onnx-diagnostic 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,5 +3,5 @@ Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.4.3"
6
+ __version__ = "0.4.4"
7
7
  __author__ = "Xavier Dupré"
@@ -336,6 +336,10 @@ def get_parser_validate() -> ArgumentParser:
336
336
  help="drops the following inputs names, it should be a list "
337
337
  "with comma separated values",
338
338
  )
339
+ parser.add_argument(
340
+ "--subfolder",
341
+ help="subfolder where to find the model and the configuration",
342
+ )
339
343
  parser.add_argument(
340
344
  "--ortfusiontype",
341
345
  required=False,
@@ -413,6 +417,7 @@ def _cmd_validate(argv: List[Any]):
413
417
  ortfusiontype=args.ortfusiontype,
414
418
  input_options=args.iop,
415
419
  model_options=args.mop,
420
+ subfolder=args.subfolder,
416
421
  )
417
422
  print("")
418
423
  print("-- summary --")
@@ -6,6 +6,7 @@ from . import (
6
6
  image_classification,
7
7
  image_text_to_text,
8
8
  mixture_of_expert,
9
+ object_detection,
9
10
  sentence_similarity,
10
11
  text_classification,
11
12
  text_generation,
@@ -20,6 +21,7 @@ __TASKS__ = [
20
21
  image_classification,
21
22
  image_text_to_text,
22
23
  mixture_of_expert,
24
+ object_detection,
23
25
  sentence_similarity,
24
26
  text_classification,
25
27
  text_generation,
@@ -7,6 +7,13 @@ __TASK__ = "image-classification"
7
7
 
8
8
  def reduce_model_config(config: Any) -> Dict[str, Any]:
9
9
  """Reduces a model size."""
10
+ if (
11
+ hasattr(config, "model_type")
12
+ and config.model_type == "timm_wrapper"
13
+ and not hasattr(config, "num_hidden_layers")
14
+ ):
15
+ # We cannot reduce.
16
+ return {}
10
17
  check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
11
18
  kwargs = dict(
12
19
  num_hidden_layers=(
@@ -82,6 +89,20 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
82
89
  If the configuration is None, the function selects typical dimensions.
83
90
  """
84
91
  if config is not None:
92
+ if (
93
+ hasattr(config, "model_type")
94
+ and config.model_type == "timm_wrapper"
95
+ and not hasattr(config, "num_hidden_layers")
96
+ ):
97
+ input_size = config.pretrained_cfg["input_size"]
98
+ kwargs = dict(
99
+ batch_size=2,
100
+ input_width=input_size[-2],
101
+ input_height=input_size[-1],
102
+ input_channels=input_size[-3],
103
+ )
104
+ return kwargs, get_inputs
105
+
85
106
  check_hasattr(config, ("image_size", "architectures"), "num_channels")
86
107
  if config is not None:
87
108
  if hasattr(config, "image_size"):
@@ -0,0 +1,123 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import update_config, check_hasattr
4
+
5
+ __TASK__ = "object-detection"
6
+
7
+
8
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
+ """Reduces a model size."""
10
+ check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
11
+ kwargs = dict(
12
+ num_hidden_layers=(
13
+ min(config.num_hidden_layers, 2)
14
+ if hasattr(config, "num_hidden_layers")
15
+ else len(config.hidden_sizes)
16
+ )
17
+ )
18
+ update_config(config, kwargs)
19
+ return kwargs
20
+
21
+
22
+ def get_inputs(
23
+ model: torch.nn.Module,
24
+ config: Optional[Any],
25
+ input_width: int,
26
+ input_height: int,
27
+ input_channels: int,
28
+ batch_size: int = 2,
29
+ dynamic_rope: bool = False,
30
+ add_second_input: bool = False,
31
+ **kwargs, # unused
32
+ ):
33
+ """
34
+ Generates inputs for task ``object-detection``.
35
+
36
+ :param model: model to get the missing information
37
+ :param config: configuration used to generate the model
38
+ :param batch_size: batch size
39
+ :param input_channels: input channel
40
+ :param input_width: input width
41
+ :param input_height: input height
42
+ :return: dictionary
43
+ """
44
+ assert isinstance(
45
+ input_width, int
46
+ ), f"Unexpected type for input_width {type(input_width)}{config}"
47
+ assert isinstance(
48
+ input_width, int
49
+ ), f"Unexpected type for input_height {type(input_height)}{config}"
50
+
51
+ shapes = {
52
+ "pixel_values": {
53
+ 0: torch.export.Dim("batch", min=1, max=1024),
54
+ 2: "width",
55
+ 3: "height",
56
+ }
57
+ }
58
+ inputs = dict(
59
+ pixel_values=torch.randn(batch_size, input_channels, input_width, input_height).clamp(
60
+ -1, 1
61
+ ),
62
+ )
63
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
64
+ if add_second_input:
65
+ res["inputs2"] = get_inputs(
66
+ model=model,
67
+ config=config,
68
+ input_width=input_width + 1,
69
+ input_height=input_height + 1,
70
+ input_channels=input_channels,
71
+ batch_size=batch_size + 1,
72
+ dynamic_rope=dynamic_rope,
73
+ **kwargs,
74
+ )["inputs"]
75
+ return res
76
+
77
+
78
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
79
+ """
80
+ Inputs kwargs.
81
+
82
+ If the configuration is None, the function selects typical dimensions.
83
+ """
84
+ if config is not None:
85
+ if (
86
+ hasattr(config, "model_type")
87
+ and config.model_type == "timm_wrapper"
88
+ and not hasattr(config, "num_hidden_layers")
89
+ ):
90
+ input_size = config.pretrained_cfg["input_size"]
91
+ kwargs = dict(
92
+ batch_size=2,
93
+ input_width=input_size[-2],
94
+ input_height=input_size[-1],
95
+ input_channels=input_size[-3],
96
+ )
97
+ return kwargs, get_inputs
98
+
99
+ check_hasattr(config, ("image_size", "architectures"), "num_channels")
100
+ if config is not None:
101
+ if hasattr(config, "image_size"):
102
+ image_size = config.image_size
103
+ else:
104
+ assert config.architectures, f"empty architecture in {config}"
105
+ from ..torch_models.hghub.hub_api import get_architecture_default_values
106
+
107
+ default_values = get_architecture_default_values(config.architectures[0])
108
+ image_size = default_values["image_size"]
109
+ if config is None or isinstance(image_size, int):
110
+ kwargs = dict(
111
+ batch_size=2,
112
+ input_width=224 if config is None else image_size,
113
+ input_height=224 if config is None else image_size,
114
+ input_channels=3 if config is None else config.num_channels,
115
+ )
116
+ else:
117
+ kwargs = dict(
118
+ batch_size=2,
119
+ input_width=config.image_size[0],
120
+ input_height=config.image_size[1],
121
+ input_channels=config.num_channels,
122
+ )
123
+ return kwargs, get_inputs
@@ -19,12 +19,11 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
19
19
  ("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"),
20
20
  "num_hidden_layers",
21
21
  ("num_key_value_heads", "num_attention_heads", "use_mambapy"),
22
- "intermediate_size",
23
22
  "hidden_size",
24
23
  "vocab_size",
25
24
  )
26
25
  if config.__class__.__name__ == "FalconMambaConfig":
27
- check_hasattr(config, "conv_kernel", "state_size") # 4 and 8
26
+ check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
28
27
  kwargs = dict(
29
28
  num_hidden_layers=min(config.num_hidden_layers, 2),
30
29
  intermediate_size=256 if config is None else min(512, config.intermediate_size),
@@ -44,17 +43,18 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
44
43
  if hasattr(config, "num_key_value_heads")
45
44
  else config.num_attention_heads
46
45
  ),
47
- intermediate_size=(
48
- min(config.intermediate_size, 24576 // 4)
49
- if config.intermediate_size % 4 == 0
50
- else config.intermediate_size
51
- ),
52
46
  hidden_size=(
53
47
  min(config.hidden_size, 3072 // 4)
54
48
  if config.hidden_size % 4 == 0
55
49
  else config.hidden_size
56
50
  ),
57
51
  )
52
+ if config is None or hasattr(config, "intermediate_size"):
53
+ kwargs["intermediate_size"] = (
54
+ min(config.intermediate_size, 24576 // 4)
55
+ if config.intermediate_size % 4 == 0
56
+ else config.intermediate_size
57
+ )
58
58
  update_config(config, kwargs)
59
59
  return kwargs
60
60
 
@@ -228,11 +228,10 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
228
228
  "vocab_size",
229
229
  ("num_attention_heads", "use_mambapy"),
230
230
  ("num_key_value_heads", "num_attention_heads", "use_mambapy"),
231
- "intermediate_size",
232
231
  "hidden_size",
233
232
  )
234
233
  if config.__class__.__name__ == "FalconMambaConfig":
235
- check_hasattr(config, "conv_kernel", "state_size") # 4 and 8
234
+ check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
236
235
  kwargs = dict(
237
236
  batch_size=2,
238
237
  sequence_length=30,
@@ -263,7 +262,11 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
263
262
  if config is None
264
263
  else _pick(config, "num_key_value_heads", "num_attention_heads")
265
264
  ),
266
- intermediate_size=1024 if config is None else config.intermediate_size,
267
265
  hidden_size=512 if config is None else config.hidden_size,
268
266
  )
267
+ if config is None or hasattr(config, "intermediate_size"):
268
+ kwargs["intermediate_size"] = (
269
+ 1024 if config is None else config.intermediate_size,
270
+ )
271
+
269
272
  return kwargs, get_inputs
@@ -6,3 +6,15 @@ from .onnx_export_errors import (
6
6
 
7
7
  # bypass_export_some_errors is the first name given to the patches.
8
8
  bypass_export_some_errors = torch_export_patches # type: ignore
9
+
10
+
11
+ def register_flattening_functions(verbose: int = 0):
12
+ """
13
+ Registers functions to serialize deserialize cache or other classes
14
+ implemented in :epkg:`transformers` and used as inputs.
15
+ This is needed whenever a model must be exported through
16
+ :func:`torch.export.export`.
17
+ """
18
+ from .onnx_export_serialization import _register_cache_serialization
19
+
20
+ return _register_cache_serialization(verbose=verbose)
@@ -0,0 +1,304 @@
1
+ import ast
2
+ import inspect
3
+ import types
4
+ import textwrap
5
+ from typing import Callable, Dict
6
+ import torch
7
+
8
+ NODE_TYPES = tuple(
9
+ getattr(ast, k)
10
+ for k in dir(ast)
11
+ if "A" <= k[0] <= "Z" and isinstance(getattr(ast, k), type)
12
+ )
13
+
14
+
15
+ def _settl(node, lineno, level=0):
16
+ if isinstance(node, (str, int, float)):
17
+ return node
18
+ if isinstance(node, list):
19
+ for n in node:
20
+ _settl(n, lineno, level=level + 1)
21
+ return node
22
+ if isinstance(node, NODE_TYPES):
23
+ if not hasattr(node, "lineno") or node.lineno is None:
24
+ node.lineno = lineno
25
+ for k in dir(node):
26
+ if k in {"s", "n"}:
27
+ continue
28
+ if k[0] == "_":
29
+ continue
30
+ v = getattr(node, k)
31
+ _settl(v, max(lineno, node.lineno), level=level + 1)
32
+ return node
33
+
34
+
35
+ class RewriteControlFlow(ast.NodeTransformer):
36
+ def __init__(self, wrapper_name):
37
+ self.wrapper_name = wrapper_name
38
+ self.counter = 0
39
+ self.current_func_args = None
40
+
41
+ def visit_FunctionDef(self, node):
42
+ # Capture argument names for branch functions
43
+ old_args = self.current_func_args
44
+ self.current_func_args = [arg.arg for arg in node.args.args]
45
+ node.body = [self.visit(n) for n in node.body]
46
+ self.current_func_args = old_args
47
+ return node
48
+
49
+ def visit_If(self, node):
50
+ # First recurse into subnodes
51
+ node = self.generic_visit(node)
52
+ test_node = node.test
53
+
54
+ # Case 1: simple assignment in both branches
55
+ if (
56
+ len(node.body) == 1
57
+ and isinstance(node.body[0], ast.Assign)
58
+ and len(node.orelse) == 1
59
+ and isinstance(node.orelse[0], ast.Assign)
60
+ and self.current_func_args is not None
61
+ ):
62
+ then_assign = node.body[0]
63
+ else_assign = node.orelse[0]
64
+ tgt = then_assign.targets[0]
65
+ if (
66
+ isinstance(tgt, ast.Name)
67
+ and isinstance(else_assign.targets[0], ast.Name)
68
+ and tgt.id == else_assign.targets[0].id
69
+ ):
70
+ self.counter += 1
71
+ then_name = f"{self.wrapper_name}_then_{self.counter}"
72
+ else_name = f"{self.wrapper_name}_else_{self.counter}"
73
+ then_expr = then_assign.value
74
+ else_expr = else_assign.value
75
+ # extract free variables
76
+ then_vars = sorted(
77
+ {
78
+ n.id
79
+ for n in ast.walk(then_expr)
80
+ if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
81
+ }
82
+ )
83
+ else_vars = sorted(
84
+ {
85
+ n.id
86
+ for n in ast.walk(else_expr)
87
+ if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
88
+ }
89
+ )
90
+ # build local funcs
91
+ then_args = [ast.arg(arg=v, annotation=None) for v in then_vars]
92
+ then_def = ast.FunctionDef(
93
+ name=then_name,
94
+ args=ast.arguments(
95
+ posonlyargs=[],
96
+ args=then_args,
97
+ kwonlyargs=[],
98
+ kw_defaults=[],
99
+ defaults=[],
100
+ ),
101
+ body=[ast.Return(then_expr)],
102
+ decorator_list=[],
103
+ returns=None,
104
+ )
105
+ else_args = [ast.arg(arg=v, annotation=None) for v in else_vars]
106
+ else_def = ast.FunctionDef(
107
+ name=else_name,
108
+ args=ast.arguments(
109
+ posonlyargs=[],
110
+ args=else_args,
111
+ kwonlyargs=[],
112
+ kw_defaults=[],
113
+ defaults=[],
114
+ ),
115
+ body=[ast.Return(else_expr)],
116
+ decorator_list=[],
117
+ returns=None,
118
+ )
119
+ # fix locations
120
+ for n in (then_def, else_def):
121
+ ast.copy_location(n, node)
122
+ ast.fix_missing_locations(n)
123
+ assert hasattr(n, "lineno")
124
+ # wrapper call and assignment
125
+ then_args_tuple = ast.Tuple(
126
+ [ast.Name(id=v, ctx=ast.Load()) for v in then_vars],
127
+ ctx=ast.Load(),
128
+ )
129
+ else_args_tuple = ast.Tuple(
130
+ [ast.Name(id=v, ctx=ast.Load()) for v in else_vars],
131
+ ctx=ast.Load(),
132
+ )
133
+ call = ast.Call(
134
+ func=ast.Name(id=self.wrapper_name, ctx=ast.Load()),
135
+ args=[
136
+ test_node,
137
+ ast.Name(id=then_name, ctx=ast.Load()),
138
+ ast.Name(id=else_name, ctx=ast.Load()),
139
+ then_args_tuple,
140
+ else_args_tuple,
141
+ ],
142
+ keywords=[],
143
+ )
144
+ assign = ast.Assign(targets=[tgt], value=call)
145
+ ast.copy_location(assign, node)
146
+ ast.fix_missing_locations(assign)
147
+ return [then_def, else_def, assign]
148
+
149
+ # Case 2: simple return in both branches
150
+ if (
151
+ len(node.body) == 1
152
+ and isinstance(node.body[0], ast.Return)
153
+ and len(node.orelse) == 1
154
+ and isinstance(node.orelse[0], ast.Return)
155
+ and self.current_func_args is not None
156
+ ):
157
+ then_ret = node.body[0]
158
+ else_ret = node.orelse[0]
159
+ then_expr = then_ret.value
160
+ else_expr = else_ret.value
161
+ self.counter += 1
162
+ then_name = f"{self.wrapper_name}_then_{self.counter}"
163
+ else_name = f"{self.wrapper_name}_else_{self.counter}"
164
+ # extract free variables
165
+ then_vars = sorted(
166
+ {
167
+ n.id
168
+ for n in ast.walk(then_expr)
169
+ if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
170
+ }
171
+ )
172
+ else_vars = sorted(
173
+ {
174
+ n.id
175
+ for n in ast.walk(else_expr)
176
+ if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
177
+ }
178
+ )
179
+
180
+ then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ != "torch")
181
+
182
+ # build local funcs
183
+ then_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars]
184
+ then_def = ast.FunctionDef(
185
+ name=then_name,
186
+ args=ast.arguments(
187
+ posonlyargs=[],
188
+ args=then_args,
189
+ kwonlyargs=[],
190
+ kw_defaults=[],
191
+ defaults=[],
192
+ ),
193
+ body=[ast.Return(then_expr)],
194
+ decorator_list=[],
195
+ returns=None,
196
+ )
197
+ else_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars]
198
+ else_def = ast.FunctionDef(
199
+ name=else_name,
200
+ args=ast.arguments(
201
+ posonlyargs=[],
202
+ args=else_args,
203
+ kwonlyargs=[],
204
+ kw_defaults=[],
205
+ defaults=[],
206
+ ),
207
+ body=[ast.Return(else_expr)],
208
+ decorator_list=[],
209
+ returns=None,
210
+ )
211
+ for n in (then_def, else_def):
212
+ ast.copy_location(n, node)
213
+ ast.fix_missing_locations(n)
214
+ # wrapper call and return
215
+ then_else_args_list = ast.List(
216
+ [ast.Name(id=v, ctx=ast.Load()) for v in then_else_vars],
217
+ ctx=ast.Load(),
218
+ )
219
+
220
+ call = ast.Call(
221
+ func=ast.Name(id=self.wrapper_name, ctx=ast.Load()),
222
+ args=[
223
+ test_node,
224
+ ast.Name(id=then_name, ctx=ast.Load()),
225
+ ast.Name(id=else_name, ctx=ast.Load()),
226
+ then_else_args_list,
227
+ ],
228
+ keywords=[],
229
+ )
230
+ ret = ast.Return(call)
231
+ ast.copy_location(ret, node)
232
+ ast.fix_missing_locations(ret)
233
+ return [then_def, else_def, ret]
234
+ return node
235
+
236
+ def generic_visit(self, node):
237
+ return super().generic_visit(node)
238
+
239
+
240
+ class RewrittenMethod:
241
+ """
242
+ Stores a rewritten method using
243
+ :func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`.
244
+
245
+ :param tree: ast tree
246
+ :param func: callable compiled from the tree
247
+ """
248
+
249
+ def __init__(self, tree, func):
250
+ self.tree = tree
251
+ self.func = func
252
+
253
+ @property
254
+ def code(self) -> str:
255
+ """Returns the source."""
256
+ return ast.unparse(self.tree)
257
+
258
+ def __repr__(self):
259
+ "usual"
260
+ return f"{self.__class__.__name__}({self.func})"
261
+
262
+
263
+ def transform_method(func: Callable, if_name="torch_cond") -> RewrittenMethod:
264
+ """
265
+ Returns a new function based on `func` where every test (if)
266
+ is replaced by a call to :func:`torch.cond`.
267
+
268
+ :param func: method or function to rewrite
269
+ :param if_name: function calling the test
270
+ :return: rewritten method
271
+ """
272
+ # Retrieve source of the function
273
+ src = inspect.getsource(func)
274
+ # Parse into AST
275
+ tree = ast.parse(textwrap.dedent(src))
276
+ # Apply transformation
277
+ transformer = RewriteControlFlow(if_name)
278
+ new_tree = transformer.visit(tree)
279
+ ast.fix_missing_locations(new_tree)
280
+ _settl(new_tree, 0)
281
+ try:
282
+ mod = compile(new_tree, filename="<ast>", mode="exec")
283
+ except TypeError as e:
284
+ if 'required field "lineno" missing from stmt' in str(e):
285
+ # Could not find a way to avoid compilng a string.
286
+ # The error message still pops up without indicating which node is not
287
+ # properly set.
288
+ code = ast.unparse(new_tree)
289
+ mod = compile(code, filename="<source>", mode="exec")
290
+ else:
291
+ kws = dict(include_attributes=True, annotate_fields=True, indent=4)
292
+ raise RuntimeError(
293
+ f"Unable to compile code\n--CODE--\n"
294
+ f"{ast.unparse(new_tree)}\n--TREE--\n"
295
+ f"{ast.dump(new_tree, **kws)}"
296
+ ) from e
297
+ namespace: Dict[str, type] = {}
298
+ globs = func.__globals__.copy()
299
+ globs[if_name] = torch.cond
300
+ exec(mod, globs, namespace)
301
+ new_func = namespace.get(func.__name__)
302
+ if not isinstance(new_func, types.FunctionType):
303
+ raise RuntimeError("Transformed function not found")
304
+ return RewrittenMethod(new_tree, new_func)
@@ -1,9 +1,10 @@
1
1
  import copy
2
2
  import functools
3
+ import json
3
4
  import os
4
5
  from typing import Any, Dict, List, Optional, Union
5
6
  import transformers
6
- from huggingface_hub import HfApi, model_info
7
+ from huggingface_hub import HfApi, model_info, hf_hub_download
7
8
  from ...helpers.config_helper import update_config
8
9
  from . import hub_data_cached_configs
9
10
  from .hub_data import __date__, __data_tasks__, load_architecture_task, __data_arch_values__
@@ -59,7 +60,11 @@ def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.Pretr
59
60
 
60
61
 
61
62
  def get_pretrained_config(
62
- model_id: str, trust_remote_code: bool = True, use_preinstalled: bool = True, **kwargs
63
+ model_id: str,
64
+ trust_remote_code: bool = True,
65
+ use_preinstalled: bool = True,
66
+ subfolder: Optional[str] = None,
67
+ **kwargs,
63
68
  ) -> Any:
64
69
  """
65
70
  Returns the config for a model_id.
@@ -71,13 +76,32 @@ def get_pretrained_config(
71
76
  accessing the network, if available, it is returned by
72
77
  :func:`get_cached_configuration`, the cached list is mostly for
73
78
  unit tests
79
+ :param subfolder: subfolder for the given model id
74
80
  :param kwargs: additional kwargs
75
81
  :return: a configuration
76
82
  """
77
83
  if use_preinstalled:
78
- conf = get_cached_configuration(model_id, **kwargs)
84
+ conf = get_cached_configuration(model_id, subfolder=subfolder, **kwargs)
79
85
  if conf is not None:
80
86
  return conf
87
+ if subfolder:
88
+ try:
89
+ return transformers.AutoConfig.from_pretrained(
90
+ model_id, trust_remote_code=trust_remote_code, subfolder=subfolder, **kwargs
91
+ )
92
+ except ValueError:
93
+ # Then we try to download it.
94
+ config = hf_hub_download(
95
+ model_id, filename="config.json", subfolder=subfolder, **kwargs
96
+ )
97
+ try:
98
+ return transformers.AutoConfig.from_pretrained(
99
+ config, trust_remote_code=trust_remote_code, **kwargs
100
+ )
101
+ except ValueError:
102
+ # Diffusers uses a dictionayr.
103
+ with open(config, "r") as f:
104
+ return json.load(f)
81
105
  return transformers.AutoConfig.from_pretrained(
82
106
  model_id, trust_remote_code=trust_remote_code, **kwargs
83
107
  )
@@ -88,6 +112,15 @@ def get_model_info(model_id) -> Any:
88
112
  return model_info(model_id)
89
113
 
90
114
 
115
+ def _guess_task_from_config(config: Any) -> Optional[str]:
116
+ """Tries to infer a task from the configuration."""
117
+ if hasattr(config, "bbox_loss_coefficient") and hasattr(config, "giou_loss_coefficient"):
118
+ return "object-detection"
119
+ if hasattr(config, "architecture") and config.architecture:
120
+ return task_from_arch(config.architecture)
121
+ return None
122
+
123
+
91
124
  @functools.cache
92
125
  def task_from_arch(arch: str, default_value: Optional[str] = None) -> str:
93
126
  """
@@ -126,7 +159,7 @@ def task_from_id(
126
159
  :param default_value: if specified, the function returns this value
127
160
  if the task cannot be determined
128
161
  :param pretrained: uses the config
129
- :param fall_back_to_pretrained: balls back to pretrained config
162
+ :param fall_back_to_pretrained: falls back to pretrained config
130
163
  :return: task
131
164
  """
132
165
  if not pretrained:
@@ -139,6 +172,9 @@ def task_from_id(
139
172
  try:
140
173
  return config.pipeline_tag
141
174
  except AttributeError:
175
+ guess = _guess_task_from_config(config)
176
+ if guess is not None:
177
+ return guess
142
178
  assert config.architectures is not None and len(config.architectures) == 1, (
143
179
  f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
144
180
  f"architectures={config.architectures} in config={config}"
@@ -78,6 +78,7 @@ __data_arch__ = textwrap.dedent(
78
78
  MobileBertModel,feature-extraction
79
79
  MobileNetV1Model,image-feature-extraction
80
80
  MobileNetV2Model,image-feature-extraction
81
+ mobilenetv3_small_100,image-classification
81
82
  MobileViTForImageClassification,image-classification
82
83
  ModernBertForMaskedLM,fill-mask
83
84
  Phi4MMForCausalLM,MoE
@@ -124,6 +125,7 @@ __data_arch__ = textwrap.dedent(
124
125
  T5ForConditionalGeneration,text2text-generation
125
126
  TableTransformerModel,image-feature-extraction
126
127
  TableTransformerForObjectDetection,object-detection
128
+ UNet2DConditionModel,text-to-image
127
129
  UniSpeechForSequenceClassification,audio-classification
128
130
  ViTForImageClassification,image-classification
129
131
  ViTMAEModel,image-feature-extraction
@@ -162,6 +164,7 @@ __data_tasks__ = [
162
164
  "sentence-similarity",
163
165
  "text-classification",
164
166
  "text-generation",
167
+ "text-to-image",
165
168
  "text-to-audio",
166
169
  "text2text-generation",
167
170
  "zero-shot-image-classification",
@@ -202,7 +205,7 @@ def load_models_testing() -> List[str]:
202
205
  @functools.cache
203
206
  def load_architecture_task() -> Dict[str, str]:
204
207
  """
205
- Returns a dictionary mapping architecture to task.
208
+ Returns a dictionary mapping architectures to tasks.
206
209
 
207
210
  import pprint
208
211
  from onnx_diagnostic.torch_models.hghub.hub_data import load_architecture_task
@@ -3640,3 +3640,213 @@ def _ccached_facebook_bart_base():
3640
3640
  "vocab_size": 50265,
3641
3641
  }
3642
3642
  )
3643
+
3644
+
3645
+ def _ccached_hustvl_yolos_tiny():
3646
+ "hustvl/yolos-tiny"
3647
+ return transformers.YolosConfig(
3648
+ **{
3649
+ "architectures": ["YolosForObjectDetection"],
3650
+ "attention_probs_dropout_prob": 0.0,
3651
+ "auxiliary_loss": false,
3652
+ "bbox_cost": 5,
3653
+ "bbox_loss_coefficient": 5,
3654
+ "class_cost": 1,
3655
+ "eos_coefficient": 0.1,
3656
+ "giou_cost": 2,
3657
+ "giou_loss_coefficient": 2,
3658
+ "hidden_act": "gelu",
3659
+ "hidden_dropout_prob": 0.0,
3660
+ "hidden_size": 192,
3661
+ "id2label": {
3662
+ "0": "N/A",
3663
+ "1": "person",
3664
+ "2": "bicycle",
3665
+ "3": "car",
3666
+ "4": "motorcycle",
3667
+ "5": "airplane",
3668
+ "6": "bus",
3669
+ "7": "train",
3670
+ "8": "truck",
3671
+ "9": "boat",
3672
+ "10": "traffic light",
3673
+ "11": "fire hydrant",
3674
+ "12": "N/A",
3675
+ "13": "stop sign",
3676
+ "14": "parking meter",
3677
+ "15": "bench",
3678
+ "16": "bird",
3679
+ "17": "cat",
3680
+ "18": "dog",
3681
+ "19": "horse",
3682
+ "20": "sheep",
3683
+ "21": "cow",
3684
+ "22": "elephant",
3685
+ "23": "bear",
3686
+ "24": "zebra",
3687
+ "25": "giraffe",
3688
+ "26": "N/A",
3689
+ "27": "backpack",
3690
+ "28": "umbrella",
3691
+ "29": "N/A",
3692
+ "30": "N/A",
3693
+ "31": "handbag",
3694
+ "32": "tie",
3695
+ "33": "suitcase",
3696
+ "34": "frisbee",
3697
+ "35": "skis",
3698
+ "36": "snowboard",
3699
+ "37": "sports ball",
3700
+ "38": "kite",
3701
+ "39": "baseball bat",
3702
+ "40": "baseball glove",
3703
+ "41": "skateboard",
3704
+ "42": "surfboard",
3705
+ "43": "tennis racket",
3706
+ "44": "bottle",
3707
+ "45": "N/A",
3708
+ "46": "wine glass",
3709
+ "47": "cup",
3710
+ "48": "fork",
3711
+ "49": "knife",
3712
+ "50": "spoon",
3713
+ "51": "bowl",
3714
+ "52": "banana",
3715
+ "53": "apple",
3716
+ "54": "sandwich",
3717
+ "55": "orange",
3718
+ "56": "broccoli",
3719
+ "57": "carrot",
3720
+ "58": "hot dog",
3721
+ "59": "pizza",
3722
+ "60": "donut",
3723
+ "61": "cake",
3724
+ "62": "chair",
3725
+ "63": "couch",
3726
+ "64": "potted plant",
3727
+ "65": "bed",
3728
+ "66": "N/A",
3729
+ "67": "dining table",
3730
+ "68": "N/A",
3731
+ "69": "N/A",
3732
+ "70": "toilet",
3733
+ "71": "N/A",
3734
+ "72": "tv",
3735
+ "73": "laptop",
3736
+ "74": "mouse",
3737
+ "75": "remote",
3738
+ "76": "keyboard",
3739
+ "77": "cell phone",
3740
+ "78": "microwave",
3741
+ "79": "oven",
3742
+ "80": "toaster",
3743
+ "81": "sink",
3744
+ "82": "refrigerator",
3745
+ "83": "N/A",
3746
+ "84": "book",
3747
+ "85": "clock",
3748
+ "86": "vase",
3749
+ "87": "scissors",
3750
+ "88": "teddy bear",
3751
+ "89": "hair drier",
3752
+ "90": "toothbrush",
3753
+ },
3754
+ "image_size": [800, 1333],
3755
+ "initializer_range": 0.02,
3756
+ "intermediate_size": 768,
3757
+ "label2id": {
3758
+ "N/A": 83,
3759
+ "airplane": 5,
3760
+ "apple": 53,
3761
+ "backpack": 27,
3762
+ "banana": 52,
3763
+ "baseball bat": 39,
3764
+ "baseball glove": 40,
3765
+ "bear": 23,
3766
+ "bed": 65,
3767
+ "bench": 15,
3768
+ "bicycle": 2,
3769
+ "bird": 16,
3770
+ "boat": 9,
3771
+ "book": 84,
3772
+ "bottle": 44,
3773
+ "bowl": 51,
3774
+ "broccoli": 56,
3775
+ "bus": 6,
3776
+ "cake": 61,
3777
+ "car": 3,
3778
+ "carrot": 57,
3779
+ "cat": 17,
3780
+ "cell phone": 77,
3781
+ "chair": 62,
3782
+ "clock": 85,
3783
+ "couch": 63,
3784
+ "cow": 21,
3785
+ "cup": 47,
3786
+ "dining table": 67,
3787
+ "dog": 18,
3788
+ "donut": 60,
3789
+ "elephant": 22,
3790
+ "fire hydrant": 11,
3791
+ "fork": 48,
3792
+ "frisbee": 34,
3793
+ "giraffe": 25,
3794
+ "hair drier": 89,
3795
+ "handbag": 31,
3796
+ "horse": 19,
3797
+ "hot dog": 58,
3798
+ "keyboard": 76,
3799
+ "kite": 38,
3800
+ "knife": 49,
3801
+ "laptop": 73,
3802
+ "microwave": 78,
3803
+ "motorcycle": 4,
3804
+ "mouse": 74,
3805
+ "orange": 55,
3806
+ "oven": 79,
3807
+ "parking meter": 14,
3808
+ "person": 1,
3809
+ "pizza": 59,
3810
+ "potted plant": 64,
3811
+ "refrigerator": 82,
3812
+ "remote": 75,
3813
+ "sandwich": 54,
3814
+ "scissors": 87,
3815
+ "sheep": 20,
3816
+ "sink": 81,
3817
+ "skateboard": 41,
3818
+ "skis": 35,
3819
+ "snowboard": 36,
3820
+ "spoon": 50,
3821
+ "sports ball": 37,
3822
+ "stop sign": 13,
3823
+ "suitcase": 33,
3824
+ "surfboard": 42,
3825
+ "teddy bear": 88,
3826
+ "tennis racket": 43,
3827
+ "tie": 32,
3828
+ "toaster": 80,
3829
+ "toilet": 70,
3830
+ "toothbrush": 90,
3831
+ "traffic light": 10,
3832
+ "train": 7,
3833
+ "truck": 8,
3834
+ "tv": 72,
3835
+ "umbrella": 28,
3836
+ "vase": 86,
3837
+ "wine glass": 46,
3838
+ "zebra": 24,
3839
+ },
3840
+ "layer_norm_eps": 1e-12,
3841
+ "model_type": "yolos",
3842
+ "num_attention_heads": 3,
3843
+ "num_channels": 3,
3844
+ "num_detection_tokens": 100,
3845
+ "num_hidden_layers": 12,
3846
+ "patch_size": 16,
3847
+ "qkv_bias": true,
3848
+ "torch_dtype": "float32",
3849
+ "transformers_version": "4.19.0.dev0",
3850
+ "use_mid_position_embeddings": false,
3851
+ }
3852
+ )
@@ -4,7 +4,7 @@ import torch
4
4
  import transformers
5
5
  from ...helpers.config_helper import update_config
6
6
  from ...tasks import reduce_model_config, random_input_kwargs
7
- from .hub_api import task_from_arch, get_pretrained_config
7
+ from .hub_api import task_from_arch, task_from_id, get_pretrained_config
8
8
 
9
9
 
10
10
  def get_untrained_model_with_inputs(
@@ -18,6 +18,7 @@ def get_untrained_model_with_inputs(
18
18
  same_as_pretrained: bool = False,
19
19
  use_preinstalled: bool = True,
20
20
  add_second_input: bool = False,
21
+ subfolder: Optional[str] = None,
21
22
  ) -> Dict[str, Any]:
22
23
  """
23
24
  Gets a non initialized model similar to the original model
@@ -37,6 +38,7 @@ def get_untrained_model_with_inputs(
37
38
  :param use_preinstalled: use preinstalled configurations
38
39
  :param add_second_input: provides a second inputs to check a model
39
40
  supports different shapes
41
+ :param subfolder: subfolder to use for this model id
40
42
  :return: dictionary with a model, inputs, dynamic shapes, and the configuration
41
43
 
42
44
  Example:
@@ -62,24 +64,39 @@ def get_untrained_model_with_inputs(
62
64
  print(f"[get_untrained_model_with_inputs] use preinstalled {model_id!r}")
63
65
  if config is None:
64
66
  config = get_pretrained_config(
65
- model_id, use_preinstalled=use_preinstalled, **(model_kwargs or {})
67
+ model_id,
68
+ use_preinstalled=use_preinstalled,
69
+ subfolder=subfolder,
70
+ **(model_kwargs or {}),
66
71
  )
67
- archs = config.architectures # type: ignore
68
- assert archs is not None and len(archs) == 1, (
72
+ if hasattr(config, "architecture") and config.architecture:
73
+ archs = [config.architecture]
74
+ if type(config) is dict:
75
+ assert "_class_name" in config, f"Unable to get the architecture from config={config}"
76
+ archs = [config["_class_name"]]
77
+ else:
78
+ archs = config.architectures # type: ignore
79
+ task = None
80
+ if archs is None:
81
+ task = task_from_id(model_id)
82
+ assert task is not None or (archs is not None and len(archs) == 1), (
69
83
  f"Unable to determine the architecture for model {model_id!r}, "
70
84
  f"architectures={archs!r}, conf={config}"
71
85
  )
72
- arch = archs[0]
73
- if verbose:
74
- print(f"[get_untrained_model_with_inputs] architecture={arch!r}")
75
86
  if verbose:
87
+ print(f"[get_untrained_model_with_inputs] architectures={archs!r}")
76
88
  print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
77
- task = task_from_arch(arch)
89
+ if task is None:
90
+ task = task_from_arch(archs[0])
78
91
  if verbose:
79
92
  print(f"[get_untrained_model_with_inputs] task={task!r}")
80
93
 
81
94
  # model kwagrs
82
95
  if dynamic_rope is not None:
96
+ assert (
97
+ type(config) is not dict
98
+ ), f"Unable to set dynamic_rope if the configuration is a dictionary\n{config}"
99
+ assert hasattr(config, "rope_scaling"), f"Missing 'rope_scaling' in\n{config}"
83
100
  config.rope_scaling = (
84
101
  {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None
85
102
  )
@@ -106,7 +123,15 @@ def get_untrained_model_with_inputs(
106
123
  if inputs_kwargs:
107
124
  kwargs.update(inputs_kwargs)
108
125
 
109
- model = getattr(transformers, arch)(config)
126
+ if archs is not None:
127
+ model = getattr(transformers, archs[0])(config)
128
+ else:
129
+ assert same_as_pretrained, (
130
+ f"Model {model_id!r} cannot be built, the model cannot be built. "
131
+ f"It must be downloaded. Use same_as_pretrained=True."
132
+ )
133
+ model = None
134
+
110
135
  # This line is important. Some models may produce different
111
136
  # outputs even with the same inputs in training mode.
112
137
  model.eval()
@@ -4,6 +4,8 @@ import os
4
4
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
5
  import time
6
6
  import onnx
7
+ import onnxscript
8
+ import onnxscript.rewriter.ort_fusions as ort_fusions
7
9
  import torch
8
10
  from ..export import CoupleInputsDynamicShapes
9
11
  from ..helpers import max_diff, string_type, string_diff
@@ -107,9 +109,12 @@ def _make_folder_name(
107
109
  optimization: Optional[str] = None,
108
110
  dtype: Optional[Union[str, torch.dtype]] = None,
109
111
  device: Optional[Union[str, torch.device]] = None,
112
+ subfolder: Optional[str] = None,
110
113
  ) -> str:
111
114
  "Creates a filename unique based on the given options."
112
115
  els = [model_id.replace("/", "_")]
116
+ if subfolder:
117
+ els.append(subfolder.replace("/", "_"))
113
118
  if exporter:
114
119
  els.append(exporter)
115
120
  if optimization:
@@ -222,6 +227,7 @@ def validate_model(
222
227
  ortfusiontype: Optional[str] = None,
223
228
  input_options: Optional[Dict[str, Any]] = None,
224
229
  model_options: Optional[Dict[str, Any]] = None,
230
+ subfolder: Optional[str] = None,
225
231
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
226
232
  """
227
233
  Validates a model.
@@ -254,11 +260,11 @@ def validate_model(
254
260
  used to export
255
261
  :param model_options: additional options when creating the model such as
256
262
  ``num_hidden_layers`` or ``attn_implementation``
263
+ :param subfolder: version or subfolders to uses when retrieving a model id
257
264
  :return: two dictionaries, one with some metrics,
258
265
  another one with whatever the function produces
259
266
  """
260
267
  summary = version_summary()
261
-
262
268
  summary.update(
263
269
  dict(
264
270
  version_model_id=model_id,
@@ -280,7 +286,7 @@ def validate_model(
280
286
  folder_name = None
281
287
  if dump_folder:
282
288
  folder_name = _make_folder_name(
283
- model_id, exporter, optimization, dtype=dtype, device=device
289
+ model_id, exporter, optimization, dtype=dtype, device=device, subfolder=subfolder
284
290
  )
285
291
  dump_folder = os.path.join(dump_folder, folder_name)
286
292
  if not os.path.exists(dump_folder):
@@ -291,11 +297,15 @@ def validate_model(
291
297
  print(f"[validate_model] dump into {folder_name!r}")
292
298
 
293
299
  if verbose:
294
- print(f"[validate_model] validate model id {model_id!r}")
300
+ if subfolder:
301
+ print(f"[validate_model] validate model id {model_id!r}, subfolder={subfolder!r}")
302
+ else:
303
+ print(f"[validate_model] validate model id {model_id!r}")
295
304
  if model_options:
296
305
  print(f"[validate_model] model_options={model_options!r}")
297
306
  print(f"[validate_model] get dummy inputs with input_options={input_options}...")
298
307
  summary["model_id"] = model_id
308
+ summary["model_subfolder"] = subfolder or ""
299
309
 
300
310
  iop = input_options or {}
301
311
  mop = model_options or {}
@@ -305,7 +315,7 @@ def validate_model(
305
315
  summary,
306
316
  None,
307
317
  (
308
- lambda mid=model_id, v=verbose, task=task, tr=trained, iop=iop: (
318
+ lambda mid=model_id, v=verbose, task=task, tr=trained, iop=iop, sub=subfolder: (
309
319
  get_untrained_model_with_inputs(
310
320
  mid,
311
321
  verbose=v,
@@ -313,6 +323,7 @@ def validate_model(
313
323
  same_as_pretrained=tr,
314
324
  inputs_kwargs=iop,
315
325
  model_kwargs=mop,
326
+ subfolder=sub,
316
327
  )
317
328
  )
318
329
  ),
@@ -917,11 +928,10 @@ def call_torch_export_onnx(
917
928
  :return: two dictionaries, one with some metrics,
918
929
  another one with whatever the function produces
919
930
  """
920
- assert optimization in {
921
- "",
922
- "ir",
923
- None,
924
- }, f"unexpected value for optimization={optimization}"
931
+ available = {"", "ir", "os_ort"}
932
+ assert (
933
+ optimization in available
934
+ ), f"unexpected value for optimization={optimization}, available={available}"
925
935
  assert exporter in {
926
936
  "onnx-dynamo",
927
937
  "onnx-script",
@@ -1001,16 +1011,25 @@ def call_torch_export_onnx(
1001
1011
  print(epo)
1002
1012
  print("[call_torch_export_onnx] -- End of ONNXProgram")
1003
1013
 
1004
- if optimization == "ir":
1014
+ if optimization in {"ir", "os_ort"}:
1005
1015
  if verbose:
1006
1016
  print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
1007
- _quiet_or_not_quiet(
1008
- quiet,
1009
- "export_onnx_opt_ir",
1010
- summary,
1011
- data,
1012
- (lambda epo=epo: epo.optimize()),
1013
- )
1017
+ if optimization == "ir":
1018
+ label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
1019
+ else:
1020
+
1021
+ def _os_ort_optim(epo):
1022
+ onnxscript.optimizer.optimize_ir(epo.model)
1023
+ optimized = ort_fusions.optimize_for_ort(epo.model)
1024
+ if isinstance(optimized, tuple):
1025
+ for k, v in optimized[1].items():
1026
+ summary[f"op_opt_fused_{k}"] = v
1027
+ epo.model = optimized[0]
1028
+ else:
1029
+ epo.model = optimized
1030
+
1031
+ label, f_optim = "export_onnx_opt_os_ort", (lambda epo=epo: _os_ort_optim(epo))
1032
+ _quiet_or_not_quiet(quiet, label, summary, data, f_optim)
1014
1033
  if "ERR_export_onnx_opt_ir" in summary:
1015
1034
  return summary, data
1016
1035
  if verbose:
@@ -1039,21 +1058,27 @@ def call_torch_export_custom(
1039
1058
  :return: two dictionaries, one with some metrics,
1040
1059
  another one with whatever the function produces
1041
1060
  """
1042
- assert optimization in {
1061
+ available = {
1043
1062
  "",
1044
1063
  "default",
1045
1064
  "default+onnxruntime",
1065
+ "default+os_ort",
1066
+ "default+onnxruntime+os_ort",
1046
1067
  None,
1047
- }, f"unexpected value for optimization={optimization}"
1048
- assert exporter in {
1068
+ }
1069
+ assert (
1070
+ optimization in available
1071
+ ), f"unexpected value for optimization={optimization}, available={available}"
1072
+ available = {
1049
1073
  "custom",
1050
1074
  "custom-strict",
1051
- "custom-strict-dec",
1075
+ "custom-strict-default",
1052
1076
  "custom-strict-all",
1053
1077
  "custom-nostrict",
1054
- "custom-nostrict-dec",
1078
+ "custom-nostrict-default",
1055
1079
  "custom-nostrict-all",
1056
- }, f"Unexpected value for exporter={exporter!r}"
1080
+ }
1081
+ assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
1057
1082
  assert "model" in data, f"model is missing from data: {sorted(data)}"
1058
1083
  assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
1059
1084
  summary: Dict[str, Union[str, int, float]] = {}
@@ -1078,10 +1103,14 @@ def call_torch_export_custom(
1078
1103
  from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
1079
1104
  from experimental_experiment.xbuilder import OptimizationOptions
1080
1105
 
1106
+ spl = optimization.split("+") if optimization else []
1107
+ os_ort = "os_ort" in spl
1108
+ optimization = "+".join(_ for _ in spl if _ != "os_ort")
1109
+
1081
1110
  export_options = ExportOptions(
1082
1111
  strict=strict,
1083
1112
  decomposition_table=(
1084
- "dec" if "-dec" in exporter else ("all" if "-all" in exporter else None)
1113
+ "default" if "-default" in exporter else ("all" if "-all" in exporter else None)
1085
1114
  ),
1086
1115
  )
1087
1116
  options = OptimizationOptions(patterns=optimization) if optimization else None
@@ -1181,6 +1210,31 @@ def call_torch_export_custom(
1181
1210
  assert epo is not None, "no onnx export was found"
1182
1211
  if verbose:
1183
1212
  print("[call_torch_export_custom] done (export)")
1213
+
1214
+ if os_ort:
1215
+ if verbose:
1216
+ print("[call_torch_export_custom] conversion to IR...")
1217
+ begin = time.perf_counter()
1218
+ ir_model = epo.to_ir()
1219
+ duration = time.perf_counter() - begin
1220
+ summary["time_optim_to_ir"] = duration
1221
+ if verbose:
1222
+ print(f"[call_torch_export_custom] done in {duration}")
1223
+ print("[call_torch_export_custom] start optimization...")
1224
+ begin = time.perf_counter()
1225
+ onnxscript.optimizer.optimize_ir(ir_model)
1226
+ ir_optimized = ort_fusions.optimize_for_ort(ir_model)
1227
+ if isinstance(ir_optimized, tuple):
1228
+ report = ir_optimized[1]
1229
+ for k, v in report.items():
1230
+ summary[f"op_opt_fused_{k}"] = v
1231
+ ir_optimized = ir_optimized[0]
1232
+ epo.model = ir_optimized
1233
+ duration = time.perf_counter() - begin
1234
+ summary["time_optim_os_ort"] = duration
1235
+ if verbose:
1236
+ print(f"[call_torch_export_custom] done in {duration}")
1237
+
1184
1238
  data["onnx_program"] = epo
1185
1239
  return summary, data
1186
1240
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-diagnostic
3
- Version: 0.4.3
3
+ Version: 0.4.4
4
4
  Summary: Investigate ONNX models
5
5
  Home-page: https://github.com/sdpython/onnx-diagnostic
6
6
  Author: Xavier Dupré
@@ -93,6 +93,11 @@ or
93
93
  Enlightening Examples
94
94
  +++++++++++++++++++++
95
95
 
96
+ **Where to start to export a model**
97
+
98
+ * `Export microsoft/phi-2
99
+ <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_phi2.html>`_
100
+
96
101
  **Torch Export**
97
102
 
98
103
  * `Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints
@@ -1,6 +1,6 @@
1
- onnx_diagnostic/__init__.py,sha256=opPQ2jwxhWOe2Y2oDiKVTmNL4w0H1Gl0G921alsi0NM,164
1
+ onnx_diagnostic/__init__.py,sha256=VQWqNIBvo3L_1EFs9AQFHV8psqXB7jqUazJt4CrKt04,164
2
2
  onnx_diagnostic/__main__.py,sha256=YmyV_Aq_ianDlHyKLHMa6h8YK3ZmFPpLVHLKjM91aCk,79
3
- onnx_diagnostic/_command_lines_parser.py,sha256=kOECT1BccZc38vmVc3jF3xvXGDpcocvLuUGoPkzte08,14753
3
+ onnx_diagnostic/_command_lines_parser.py,sha256=hwh-o6DfkO9Oj1TNqC2RGgfO_lXnZlyobH16Kzffyx4,14917
4
4
  onnx_diagnostic/doc.py,sha256=MTuT7Kxyvn7KEy84liQeFeqhugJrUQhjjpx21F72Uxw,926
5
5
  onnx_diagnostic/ext_test_case.py,sha256=JTKlAXZL0EIzRCNUjP2VWHq7g9ueFSFd18JVrJY9SFM,41404
6
6
  onnx_diagnostic/export/__init__.py,sha256=yEIoWiOeTwBsDhyYt2fTKuhtA0Ya1J9u9ZzMTOTWaWs,101
@@ -49,40 +49,42 @@ onnx_diagnostic/reference/ops/op_skip_layer_normalization.py,sha256=oJ7fQNx2iQh9
49
49
  onnx_diagnostic/reference/ops/op_slice.py,sha256=yRxfYBs8b7QezyyG9JHCD8MIJHij2qR2NNDpBmD3FJI,705
50
50
  onnx_diagnostic/reference/ops/op_transpose_cast.py,sha256=ifef74rvh0Yvq1Zx51B4mfnISbxV9uRg9DFjkdL1_68,361
51
51
  onnx_diagnostic/reference/ops/op_tri_matrix.py,sha256=Yn2gxAyygcwtF5Hjau9ihXDAzul0BAkdqVimVahtFBU,519
52
- onnx_diagnostic/tasks/__init__.py,sha256=AkZDewmxKmLN86obwTIcPhpMWX2nOTPwKCSLwoJhXTM,1566
52
+ onnx_diagnostic/tasks/__init__.py,sha256=G9q-yZf-bU8udb7YtwPXf4bkfZh0lEXXzSKcG2HKvEU,1610
53
53
  onnx_diagnostic/tasks/automatic_speech_recognition.py,sha256=oRoYy56M0Yv_WOcn1hJXv-R9wgHkJ8rbym7j7y8oslw,6851
54
54
  onnx_diagnostic/tasks/feature_extraction.py,sha256=V-T5NpZ6EimOz00weWWxGfksZ9jQ5ZQyaP-mxuCEuJo,2223
55
55
  onnx_diagnostic/tasks/fill_mask.py,sha256=POUtgvOWv8wTOVLqxPNsj_C2WBiBWkmM72Z9mNlNqxI,2341
56
- onnx_diagnostic/tasks/image_classification.py,sha256=hxbvjSNLtewfY2PCkFhjxfBu0yHcTeRSRd79WSzBe2k,3527
56
+ onnx_diagnostic/tasks/image_classification.py,sha256=qgT9tbXby3dACZyXXjvfpm0a7-ey2-vxMCXtjoDusJw,4210
57
57
  onnx_diagnostic/tasks/image_text_to_text.py,sha256=6rKbts_p05VZL8wufJa6NP-MhxUOU-fuTAks5QfUVVQ,6037
58
58
  onnx_diagnostic/tasks/mixture_of_expert.py,sha256=orMx8Ly4DO0Po0tEmme4gi2flPIGip4TaAyxVik4Zgg,2685
59
+ onnx_diagnostic/tasks/object_detection.py,sha256=o1T8NMztjdFAFA-Z5efx-8nd9W7YZZcbE8Ag5wKVxZA,3930
59
60
  onnx_diagnostic/tasks/sentence_similarity.py,sha256=okQ-TQR8j1a92_N-eT6xN56rjtu26CdlU_pk88gdbGs,2356
60
61
  onnx_diagnostic/tasks/text2text_generation.py,sha256=jaJLQqKk38mAop7O3zCFQjUvmYmFTYWDVgzT7dSJPOw,7032
61
62
  onnx_diagnostic/tasks/text_classification.py,sha256=OgC_G9iumzTjTNUEvMoFFNTHCD8_BkdvdYC4jUsfpHM,2412
62
- onnx_diagnostic/tasks/text_generation.py,sha256=fTasu-igW-f9dyhYN4qXYkTWZU1ppgK37cmpvXV3i08,10215
63
+ onnx_diagnostic/tasks/text_generation.py,sha256=Wv8DamBHte355wXe_tAeVxG4EL20y86fu7JEmUM75to,10385
63
64
  onnx_diagnostic/tasks/zero_shot_image_classification.py,sha256=N3cEG1Lq95wS1N_CWUUUCU5j-4Tp5eR8Ce68U8THYAk,4380
64
- onnx_diagnostic/torch_export_patches/__init__.py,sha256=uRqg0-KSs_DhDnzrVp-TG2vfcDBO4HlsNkEg7RomQL0,246
65
+ onnx_diagnostic/torch_export_patches/__init__.py,sha256=SGR2aCWfxl8--rBF3_Zm0h6ovGyp8HknhkUsS1NyVNM,675
65
66
  onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=9WQUBAp5okQL9aJJKyp23ZumBnKt-qahcr94a9MYWxA,16083
66
67
  onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=1s1LqgqOL_hV6yqT7sgxzTKSDAL267CcZgNq8K4oTZM,14898
67
68
  onnx_diagnostic/torch_export_patches/patch_inputs.py,sha256=FQrMjwvEgPqvYY7ptfULzfexW5yJHo6Pzq_p1HDkNrY,7680
69
+ onnx_diagnostic/torch_export_patches/patch_module.py,sha256=66DXLeblDw5oLFJ-tGJYJzso4V2_q-96pQXR3L9nsvM,10891
68
70
  onnx_diagnostic/torch_export_patches/patches/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
69
71
  onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=TKLxrIJUrQsy0mbQwFcvn5e2l_s4OGKu5jnwWDHC6LE,15670
70
72
  onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=exiIq8zNZsY6QTzZVDMgU2ywGzs6-54Ic4vzTQ-26YQ,21863
71
73
  onnx_diagnostic/torch_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
72
74
  onnx_diagnostic/torch_models/llms.py,sha256=soyg4yC87ptGoeulJhKqw5opGmuLvH1pn_ZDXZ4Jr8E,90
73
- onnx_diagnostic/torch_models/test_helper.py,sha256=ad4C1vpnYkLMTrPMBdyk0OuUIUIAa7wP-zgCzPH88zc,46604
75
+ onnx_diagnostic/torch_models/test_helper.py,sha256=odRN3eIrK3kkvJCWTnZruUZGqGZrKR6r_Dp0dXQRDbw,49091
74
76
  onnx_diagnostic/torch_models/hghub/__init__.py,sha256=vi1Q7YHdddj1soiBN42MSvJdFqe2_KUoWafHISjwOu8,58
75
- onnx_diagnostic/torch_models/hghub/hub_api.py,sha256=EjwsmdHhf9ub1K5UCQPxsKiTMZy1dsdcRvNmxoZrc98,8621
76
- onnx_diagnostic/torch_models/hghub/hub_data.py,sha256=K9fu3NA530QEqYJcFTwACfsoRAIUrJoJLNtxDrUzM3c,7863
77
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py,sha256=IYinZGm6-Ob7fQHg7rE1OE0t5muCsEq5TpQiMgEsrgs,245009
78
- onnx_diagnostic/torch_models/hghub/model_inputs.py,sha256=B5c_-T_Ub9Mxs_DxpP4_yb4im-85ftVvAcUBgsISp1o,5743
77
+ onnx_diagnostic/torch_models/hghub/hub_api.py,sha256=BgM_p57Q0gT9GOhdrmOYcnbuTTzCWp80jS4OQqWwFhs,9990
78
+ onnx_diagnostic/torch_models/hghub/hub_data.py,sha256=LLXq4KIYdz80JpRTV_liHNkJu6qBjdC2cH6E4v98yGg,7972
79
+ onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py,sha256=dYL4E7-12JwhZQKuNhoDT1avb7V3Go5nZyaoPCGR92o,251610
80
+ onnx_diagnostic/torch_models/hghub/model_inputs.py,sha256=XIFyadFg2cbDSqOCFMm7aimh15PVp7kxOnZ_oVhnjps,6770
79
81
  onnx_diagnostic/torch_models/untrained/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
80
82
  onnx_diagnostic/torch_models/untrained/llm_phi2.py,sha256=ynBTDHJHCk44NjLT_t6OiFDBdPP0rFGPteiONDxvztw,3708
81
83
  onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py,sha256=7N3fGvT_4Mn4NbIo0Qk57c6DMc3OXGWyvj_P41rjwSY,3513
82
84
  onnx_diagnostic/torch_onnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
83
85
  onnx_diagnostic/torch_onnx/sbs.py,sha256=HEGDHhV9pfXxpBQrpOWPNWGMsNfOebWewyAazi9poV8,16872
84
- onnx_diagnostic-0.4.3.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
85
- onnx_diagnostic-0.4.3.dist-info/METADATA,sha256=AI3BrV-xYBj8qVBlnpPOMsKnIAtERv1fsjx7sA6hD-A,5496
86
- onnx_diagnostic-0.4.3.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
87
- onnx_diagnostic-0.4.3.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
88
- onnx_diagnostic-0.4.3.dist-info/RECORD,,
86
+ onnx_diagnostic-0.4.4.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
87
+ onnx_diagnostic-0.4.4.dist-info/METADATA,sha256=9c5HUQruT58E_auU0qmqSKxoJMCEKz_MU3dGvhIi_G8,5659
88
+ onnx_diagnostic-0.4.4.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
89
+ onnx_diagnostic-0.4.4.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
90
+ onnx_diagnostic-0.4.4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (79.0.1)
2
+ Generator: setuptools (80.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5