onnx-diagnostic 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +43 -1
  3. onnx_diagnostic/export/dynamic_shapes.py +7 -3
  4. onnx_diagnostic/ext_test_case.py +1 -1
  5. onnx_diagnostic/helpers/cache_helper.py +11 -1
  6. onnx_diagnostic/helpers/config_helper.py +7 -2
  7. onnx_diagnostic/helpers/helper.py +31 -0
  8. onnx_diagnostic/helpers/torch_test_helper.py +6 -0
  9. onnx_diagnostic/tasks/__init__.py +6 -2
  10. onnx_diagnostic/tasks/automatic_speech_recognition.py +22 -4
  11. onnx_diagnostic/tasks/feature_extraction.py +76 -0
  12. onnx_diagnostic/tasks/fill_mask.py +14 -3
  13. onnx_diagnostic/tasks/image_classification.py +16 -3
  14. onnx_diagnostic/tasks/image_text_to_text.py +24 -4
  15. onnx_diagnostic/tasks/mixture_of_expert.py +76 -0
  16. onnx_diagnostic/tasks/sentence_similarity.py +14 -3
  17. onnx_diagnostic/tasks/text2text_generation.py +19 -3
  18. onnx_diagnostic/tasks/text_classification.py +14 -3
  19. onnx_diagnostic/tasks/text_generation.py +69 -48
  20. onnx_diagnostic/tasks/zero_shot_image_classification.py +18 -3
  21. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -2
  22. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +6 -1
  23. onnx_diagnostic/torch_models/hghub/hub_api.py +12 -5
  24. onnx_diagnostic/torch_models/hghub/hub_data.py +2 -0
  25. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +71 -0
  26. onnx_diagnostic/torch_models/hghub/model_inputs.py +7 -3
  27. onnx_diagnostic/torch_models/test_helper.py +23 -5
  28. {onnx_diagnostic-0.4.0.dist-info → onnx_diagnostic-0.4.2.dist-info}/METADATA +1 -1
  29. {onnx_diagnostic-0.4.0.dist-info → onnx_diagnostic-0.4.2.dist-info}/RECORD +32 -30
  30. {onnx_diagnostic-0.4.0.dist-info → onnx_diagnostic-0.4.2.dist-info}/WHEEL +1 -1
  31. {onnx_diagnostic-0.4.0.dist-info → onnx_diagnostic-0.4.2.dist-info}/licenses/LICENSE.txt +0 -0
  32. {onnx_diagnostic-0.4.0.dist-info → onnx_diagnostic-0.4.2.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.4.0"
6
+ __version__ = "0.4.2"
7
7
  __author__ = "Xavier Dupré"
@@ -214,6 +214,22 @@ def get_parser_config() -> ArgumentParser:
214
214
  action=BooleanOptionalAction,
215
215
  help="displays the task as well",
216
216
  )
217
+ parser.add_argument(
218
+ "-c",
219
+ "--cached",
220
+ default=True,
221
+ action=BooleanOptionalAction,
222
+ help="uses cached configuration, only available for some of them, "
223
+ "mostly for unit test purposes",
224
+ )
225
+ parser.add_argument(
226
+ "--mop",
227
+ metavar="KEY=VALUE",
228
+ nargs="*",
229
+ help="Additional model options, use to change some parameters of the model, "
230
+ "example: --mop attn_implementation=eager",
231
+ action=_ParseDict,
232
+ )
217
233
  return parser
218
234
 
219
235
 
@@ -222,7 +238,11 @@ def _cmd_config(argv: List[Any]):
222
238
 
223
239
  parser = get_parser_config()
224
240
  args = parser.parse_args(argv[1:])
225
- print(get_pretrained_config(args.mid))
241
+ conf = get_pretrained_config(args.mid, **(args.mop or {}))
242
+ print(conf)
243
+ for k, v in sorted(conf.__dict__.items()):
244
+ if "_implementation" in k:
245
+ print(f"config.{k}={v!r}")
226
246
  if args.task:
227
247
  print("------")
228
248
  print(f"task: {task_from_id(args.mid)}")
@@ -238,6 +258,19 @@ class _ParseDict(argparse.Action):
238
258
  key = split_items[0].strip() # we remove blanks around keys, as is logical
239
259
  value = split_items[1]
240
260
 
261
+ if value in ("True", "true", "False", "false"):
262
+ d[key] = bool(value)
263
+ continue
264
+ try:
265
+ d[key] = int(value)
266
+ continue
267
+ except (TypeError, ValueError):
268
+ pass
269
+ try:
270
+ d[key] = float(value)
271
+ continue
272
+ except (TypeError, ValueError):
273
+ pass
241
274
  d[key] = value
242
275
 
243
276
  setattr(namespace, self.dest, d)
@@ -321,6 +354,14 @@ def get_parser_validate() -> ArgumentParser:
321
354
  "inputs use to export, example: --iop cls_cache=SlidingWindowCache",
322
355
  action=_ParseDict,
323
356
  )
357
+ parser.add_argument(
358
+ "--mop",
359
+ metavar="KEY=VALUE",
360
+ nargs="*",
361
+ help="Additional model options, use to change some parameters of the model, "
362
+ "example: --mop attn_implementation=eager",
363
+ action=_ParseDict,
364
+ )
324
365
  return parser
325
366
 
326
367
 
@@ -371,6 +412,7 @@ def _cmd_validate(argv: List[Any]):
371
412
  drop_inputs=None if not args.drop else args.drop.split(","),
372
413
  ortfusiontype=args.ortfusiontype,
373
414
  input_options=args.iop,
415
+ model_options=args.mop,
374
416
  )
375
417
  print("")
376
418
  print("-- summary --")
@@ -363,16 +363,20 @@ class CoupleInputsDynamicShapes:
363
363
  )
364
364
  if flatten_unflatten:
365
365
  flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
366
- return cls._generic_walker_step(
366
+ res = cls._generic_walker_step(
367
367
  processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
368
368
  )
369
- flat, _spec = torch.utils._pytree.tree_flatten(inputs)
369
+ # Should we restore the original class?
370
+ return res
371
+ flat, spec = torch.utils._pytree.tree_flatten(inputs)
370
372
  if all(isinstance(t, torch.Tensor) for t in flat):
371
373
  # We need to flatten dynamic shapes as well
372
374
  ds = flatten_dynamic_shapes(ds)
373
- return cls._generic_walker_step(
375
+ res = cls._generic_walker_step(
374
376
  processor, flat, ds, flatten_unflatten=flatten_unflatten
375
377
  )
378
+ # Then we restore the original class.
379
+ return torch.utils._pytree.tree_unflatten(res, spec)
376
380
 
377
381
  class ChangeDimensionProcessor:
378
382
  def __init__(self, desired_values):
@@ -461,7 +461,7 @@ def requires_sklearn(version: str, msg: str = "") -> Callable:
461
461
  return lambda x: x
462
462
 
463
463
 
464
- def requires_experimental(version: str = "", msg: str = "") -> Callable:
464
+ def requires_experimental(version: str = "0.0.0", msg: str = "") -> Callable:
465
465
  """Skips a unit test if :epkg:`experimental-experiment` is not recent enough."""
466
466
  import packaging.version as pv
467
467
 
@@ -155,6 +155,7 @@ def make_mamba_cache(
155
155
  key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
156
156
  ) -> transformers.cache_utils.MambaCache:
157
157
  "Creates a :class:`transformers.cache_utils.MambaCache`."
158
+ dtype = key_value_pairs[0][0].dtype
158
159
 
159
160
  class _config:
160
161
  def __init__(self):
@@ -162,14 +163,23 @@ def make_mamba_cache(
162
163
  self.conv_kernel = key_value_pairs[0][0].shape[-1]
163
164
  self.state_size = key_value_pairs[0][1].shape[-1]
164
165
  self.num_hidden_layers = len(key_value_pairs)
165
- self.dtype = key_value_pairs[0][0].dtype
166
+ self.dtype = dtype
166
167
 
167
168
  cache = transformers.cache_utils.MambaCache(
168
169
  _config(),
169
170
  max_batch_size=key_value_pairs[0][0].shape[0],
170
171
  device=key_value_pairs[0][0].device,
172
+ dtype=dtype,
171
173
  )
172
174
  for i in range(len(key_value_pairs)):
175
+ assert cache.conv_states[i].dtype == dtype, (
176
+ f"Type mismatch for cache.conv_states[{i}].dtype="
177
+ f"{cache.conv_states[i].dtype} != {dtype}"
178
+ )
179
+ assert cache.ssm_states[i].dtype == dtype, (
180
+ f"Type mismatch for cache.ssm_states[{i}].dtype="
181
+ f"{cache.ssm_states[i].dtype} != {dtype}"
182
+ )
173
183
  assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, (
174
184
  f"Shape mismatch, expected {cache.conv_states[i].shape}, "
175
185
  f"got {key_value_pairs[i][0].shape}"
@@ -28,13 +28,18 @@ def check_hasattr(config: Any, *args: Union[str, Tuple[Any, ...]]):
28
28
  def update_config(config: Any, mkwargs: Dict[str, Any]):
29
29
  """Updates a configuration with different values."""
30
30
  for k, v in mkwargs.items():
31
+ if k == "attn_implementation":
32
+ config._attn_implementation = v
33
+ if getattr(config, "_attn_implementation_autoset", False):
34
+ config._attn_implementation_autoset = False
35
+ continue
31
36
  if isinstance(v, dict):
32
37
  assert hasattr(
33
38
  config, k
34
39
  ), f"missing attribute {k!r} in config={config}, cannot update it with {v}"
35
40
  update_config(getattr(config, k), v)
36
- else:
37
- setattr(config, k, v)
41
+ continue
42
+ setattr(config, k, v)
38
43
 
39
44
 
40
45
  def _pick(config, *atts):
@@ -666,6 +666,15 @@ def string_type(
666
666
  print(f"[string_type] CACHE4:{type(obj)}")
667
667
  return f"{obj.__class__.__name__}(...)"
668
668
 
669
+ if obj.__class__.__name__.endswith("Config"):
670
+ import transformers.configuration_utils as tcu
671
+
672
+ if isinstance(obj, tcu.PretrainedConfig):
673
+ if verbose:
674
+ print(f"[string_type] CONFIG:{type(obj)}")
675
+ s = str(obj.to_diff_dict()).replace("\n", "").replace(" ", "")
676
+ return f"{obj.__class__.__name__}(**{s})"
677
+
669
678
  if verbose:
670
679
  print(f"[string_type] END:{type(obj)}")
671
680
  raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
@@ -1395,6 +1404,28 @@ def max_diff(
1395
1404
  f"level={level}"
1396
1405
  )
1397
1406
 
1407
+ if expected.__class__.__name__ == "SlidingWindowCache":
1408
+ if got.__class__.__name__ == "SlidingWindowCache":
1409
+ if verbose >= 6:
1410
+ print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
1411
+ return max_diff(
1412
+ [expected.key_cache, expected.value_cache],
1413
+ [got.key_cache, got.value_cache],
1414
+ verbose=verbose,
1415
+ )
1416
+ if isinstance(got, tuple) and len(got) == 2:
1417
+ return max_diff(
1418
+ [expected.key_cache, expected.value_cache],
1419
+ [got[0], got[1]],
1420
+ verbose=verbose,
1421
+ )
1422
+ raise AssertionError(
1423
+ f"SlidingWindowCache not fully implemented with classes "
1424
+ f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
1425
+ f"and expected={string_type(expected)}, got={string_type(got)},\n"
1426
+ f"level={level}"
1427
+ )
1428
+
1398
1429
  if expected.__class__.__name__ == "EncoderDecoderCache":
1399
1430
  if got.__class__.__name__ == "EncoderDecoderCache":
1400
1431
  if verbose >= 6:
@@ -8,6 +8,7 @@ from .cache_helper import (
8
8
  make_dynamic_cache,
9
9
  make_encoder_decoder_cache,
10
10
  make_sliding_window_cache,
11
+ make_mamba_cache,
11
12
  )
12
13
 
13
14
 
@@ -346,6 +347,8 @@ def torch_deepcopy(value: Any) -> Any:
346
347
  """
347
348
  Makes a deepcopy.
348
349
  """
350
+ if value is None:
351
+ return None
349
352
  if isinstance(value, (int, float, str)):
350
353
  return value
351
354
  if isinstance(value, tuple):
@@ -376,6 +379,9 @@ def torch_deepcopy(value: Any) -> Any:
376
379
  torch_deepcopy(value.self_attention_cache),
377
380
  torch_deepcopy(value.cross_attention_cache),
378
381
  )
382
+ if value.__class__.__name__ == "MambaCache":
383
+ return make_mamba_cache(list(zip(value.conv_states, value.ssm_states)))
384
+
379
385
  if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
380
386
  args, spec = torch.utils._pytree.tree_flatten(value)
381
387
  new_args = torch_deepcopy(args)
@@ -1,9 +1,11 @@
1
1
  from typing import Any, Callable, Dict, List, Tuple
2
2
  from . import (
3
3
  automatic_speech_recognition,
4
+ feature_extraction,
4
5
  fill_mask,
5
6
  image_classification,
6
7
  image_text_to_text,
8
+ mixture_of_expert,
7
9
  sentence_similarity,
8
10
  text_classification,
9
11
  text_generation,
@@ -13,9 +15,11 @@ from . import (
13
15
 
14
16
  __TASKS__ = [
15
17
  automatic_speech_recognition,
18
+ feature_extraction,
16
19
  fill_mask,
17
20
  image_classification,
18
21
  image_text_to_text,
22
+ mixture_of_expert,
19
23
  sentence_similarity,
20
24
  text_classification,
21
25
  text_generation,
@@ -33,7 +37,7 @@ def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
33
37
  """Reduces a model size."""
34
38
  tasks = {mod.__TASK__: mod.reduce_model_config for mod in __TASKS__}
35
39
  assert task in tasks, f"Task {task!r} not found in {sorted(tasks)}"
36
- return tasks[task](config, task)
40
+ return tasks[task](config)
37
41
 
38
42
 
39
43
  def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
@@ -45,4 +49,4 @@ def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callabl
45
49
  """
46
50
  tasks = {mod.__TASK__: mod.random_input_kwargs for mod in __TASKS__}
47
51
  assert task in tasks, f"Task {task!r} not found in {sorted(tasks)}"
48
- return tasks[task](config, task)
52
+ return tasks[task](config)
@@ -7,7 +7,7 @@ from ..helpers.config_helper import update_config, check_hasattr
7
7
  __TASK__ = "automatic-speech-recognition"
8
8
 
9
9
 
10
- def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
10
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
11
11
  """Reduces a model size."""
12
12
  kwargs: Dict[str, Any] = {}
13
13
  if hasattr(config, "num_decoder_layers"):
@@ -33,10 +33,11 @@ def get_inputs(
33
33
  head_dim: int,
34
34
  batch_size: int = 2,
35
35
  sequence_length: int = 30,
36
+ add_second_input: bool = False,
36
37
  **kwargs, # unused
37
38
  ):
38
39
  """
39
- Generates inputs for task ``text2text-generation``.
40
+ Generates inputs for task ``automatic-speech-recognition``.
40
41
  Example:
41
42
 
42
43
  ::
@@ -126,10 +127,27 @@ def get_inputs(
126
127
  # encoder_last_hidden_state=torch.randn(batch_size, sequence_length2, encoder_dim),
127
128
  # encoder_outputs=torch.randn(batch_size, sequence_length2, encoder_dim),
128
129
  )
129
- return dict(inputs=inputs, dynamic_shapes=shapes)
130
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
131
+ if add_second_input:
132
+ res["inputs2"] = get_inputs(
133
+ model=model,
134
+ config=config,
135
+ dummy_max_token_id=dummy_max_token_id,
136
+ max_source_positions=max_source_positions,
137
+ d_model=d_model,
138
+ num_hidden_layers=num_hidden_layers,
139
+ encoder_attention_heads=encoder_attention_heads,
140
+ encoder_layers=encoder_layers,
141
+ decoder_layers=decoder_layers,
142
+ head_dim=head_dim,
143
+ batch_size=batch_size + 1,
144
+ sequence_length=sequence_length + 1,
145
+ **kwargs,
146
+ )["inputs"]
147
+ return res
130
148
 
131
149
 
132
- def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
150
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
133
151
  """
134
152
  Inputs kwargs.
135
153
 
@@ -0,0 +1,76 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import update_config, check_hasattr
4
+
5
+ __TASK__ = "feature-extraction"
6
+
7
+
8
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
+ """Reduces a model size."""
10
+ check_hasattr(config, "num_attention_heads", "num_hidden_layers")
11
+ kwargs = dict(
12
+ num_hidden_layers=min(config.num_hidden_layers, 2),
13
+ num_attention_heads=min(config.num_attention_heads, 4),
14
+ )
15
+ update_config(config, kwargs)
16
+ return kwargs
17
+
18
+
19
+ def get_inputs(
20
+ model: torch.nn.Module,
21
+ config: Optional[Any],
22
+ batch_size: int,
23
+ sequence_length: int,
24
+ dummy_max_token_id: int,
25
+ add_second_input: bool = False,
26
+ **kwargs, # unused
27
+ ):
28
+ """
29
+ Generates inputs for task ``feature-extraction``.
30
+ Example:
31
+
32
+ ::
33
+
34
+ input_ids:T7s1x13[101,72654:A16789.23076923077],
35
+ token_type_ids:T7s1x13[0,0:A0.0],
36
+ attention_mask:T7s1x13[1,1:A1.0])
37
+ """
38
+ batch = torch.export.Dim("batch", min=1, max=1024)
39
+ seq_length = "sequence_length"
40
+ shapes = {
41
+ "input_ids": {0: batch, 1: seq_length},
42
+ "attention_mask": {0: batch, 1: seq_length},
43
+ }
44
+ inputs = dict(
45
+ input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
46
+ torch.int64
47
+ ),
48
+ attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
49
+ )
50
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
51
+ if add_second_input:
52
+ res["inputs2"] = get_inputs(
53
+ model=model,
54
+ config=config,
55
+ batch_size=batch_size + 1,
56
+ sequence_length=sequence_length + 1,
57
+ dummy_max_token_id=dummy_max_token_id,
58
+ **kwargs,
59
+ )["inputs"]
60
+ return res
61
+
62
+
63
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
64
+ """
65
+ Inputs kwargs.
66
+
67
+ If the configuration is None, the function selects typical dimensions.
68
+ """
69
+ if config is not None:
70
+ check_hasattr(config, "vocab_size")
71
+ kwargs = dict(
72
+ batch_size=2,
73
+ sequence_length=30,
74
+ dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
75
+ )
76
+ return kwargs, get_inputs
@@ -5,7 +5,7 @@ from ..helpers.config_helper import update_config, check_hasattr
5
5
  __TASK__ = "fill-mask"
6
6
 
7
7
 
8
- def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
8
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
9
  """Reduces a model size."""
10
10
  check_hasattr(config, "num_attention_heads", "num_hidden_layers")
11
11
  kwargs = dict(
@@ -22,6 +22,7 @@ def get_inputs(
22
22
  batch_size: int,
23
23
  sequence_length: int,
24
24
  dummy_max_token_id: int,
25
+ add_second_input: bool = False,
25
26
  **kwargs, # unused
26
27
  ):
27
28
  """
@@ -48,10 +49,20 @@ def get_inputs(
48
49
  token_type_ids=torch.zeros((batch_size, sequence_length)).to(torch.int64),
49
50
  attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
50
51
  )
51
- return dict(inputs=inputs, dynamic_shapes=shapes)
52
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
53
+ if add_second_input:
54
+ res["inputs2"] = get_inputs(
55
+ model=model,
56
+ config=config,
57
+ batch_size=batch_size + 1,
58
+ sequence_length=sequence_length + 1,
59
+ dummy_max_token_id=dummy_max_token_id,
60
+ **kwargs,
61
+ )["inputs"]
62
+ return res
52
63
 
53
64
 
54
- def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
65
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
55
66
  """
56
67
  Inputs kwargs.
57
68
 
@@ -5,7 +5,7 @@ from ..helpers.config_helper import update_config, check_hasattr
5
5
  __TASK__ = "image-classification"
6
6
 
7
7
 
8
- def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
8
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
9
  """Reduces a model size."""
10
10
  check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
11
11
  kwargs = dict(
@@ -27,6 +27,7 @@ def get_inputs(
27
27
  input_channels: int,
28
28
  batch_size: int = 2,
29
29
  dynamic_rope: bool = False,
30
+ add_second_input: bool = False,
30
31
  **kwargs, # unused
31
32
  ):
32
33
  """
@@ -59,10 +60,22 @@ def get_inputs(
59
60
  -1, 1
60
61
  ),
61
62
  )
62
- return dict(inputs=inputs, dynamic_shapes=shapes)
63
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
64
+ if add_second_input:
65
+ res["inputs2"] = get_inputs(
66
+ model=model,
67
+ config=config,
68
+ input_width=input_width + 1,
69
+ input_height=input_height + 1,
70
+ input_channels=input_channels,
71
+ batch_size=batch_size + 1,
72
+ dynamic_rope=dynamic_rope,
73
+ **kwargs,
74
+ )["inputs"]
75
+ return res
63
76
 
64
77
 
65
- def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
78
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
66
79
  """
67
80
  Inputs kwargs.
68
81
 
@@ -6,7 +6,7 @@ from ..helpers.config_helper import update_config, check_hasattr, _pick
6
6
  __TASK__ = "image-text-to-text"
7
7
 
8
8
 
9
- def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
9
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
10
10
  """Reduces a model size."""
11
11
  kwargs: Dict[str, Any] = {}
12
12
  if hasattr(config, "num_hidden_layers"):
@@ -32,10 +32,11 @@ def get_inputs(
32
32
  sequence_length2: int = 3,
33
33
  n_images: int = 2,
34
34
  dynamic_rope: bool = False,
35
+ add_second_input: bool = False,
35
36
  **kwargs, # unused
36
37
  ):
37
38
  """
38
- Generates input for task ``text-generation``.
39
+ Generates input for task ``image-text-to-text``.
39
40
 
40
41
  :param model: model to get the missing information
41
42
  :param config: configuration used to generate the model
@@ -99,10 +100,29 @@ def get_inputs(
99
100
  torch.int64
100
101
  ),
101
102
  )
102
- return dict(inputs=inputs, dynamic_shapes=shapes)
103
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
104
+ if add_second_input:
105
+ res["inputs2"] = get_inputs(
106
+ model=model,
107
+ config=config,
108
+ dummy_max_token_id=dummy_max_token_id,
109
+ num_key_value_heads=num_key_value_heads,
110
+ num_hidden_layers=num_hidden_layers,
111
+ head_dim=head_dim,
112
+ width=width,
113
+ height=height,
114
+ num_channels=num_channels,
115
+ batch_size=batch_size + 1,
116
+ sequence_length=sequence_length + 1,
117
+ sequence_length2=sequence_length2 + 1,
118
+ n_images=n_images + 1,
119
+ dynamic_rope=dynamic_rope,
120
+ **kwargs,
121
+ )["inputs"]
122
+ return res
103
123
 
104
124
 
105
- def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
125
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
106
126
  """
107
127
  Inputs kwargs.
108
128
 
@@ -0,0 +1,76 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+
4
+ # from ..helpers.cache_helper import make_dynamic_cache
5
+ from ..helpers.config_helper import update_config # , check_hasattr, _pick
6
+
7
+ __TASK__ = "MoE"
8
+
9
+
10
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
11
+ """Reduces a model size."""
12
+ kwargs: Dict[str, Any] = {}
13
+ if hasattr(config, "num_hidden_layers"):
14
+ config.num_hidden_layers = min(config.num_hidden_layers, 2)
15
+ if hasattr(config, "vision_config") and hasattr(config.vision_config, "num_hidden_layers"):
16
+ config.vision_config.num_hidden_layers = min(config.vision_config.num_hidden_layers, 2)
17
+ if hasattr(config, "audio_processor") and hasattr(
18
+ config.audio_processor, "num_hidden_layers"
19
+ ):
20
+ config.audio_processor.num_hidden_layers = min(
21
+ config.audio_processor.num_hidden_layers, 2
22
+ )
23
+ if hasattr(config, "audio_processor") and hasattr(config.audio_processor, "attention_dim"):
24
+ config.audio_processor.attention_dim = min(config.audio_processor.attention_dim, 2)
25
+ update_config(config, kwargs)
26
+ return kwargs
27
+
28
+
29
+ def get_inputs(
30
+ model: torch.nn.Module,
31
+ config: Optional[Any],
32
+ dummy_max_token_id: int,
33
+ num_key_value_heads: int,
34
+ num_hidden_layers: int,
35
+ head_dim: int,
36
+ width: int,
37
+ height: int,
38
+ num_channels: int,
39
+ batch_size: int = 2,
40
+ sequence_length: int = 30,
41
+ sequence_length2: int = 3,
42
+ n_images: int = 2,
43
+ dynamic_rope: bool = False,
44
+ add_second_input: bool = False,
45
+ **kwargs, # unused
46
+ ):
47
+ """
48
+ Generates input for task ``MoE``.
49
+
50
+ :param model: model to get the missing information
51
+ :param config: configuration used to generate the model
52
+ :param head_dim: last dimension of the cache
53
+ :param dummy_max_token_id: dummy max token id
54
+ :param batch_size: batch size
55
+ :param sequence_length: sequence length
56
+ :param sequence_length2: new sequence length
57
+ :param n_images: number of images
58
+ :param width: width of the image
59
+ :param height: height of the image
60
+ :param num_channels: number of channels
61
+ :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
62
+ :return: dictionary
63
+ """
64
+ assert not add_second_input, "add_second_input=True not yet implemented"
65
+ raise NotImplementedError(f"get_inputs not yet implemented for task {__TASK__!r}.")
66
+
67
+
68
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
69
+ """
70
+ Inputs kwargs.
71
+
72
+ If the configuration is None, the function selects typical dimensions.
73
+ """
74
+ raise NotImplementedError(
75
+ f"random_input_kwargs not yet implemented for task {__TASK__!r}."
76
+ )
@@ -5,7 +5,7 @@ from ..helpers.config_helper import update_config, check_hasattr
5
5
  __TASK__ = "sentence-similarity"
6
6
 
7
7
 
8
- def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
8
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
9
  """Reduces a model size."""
10
10
  check_hasattr(config, "num_attention_heads", "num_hidden_layers")
11
11
  kwargs = dict(
@@ -22,6 +22,7 @@ def get_inputs(
22
22
  batch_size: int,
23
23
  sequence_length: int,
24
24
  dummy_max_token_id: int,
25
+ add_second_input: bool = False,
25
26
  **kwargs, # unused
26
27
  ):
27
28
  """
@@ -48,10 +49,20 @@ def get_inputs(
48
49
  token_type_ids=torch.zeros((batch_size, sequence_length)).to(torch.int64),
49
50
  attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
50
51
  )
51
- return dict(inputs=inputs, dynamic_shapes=shapes)
52
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
53
+ if add_second_input:
54
+ res["inputs2"] = get_inputs(
55
+ model=model,
56
+ config=config,
57
+ batch_size=batch_size + 1,
58
+ sequence_length=sequence_length + 1,
59
+ dummy_max_token_id=dummy_max_token_id,
60
+ **kwargs,
61
+ )["inputs"]
62
+ return res
52
63
 
53
64
 
54
- def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
65
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
55
66
  """
56
67
  Inputs kwargs.
57
68