onnx-diagnostic 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,5 +3,5 @@ Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.2.1"
6
+ __version__ = "0.2.2"
7
7
  __author__ = "Xavier Dupré"
@@ -4,6 +4,28 @@ import torch
4
4
  import transformers
5
5
  import transformers.cache_utils
6
6
 
7
+
8
+ def is_cache_dynamic_registered() -> bool:
9
+ """
10
+ Tells class :class:`transformers.cache_utils.DynamicCache` can be
11
+ serialized and deserialized. Only then, :func:`torch.export.export`
12
+ can export a model.
13
+ """
14
+ bsize, nheads, slen, dim = 2, 4, 3, 7
15
+ cache = make_dynamic_cache(
16
+ [
17
+ (
18
+ torch.randn(bsize, nheads, slen, dim),
19
+ torch.randn(bsize, nheads, slen, dim),
20
+ )
21
+ for i in range(2)
22
+ ]
23
+ )
24
+ values, spec = torch.utils._pytree.tree_flatten(cache)
25
+ cache2 = torch.utils._pytree.tree_unflatten(values, spec)
26
+ return len(cache2.key_cache) == len(cache.value_cache)
27
+
28
+
7
29
  if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
8
30
 
9
31
  def make_dynamic_cache(
onnx_diagnostic/doc.py ADDED
@@ -0,0 +1,32 @@
1
+ def reset_torch_transformers(gallery_conf, fname):
2
+ "Resets torch dynamo for :epkg:`sphinx-gallery`."
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+
6
+ plt.style.use("ggplot")
7
+ torch._dynamo.reset()
8
+
9
+
10
+ def plot_legend(
11
+ text: str, text_bottom: str = "", color: str = "green", fontsize: int = 35
12
+ ) -> "matplotlib.axes.Axes": # noqa: F821
13
+ import matplotlib.pyplot as plt
14
+
15
+ fig = plt.figure()
16
+ ax = fig.add_subplot()
17
+ ax.axis([0, 5, 0, 5])
18
+ ax.text(2.5, 4, "END", fontsize=50, horizontalalignment="center")
19
+ ax.text(
20
+ 2.5,
21
+ 2.5,
22
+ text,
23
+ fontsize=fontsize,
24
+ bbox={"facecolor": color, "alpha": 0.5, "pad": 10},
25
+ horizontalalignment="center",
26
+ verticalalignment="center",
27
+ )
28
+ if text_bottom:
29
+ ax.text(4.5, 0.5, text_bottom, fontsize=20, horizontalalignment="right")
30
+ ax.grid(False)
31
+ ax.set_axis_off()
32
+ return ax
@@ -1090,3 +1090,7 @@ class ExtTestCase(unittest.TestCase):
1090
1090
  and not numpy.isnan(diff["rel"])
1091
1091
  and diff["rel"] <= rtol
1092
1092
  ), f"discrepancies in {test_name!r}, diff={string_diff(diff)}"
1093
+
1094
+ def _debug(self):
1095
+ "Tells if DEBUG=1 is set up."
1096
+ return os.environ.get("DEBUG") in BOOLEAN_VALUES
@@ -240,7 +240,16 @@ class InferenceSessionForNumpy(_InferenceSession):
240
240
 
241
241
  el_type = ortvalues[i].element_type()
242
242
  if el_type < onnx.TensorProto.BFLOAT16:
243
- res.append(np.from_dlpack(ortvalues[i]))
243
+ try:
244
+ a = np.from_dlpack(ortvalues[i])
245
+ except RuntimeError as e:
246
+ assert "ORT only supports contiguous tensor for now." in str(e), (
247
+ f"As it says, non-contiguous OrtValue are not supported "
248
+ f"though DLPack, i={i}, the error is different {e}"
249
+ )
250
+ # We make a copy in that case.
251
+ a = ortvalues[i].numpy()
252
+ res.append(a)
244
253
  continue
245
254
 
246
255
  # no easy conversion, let's use torch
@@ -430,6 +439,8 @@ class InferenceSessionForTorch(_InferenceSession):
430
439
  new_feeds = {}
431
440
  for k, v in feeds.items():
432
441
  assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
442
+ if not v.is_contiguous():
443
+ v = v.contiguous()
433
444
  new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), v.dtype == torch.bool)
434
445
  if self.nvtx:
435
446
  self.torch.cuda.nvtx.range_push("run_with_ort_values")
@@ -145,6 +145,11 @@ def _unregister(cls: type, verbose: int = 0):
145
145
  # torch >= 2.7
146
146
  torch.utils._pytree._deregister_pytree_node(cls)
147
147
  optree.unregister_pytree_node(cls, namespace="torch")
148
+ if cls in torch.utils._pytree.SUPPORTED_NODES:
149
+ import packaging.version as pv
150
+
151
+ if pv.Version(torch.__version__) < pv.Version("2.7.0"):
152
+ del torch.utils._pytree.SUPPORTED_NODES[cls]
148
153
  assert cls not in torch.utils._pytree.SUPPORTED_NODES, (
149
154
  f"{cls} was not successful unregistered "
150
155
  f"from torch.utils._pytree.SUPPORTED_NODES="
@@ -190,6 +195,7 @@ def bypass_export_some_errors(
190
195
  patch_torch: bool = True,
191
196
  patch_transformers: bool = False,
192
197
  catch_constraints: bool = True,
198
+ stop_if_static: bool = False,
193
199
  verbose: int = 0,
194
200
  patch: bool = True,
195
201
  ) -> Callable:
@@ -203,8 +209,12 @@ def bypass_export_some_errors(
203
209
  as a result, some dynamic dimension may turn into static ones,
204
210
  the environment variable ``SKIP_SOLVE_CONSTRAINTS=0``
205
211
  can be put to stop at that stage.
212
+ :param stop_if_static: see example :ref:`l-plot-export-locale-issue`,
213
+ to stop the export as soon as an issue is detected with dynamic shapes
214
+ and show a stack trace indicating the exact location of the issue
206
215
  :param patch: if False, disable all patches except the registration of
207
216
  serialization function
217
+ :param verbose: to show which patches is applied
208
218
 
209
219
  The list of available patches.
210
220
 
@@ -348,6 +358,18 @@ def bypass_export_some_errors(
348
358
  )
349
359
  )
350
360
 
361
+ if stop_if_static:
362
+ if verbose:
363
+ print(
364
+ "[bypass_export_some_errors] assert when a dynamic dimension turns static"
365
+ )
366
+
367
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
368
+ from .patches.patch_torch import patched_ShapeEnv
369
+
370
+ f_shape_env__set_replacement = ShapeEnv._set_replacement
371
+ ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
372
+
351
373
  ####################
352
374
  # patch transformers
353
375
  ####################
@@ -401,6 +423,12 @@ def bypass_export_some_errors(
401
423
  if verbose:
402
424
  print("[bypass_export_some_errors] restored pytorch functions")
403
425
 
426
+ if stop_if_static:
427
+ if verbose:
428
+ print("[bypass_export_some_errors] restored ShapeEnv._set_replacement")
429
+
430
+ ShapeEnv._set_replacement = f_shape_env__set_replacement
431
+
404
432
  if catch_constraints:
405
433
  # to catch or skip dynamic_shapes issues
406
434
  torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
@@ -146,3 +146,186 @@ def patched__broadcast_shapes(*_shapes):
146
146
  common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx])
147
147
 
148
148
  return common_shape
149
+
150
+
151
+ class patched_ShapeEnv:
152
+
153
+ def _set_replacement(
154
+ self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str # noqa: F821
155
+ ) -> None:
156
+ """
157
+ Adds or updates a replacement for a symbol.
158
+ Use this instead of `self.replacements[a] = tgt`.
159
+ """
160
+ if tgt == self.replacements.get(a, None):
161
+ return
162
+
163
+ if a in tgt.free_symbols:
164
+ return
165
+
166
+ import sympy
167
+ from torch._logging import structured
168
+ from torch.utils._traceback import CapturedTraceback
169
+ from torch._logging import trace_structured
170
+ from torch._guards import TracingContext
171
+ from torch.utils._sympy.functions import FloorToInt, CeilToInt
172
+ from torch.utils._sympy.solve import try_solve
173
+ from torch.fx.experimental.symbolic_shapes import (
174
+ _is_supported_equivalence,
175
+ ValueRanges,
176
+ )
177
+
178
+ # Precondition: a == tgt
179
+ assert isinstance(a, sympy.Symbol)
180
+
181
+ if self.allow_complex_guards_as_runtime_asserts and not _is_supported_equivalence(tgt):
182
+ # continuing leads to placeholder shapes
183
+ # having complex expressions that we can't resolve
184
+ return
185
+
186
+ # Handles nested tensor symbolic variables which don't have
187
+ # var_to_range bounds
188
+ tgt_bound = None
189
+ if a in self.var_to_range:
190
+ src_bound = self.var_to_range[a]
191
+
192
+ # First, refine the value range of a based on the computed value range
193
+ # of tgt. This is always OK to do, even if we decide not to do the
194
+ # substitution in the end. This might be a no-op, if a already has
195
+ # a tighter bound
196
+ tgt_bound = self.bound_sympy(tgt)
197
+ self._update_var_to_range(a, tgt_bound)
198
+
199
+ # Next, check if we can update the range of free symbols in tgt
200
+ # based on the range in a. But only do it if:
201
+ # - the source bound non-trivially improves over what we get out of
202
+ # the existing bounds.
203
+ # - the replacement is univariate and we can invert the tgt expression
204
+ if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1:
205
+ b = next(iter(tgt.free_symbols))
206
+ # Try to invert the equality
207
+ r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
208
+ if r is not None:
209
+ self.log.debug(
210
+ "set_replacement: solve for %s in %s == %s gives %s",
211
+ b,
212
+ a,
213
+ tgt,
214
+ r,
215
+ )
216
+ # The solution here can be non-integral, for example, if
217
+ # we have s0 = 2*s1, then s1 = s0/2. What we would like
218
+ # to do is calculated the bounds in arbitrary precision,
219
+ # and then requantize the bound to integers when we are
220
+ # done.
221
+ rat_b_bound = self.bound_sympy(r[1])
222
+ b_bound = ValueRanges(
223
+ CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)
224
+ )
225
+ self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a])
226
+ tgt_bound = self.bound_sympy(tgt)
227
+ assert tgt_bound.issubset(
228
+ src_bound
229
+ ), f"{tgt_bound=} not a subset of {src_bound=}"
230
+
231
+ # TODO: Should we propagate size-like-ness?
232
+ #
233
+ # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1
234
+ # to become size-like.
235
+ #
236
+ # Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T
237
+ # propagate in this case, because what if u0 == 0, then u1 is negative
238
+ # and clearly isn't a size. So, at minimum, any f(x) whose value
239
+ # range isn't [0, inf] given x in [0, inf] cannot propagate
240
+ # size-like-ness. But there are many situations where you could
241
+ # imagine u1 is going to be size-like and actually you just didn't
242
+ # have a refined enough value range on u0. Since even innocuous
243
+ # looking arithmetic operations can destroy size-like-ness, it's
244
+ # best to not propagate it at all and force the user to annotate it
245
+ # as necessary.
246
+ #
247
+ # Compromise: we preserve size-like-ness only for exact equality
248
+ # and nothing else.
249
+ if a in self.size_like and isinstance(tgt, sympy.Symbol):
250
+ self.size_like.add(tgt)
251
+ elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like:
252
+ self.size_like.add(a)
253
+
254
+ # Now, decide if we will do the substitution.
255
+ #
256
+ # - If the source has a non-trivial range, only substitute if
257
+ # we preserve this range. Note that we may have propagated
258
+ # the src_range to free variables in tgt when tgt is univariate
259
+ # and we could find an inverse, which helps us achieve this.
260
+ # This ensures we never "forget" about user defined ranges,
261
+ # even if they end up being defined on composite formulas
262
+ # like s0 + s1.
263
+ #
264
+ # - If the variable is unbacked, only substitute if the substitution
265
+ # would preserve the bounds also under size-like-ness conditions.
266
+
267
+ if not tgt_bound.issubset(src_bound):
268
+ self.log.debug(
269
+ "skipped set_replacement %s = %s (%s) [%s not subset of %s]",
270
+ a,
271
+ tgt,
272
+ msg,
273
+ tgt_bound,
274
+ src_bound,
275
+ )
276
+ return
277
+ elif a in self.size_like:
278
+ tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True)
279
+ src_bound_so = self.bound_sympy(a, size_oblivious=True)
280
+ if not tgt_bound_so.issubset(src_bound_so):
281
+ self.log.debug(
282
+ "skipped set_replacement %s = %s (%s) "
283
+ "[%s not subset of %s (size-oblivious conditions)]",
284
+ a,
285
+ tgt,
286
+ msg,
287
+ tgt_bound_so,
288
+ src_bound_so,
289
+ )
290
+ return
291
+
292
+ if isinstance(tgt, (sympy.Integer, sympy.Float)):
293
+ # specializing to a constant, which is likely unexpected (unless
294
+ # you specified dynamic=True)
295
+
296
+ user_tb = TracingContext.extract_stack()
297
+ trace_structured(
298
+ "symbolic_shape_specialization",
299
+ metadata_fn=lambda: {
300
+ "symbol": repr(a),
301
+ "sources": [s.name() for s in self.var_to_sources.get(a, [])],
302
+ "value": repr(tgt),
303
+ "reason": msg,
304
+ "stack": structured.from_traceback(
305
+ CapturedTraceback.extract(skip=1).summary()
306
+ ),
307
+ "user_stack": (structured.from_traceback(user_tb) if user_tb else None),
308
+ },
309
+ )
310
+
311
+ # if config.print_specializations:
312
+ # self.log.warning(
313
+ # "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
314
+ # )
315
+ # self.log.debug("SPECIALIZATION", stack_info=True)
316
+ assert msg != "range_refined_to_singleton", (
317
+ f"A dynamic dimension becomes static! "
318
+ f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
319
+ )
320
+ # log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
321
+ self.replacements[a] = tgt
322
+ # NB: the replacement may get refined, but the user will find the
323
+ # FIRST one most useful (TODO: Maybe we could consider tracking all of
324
+ # them)
325
+ if a not in self.replacements_slocs:
326
+ self.replacements_slocs[a] = self._get_sloc()
327
+ self._update_version_counter()
328
+
329
+ # When specializing 'a == tgt', the equality should be also conveyed to
330
+ # Z3, in case an expression uses 'a'.
331
+ self._add_target_expr(sympy.Eq(a, tgt, evaluate=False))
@@ -102,9 +102,22 @@ class patched_DynamicCache:
102
102
  `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
103
103
  """
104
104
 
105
- _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits"]
105
+ _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
106
106
  _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
107
107
 
108
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
109
+ """Returns the sequence length of the cached states.
110
+ A layer index can be optionally passed."""
111
+ # TODO: deprecate this function in favor of `cache_position`
112
+ is_empty_layer = (
113
+ len(self.key_cache) == 0 # no cache in any layer
114
+ or len(self.key_cache)
115
+ <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
116
+ or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
117
+ )
118
+ layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
119
+ return layer_seq_length
120
+
108
121
  def reorder_cache(self, beam_idx: torch.LongTensor):
109
122
  """Reorders the cache for beam search, given the selected beam indices."""
110
123
  for layer_idx in range(len(self.key_cache)):
@@ -1,96 +1,2 @@
1
- from typing import Any, Dict
2
- import torch
3
- import transformers
4
- from ..cache_helpers import make_dynamic_cache
5
-
6
-
7
- def get_tiny_llm(
8
- batch_size: int = 2,
9
- input_cache: bool = True,
10
- dynamic_rope: bool = False,
11
- **kwargs,
12
- ) -> Dict[str, Any]:
13
- """
14
- Gets a non initialized model.
15
-
16
- :param batch_size: batch size
17
- :param input_cache: generate data for this iteration with or without cache
18
- :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
19
- :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
20
- :return: dictionary
21
-
22
- See :ref:`l-plot-tiny-llm-export` for an example.
23
- """
24
- config = {
25
- "architectures": ["LlamaForCausalLM"],
26
- "bos_token_id": 1,
27
- "eos_token_id": 2,
28
- "hidden_act": "silu",
29
- "hidden_size": 192,
30
- "initializer_range": 0.02,
31
- "intermediate_size": 1024,
32
- "max_position_embeddings": 1024,
33
- "model_type": "llama",
34
- "num_attention_heads": 2,
35
- "num_hidden_layers": 1,
36
- "num_key_value_heads": 1,
37
- "pretraining_tp": 1,
38
- "rms_norm_eps": 1e-05,
39
- "rope_scaling": {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None,
40
- "tie_word_embeddings": False,
41
- "torch_dtype": "float32",
42
- "transformers_version": "4.31.0.dev0",
43
- "use_cache": True,
44
- "vocab_size": 32000,
45
- }
46
-
47
- config.update(**kwargs)
48
- conf = transformers.LlamaConfig(**config)
49
- model = transformers.LlamaForCausalLM(conf)
50
- model.eval()
51
-
52
- # now the inputs
53
- cache_last_dim = 96
54
- sequence_length = 30
55
- sequence_length2 = 3
56
- num_key_value_heads = 1
57
- max_token_id = config["vocab_size"] - 1
58
- n_layers = config["num_hidden_layers"]
59
-
60
- batch = torch.export.Dim("batch", min=1, max=1024)
61
- seq_length = torch.export.Dim("seq_length", min=1, max=4096)
62
- cache_length = torch.export.Dim("cache_length", min=1, max=4096)
63
-
64
- shapes = {
65
- "input_ids": {0: batch, 1: seq_length},
66
- "attention_mask": {
67
- 0: batch,
68
- 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
69
- },
70
- "past_key_values": [
71
- [{0: batch, 2: cache_length} for _ in range(n_layers)],
72
- [{0: batch, 2: cache_length} for _ in range(n_layers)],
73
- ],
74
- }
75
- inputs = dict(
76
- input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(
77
- torch.int64
78
- ),
79
- attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
80
- torch.int64
81
- ),
82
- past_key_values=make_dynamic_cache(
83
- [
84
- (
85
- torch.randn(
86
- batch_size, num_key_value_heads, sequence_length, cache_last_dim
87
- ),
88
- torch.randn(
89
- batch_size, num_key_value_heads, sequence_length, cache_last_dim
90
- ),
91
- )
92
- for i in range(n_layers)
93
- ]
94
- ),
95
- )
96
- return dict(inputs=inputs, model=model, dynamic_shapes=shapes)
1
+ from .untrained.llm_phi2 import get_phi2
2
+ from .untrained.llm_tiny_llm import get_tiny_llm
File without changes
@@ -0,0 +1,109 @@
1
+ from typing import Any, Dict
2
+ import torch
3
+ import transformers
4
+ from ...cache_helpers import make_dynamic_cache
5
+
6
+
7
+ def get_phi2(
8
+ batch_size: int = 1,
9
+ sequence_length: int = 30,
10
+ sequence_length2: int = 3,
11
+ dynamic_rope: bool = False,
12
+ **kwargs,
13
+ ) -> Dict[str, Any]:
14
+ """
15
+ Gets a non initialized model
16
+ similar to `microsoft/phi-2 <https://huggingface.co/microsoft/phi-2>`_
17
+
18
+ :param batch_size: batch size
19
+ :param sequence_length: sequence length
20
+ :param sequence_length2: new sequence length
21
+ :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
22
+ :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
23
+ :return: dictionary
24
+
25
+ See :ref:`l-plot-tiny-llm-export-patched` for an example with a similar model.
26
+ """
27
+ config = {
28
+ "_name_or_path": "microsoft/phi-2",
29
+ "architectures": ["PhiForCausalLM"],
30
+ "attention_dropout": 0.0,
31
+ "bos_token_id": 50256,
32
+ "embd_pdrop": 0.0,
33
+ "eos_token_id": 50256,
34
+ "hidden_act": "gelu_new",
35
+ "hidden_size": 2560,
36
+ "initializer_range": 0.02,
37
+ "intermediate_size": 10240,
38
+ "layer_norm_eps": 1e-05,
39
+ "max_position_embeddings": 2048,
40
+ "model_type": "phi",
41
+ "num_attention_heads": 32,
42
+ "num_hidden_layers": 32,
43
+ "num_key_value_heads": 32,
44
+ "partial_rotary_factor": 0.4,
45
+ "qk_layernorm": False,
46
+ "resid_pdrop": 0.1,
47
+ "rope_scaling": {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None,
48
+ "rope_theta": 10000.0,
49
+ "tie_word_embeddings": False,
50
+ "torch_dtype": "float16",
51
+ "transformers_version": "4.37.0",
52
+ "use_cache": True,
53
+ "vocab_size": 51200,
54
+ }
55
+ config.update(**kwargs)
56
+ conf = transformers.PhiConfig(**config)
57
+ model = transformers.PhiForCausalLM(conf)
58
+ model.eval()
59
+
60
+ # now the inputs
61
+ cache_last_dim = 80
62
+ max_token_id = config["vocab_size"] - 1
63
+ n_layers = config["num_hidden_layers"]
64
+ num_key_value_heads = config["num_key_value_heads"]
65
+
66
+ batch = torch.export.Dim("batch", min=1, max=1024)
67
+ seq_length = torch.export.Dim("seq_length", min=1, max=4096)
68
+ cache_length = torch.export.Dim("cache_length", min=1, max=4096)
69
+
70
+ shapes = {
71
+ "input_ids": {0: batch, 1: seq_length},
72
+ "position_ids": {
73
+ 0: batch,
74
+ 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
75
+ },
76
+ "attention_mask": {
77
+ 0: batch,
78
+ 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
79
+ },
80
+ "past_key_values": [
81
+ [{0: batch, 2: cache_length} for _ in range(n_layers)],
82
+ [{0: batch, 2: cache_length} for _ in range(n_layers)],
83
+ ],
84
+ }
85
+ inputs = dict(
86
+ input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(
87
+ torch.int64
88
+ ),
89
+ attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
90
+ torch.int64
91
+ ),
92
+ position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
93
+ .to(torch.int64)
94
+ .expand((batch_size, -1)),
95
+ past_key_values=make_dynamic_cache(
96
+ [
97
+ (
98
+ torch.randn(
99
+ batch_size, num_key_value_heads, sequence_length, cache_last_dim
100
+ ),
101
+ torch.randn(
102
+ batch_size, num_key_value_heads, sequence_length, cache_last_dim
103
+ ),
104
+ )
105
+ for i in range(n_layers)
106
+ ]
107
+ ),
108
+ )
109
+ return dict(inputs=inputs, model=model, dynamic_shapes=shapes)
@@ -0,0 +1,104 @@
1
+ from typing import Any, Dict
2
+ import torch
3
+ import transformers
4
+ from ...cache_helpers import make_dynamic_cache
5
+
6
+
7
+ def get_tiny_llm(
8
+ batch_size: int = 2,
9
+ sequence_length: int = 30,
10
+ sequence_length2: int = 3,
11
+ dynamic_rope: bool = False,
12
+ **kwargs,
13
+ ) -> Dict[str, Any]:
14
+ """
15
+ Gets a non initialized model
16
+ similar to `arnir0/Tiny-LLM <https://huggingface.co/arnir0/Tiny-LLM>`_
17
+
18
+ :param batch_size: batch size
19
+ :param sequence_length: sequence length
20
+ :param sequence_length2: new sequence length
21
+ :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
22
+ :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
23
+ :return: dictionary
24
+
25
+ See :ref:`l-plot-tiny-llm-export` or :ref:`l-plot-tiny-llm-export-patched` for examples.
26
+ """
27
+ config = {
28
+ "architectures": ["LlamaForCausalLM"],
29
+ "bos_token_id": 1,
30
+ "eos_token_id": 2,
31
+ "hidden_act": "silu",
32
+ "hidden_size": 192,
33
+ "initializer_range": 0.02,
34
+ "intermediate_size": 1024,
35
+ "max_position_embeddings": 1024,
36
+ "model_type": "llama",
37
+ "num_attention_heads": 2,
38
+ "num_hidden_layers": 1,
39
+ "num_key_value_heads": 1,
40
+ "pretraining_tp": 1,
41
+ "rms_norm_eps": 1e-05,
42
+ "rope_scaling": {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None,
43
+ "tie_word_embeddings": False,
44
+ "torch_dtype": "float32",
45
+ "transformers_version": "4.31.0.dev0",
46
+ "use_cache": True,
47
+ "vocab_size": 32000,
48
+ }
49
+
50
+ config.update(**kwargs)
51
+ conf = transformers.LlamaConfig(**config)
52
+ model = transformers.LlamaForCausalLM(conf)
53
+ model.eval()
54
+
55
+ # now the inputs
56
+ cache_last_dim = 96
57
+ max_token_id = config["vocab_size"] - 1
58
+ n_layers = config["num_hidden_layers"]
59
+ num_key_value_heads = config["num_key_value_heads"]
60
+
61
+ batch = torch.export.Dim("batch", min=1, max=1024)
62
+ seq_length = torch.export.Dim("seq_length", min=1, max=4096)
63
+ cache_length = torch.export.Dim("cache_length", min=1, max=4096)
64
+
65
+ shapes = {
66
+ "input_ids": {0: batch, 1: seq_length},
67
+ "attention_mask": {
68
+ 0: batch,
69
+ 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
70
+ },
71
+ "position_ids": {
72
+ 0: batch,
73
+ 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
74
+ },
75
+ "past_key_values": [
76
+ [{0: batch, 2: cache_length} for _ in range(n_layers)],
77
+ [{0: batch, 2: cache_length} for _ in range(n_layers)],
78
+ ],
79
+ }
80
+ inputs = dict(
81
+ input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(
82
+ torch.int64
83
+ ),
84
+ attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
85
+ torch.int64
86
+ ),
87
+ position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
88
+ .to(torch.int64)
89
+ .expand((batch_size, -1)),
90
+ past_key_values=make_dynamic_cache(
91
+ [
92
+ (
93
+ torch.randn(
94
+ batch_size, num_key_value_heads, sequence_length, cache_last_dim
95
+ ),
96
+ torch.randn(
97
+ batch_size, num_key_value_heads, sequence_length, cache_last_dim
98
+ ),
99
+ )
100
+ for i in range(n_layers)
101
+ ]
102
+ ),
103
+ )
104
+ return dict(inputs=inputs, model=model, dynamic_shapes=shapes)
@@ -13,6 +13,10 @@ def is_torchdynamo_exporting() -> bool:
13
13
  """Tells if torch is exporting a model."""
14
14
  import torch
15
15
 
16
+ if not hasattr(torch.compiler, "is_exporting"):
17
+ # torch.compiler.is_exporting requires torch>=2.7
18
+ return False
19
+
16
20
  try:
17
21
  return torch.compiler.is_exporting()
18
22
  except Exception:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-diagnostic
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: Investigate ONNX models
5
5
  Home-page: https://github.com/sdpython/onnx-diagnostic
6
6
  Author: Xavier Dupré
@@ -87,6 +87,8 @@ Enlightening Examples
87
87
 
88
88
  * `Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints
89
89
  <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_shapes_auto.html>`_
90
+ * `Find and fix an export issue due to dynamic shapes
91
+ <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_locate_issue.html>`_
90
92
  * `Export with DynamicCache and dynamic shapes
91
93
  <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_cache.html>`_
92
94
  * `Steel method forward to guess the dynamic shapes (with Tiny-LLM)
@@ -1,11 +1,12 @@
1
- onnx_diagnostic/__init__.py,sha256=_Od-M8LjUrc4uDtezDVkld8vhaBzkJ0r5WdsOWcCUG0,164
1
+ onnx_diagnostic/__init__.py,sha256=G8fJzYgyo9GW9yrpyPU-kbp-3X-LMyXMXxZiqPLDLEg,164
2
2
  onnx_diagnostic/args.py,sha256=7pTrw1A1wuNvLdXJdpda5spPI140FylwSmxxZTGu_4E,4389
3
- onnx_diagnostic/cache_helpers.py,sha256=3wMsOseENzbsjzo7yhgt44cH0S1t-n1e-_dj-os3uGc,2698
4
- onnx_diagnostic/ext_test_case.py,sha256=XvrgU7_TBo2MB3ilhOawatmSwvtHeGfYnl_a0t69QAo,37222
3
+ onnx_diagnostic/cache_helpers.py,sha256=rtiJZDzuIXR88p3e9ALeMhOiJ66SONM7oUO1t5-N7FM,3374
4
+ onnx_diagnostic/doc.py,sha256=qKBHk5pFGK0WwecMqAB5fhAdTRF5g9-rsMWLsyczYFc,913
5
+ onnx_diagnostic/ext_test_case.py,sha256=dlpwu4CunVA35W1VycCYVreEIOZyeYQpvU8fHKfurNE,37340
5
6
  onnx_diagnostic/helpers.py,sha256=qOArpaRV2W3LcD48Ewi6QsBHWZcSpvaHaA0f9bOMGWA,61494
6
7
  onnx_diagnostic/onnx_tools.py,sha256=avk1GkA8QRPb_a6JEmN55YfbefxL4wCBRKZMNMs_VuI,8304
7
- onnx_diagnostic/ort_session.py,sha256=hcXwSOLLosHzmQYQRhzFVd5kHbj7frymZiw_WTXSg44,26221
8
- onnx_diagnostic/torch_test_helper.py,sha256=ZRKZDzuv_IZogt79ubJpbD6YDxRX52BX74er7OI0VOQ,9336
8
+ onnx_diagnostic/ort_session.py,sha256=z_uBekAkQMHbt1XdxQUM3HSA9SPdrkoEJqETe-yDdd0,26746
9
+ onnx_diagnostic/torch_test_helper.py,sha256=f6cbHwEexsY9-wwTDu9_y5_-57QzUW8UrC8YWIa4ID4,9468
9
10
  onnx_diagnostic/export/__init__.py,sha256=nngeNcP71r_Mmvt2U-qbG-t4-0Kr623zcUhHgNA3IiA,40
10
11
  onnx_diagnostic/export/dynamic_shapes.py,sha256=9lVZEq9fY_97HGeXVbZlStYXJQ3bs_T5jVSp3iu46Kc,15662
11
12
  onnx_diagnostic/reference/__init__.py,sha256=0Al5kins8LlBICAsszEZ59thMwmaARBO6fMwtYpKOOQ,98
@@ -41,15 +42,18 @@ onnx_diagnostic/reference/ops/op_slice.py,sha256=yRxfYBs8b7QezyyG9JHCD8MIJHij2qR
41
42
  onnx_diagnostic/reference/ops/op_transpose_cast.py,sha256=ifef74rvh0Yvq1Zx51B4mfnISbxV9uRg9DFjkdL1_68,361
42
43
  onnx_diagnostic/reference/ops/op_tri_matrix.py,sha256=Yn2gxAyygcwtF5Hjau9ihXDAzul0BAkdqVimVahtFBU,519
43
44
  onnx_diagnostic/torch_export_patches/__init__.py,sha256=RZzVGgouNNXaPirQJYQThiq5wrliwH4unVszeU18oJw,116
44
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=B7OSjS9dbwa5NNh7kpPexL1eb03vSD7L7TgaMfiH60Q,16223
45
+ onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=o_Q6DHWUL9_zwHNwGHnzkBvhdJd3LV968-6mVKEr3Ko,17441
45
46
  onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=0DAa1wP4kFRoru2J3i2r1sNNxUdpF0Op5noTcmLWRow,4087
46
47
  onnx_diagnostic/torch_export_patches/patches/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=uUraqyHGnj9QykRtA_CFjlTGmk8NsmbY2Ufu7ITPSJQ,5783
48
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=okNlYyWxM4P6TDxZv-_vEKI8nJSx4tnwXZRmcO1ZnPg,21118
48
+ onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=SPvwag9oelMDmK5St0c1DXOKPO7_5iA9Nck8P9JCGk0,14111
49
+ onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=RSoIPTbiZj9IV7Gd9NjAnJGhLHveyB2Tw1kKpATw3f8,21814
49
50
  onnx_diagnostic/torch_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
- onnx_diagnostic/torch_models/llms.py,sha256=pEMzD5GwmQ-X2WmYHw1OnNT9RKMRwJdGVyoGxWEbaio,3111
51
- onnx_diagnostic-0.2.1.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
52
- onnx_diagnostic-0.2.1.dist-info/METADATA,sha256=_mDC2LPuSjXECIosSLHGsPLqjWlUGmQH62_-xLKOT0w,4696
53
- onnx_diagnostic-0.2.1.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
54
- onnx_diagnostic-0.2.1.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
55
- onnx_diagnostic-0.2.1.dist-info/RECORD,,
51
+ onnx_diagnostic/torch_models/llms.py,sha256=soyg4yC87ptGoeulJhKqw5opGmuLvH1pn_ZDXZ4Jr8E,90
52
+ onnx_diagnostic/torch_models/untrained/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
53
+ onnx_diagnostic/torch_models/untrained/llm_phi2.py,sha256=wfc_sp7DDbVPNMPtrYrKSndOPSuL3YLd_vTgjjaeGAc,3720
54
+ onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py,sha256=9pculGIzkhjKSNwna6nbpUoLBpm2JEWhYvHeq_TFEts,3525
55
+ onnx_diagnostic-0.2.2.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
56
+ onnx_diagnostic-0.2.2.dist-info/METADATA,sha256=ea8SGO-nBmEa4OHFkbMDY3VhqGGdvZn3iS-l-YXz97A,4851
57
+ onnx_diagnostic-0.2.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
58
+ onnx_diagnostic-0.2.2.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
59
+ onnx_diagnostic-0.2.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.0.2)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5