onnx-diagnostic 0.7.12__py3-none-any.whl → 0.7.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +7 -2
  3. onnx_diagnostic/export/dynamic_shapes.py +11 -2
  4. onnx_diagnostic/helpers/helper.py +11 -5
  5. onnx_diagnostic/helpers/log_helper.py +53 -17
  6. onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
  7. onnx_diagnostic/helpers/model_builder_helper.py +1 -0
  8. onnx_diagnostic/helpers/rt_helper.py +2 -1
  9. onnx_diagnostic/helpers/torch_helper.py +31 -7
  10. onnx_diagnostic/reference/torch_evaluator.py +2 -2
  11. onnx_diagnostic/tasks/data/__init__.py +13 -0
  12. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  13. onnx_diagnostic/tasks/image_text_to_text.py +256 -141
  14. onnx_diagnostic/tasks/text_generation.py +30 -0
  15. onnx_diagnostic/torch_export_patches/eval/__init__.py +184 -151
  16. onnx_diagnostic/torch_export_patches/eval/model_cases.py +20 -5
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +52 -20
  18. onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
  19. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +540 -10
  20. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
  21. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
  22. onnx_diagnostic/torch_models/hghub/model_inputs.py +55 -5
  23. onnx_diagnostic/torch_models/validate.py +116 -50
  24. onnx_diagnostic/torch_onnx/sbs.py +2 -1
  25. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/METADATA +11 -31
  26. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/RECORD +29 -27
  27. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/WHEEL +0 -0
  28. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/licenses/LICENSE.txt +0 -0
  29. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,20 @@
1
1
  import inspect
2
2
  import math
3
+ import os
3
4
  from dataclasses import dataclass
4
5
  from functools import wraps
5
- from typing import Callable, List, Optional, Tuple
6
+ from typing import Callable, List, Optional, Tuple, Union
6
7
  import packaging.version as pv
7
8
  import torch
8
9
  import transformers
9
10
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
10
11
  from transformers.cache_utils import StaticCache, Cache
12
+ from transformers.generation.utils import (
13
+ GenerateNonBeamOutput,
14
+ GenerationConfig,
15
+ StoppingCriteriaList,
16
+ LogitsProcessorList,
17
+ )
11
18
 
12
19
  try:
13
20
  from transformers.cache_utils import parse_processor_args # noqa: F401
@@ -114,6 +121,7 @@ if patch_masking_utils:
114
121
  """manual patch for function ``transformers.masking_utils.eager_mask``."""
115
122
  # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
116
123
  _ = kwargs.pop("allow_is_causal_skip", None)
124
+ # PATCHED: this line called the patched version of sdpa_mask
117
125
  mask = patched_sdpa_mask_recent_torch(
118
126
  batch_size=batch_size,
119
127
  cache_position=cache_position,
@@ -126,7 +134,7 @@ if patch_masking_utils:
126
134
  **kwargs,
127
135
  )
128
136
  min_dtype = torch.finfo(dtype).min
129
- # The patched line.
137
+ # PATCHED: the following line
130
138
  # we need 0s where the tokens should be taken into account,
131
139
  # and -inf otherwise (mask is already of boolean type)
132
140
  # mask =
@@ -158,6 +166,7 @@ if patch_masking_utils:
158
166
  mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
159
167
  batch_arange = torch.arange(batch_size, device=cache_position.device)
160
168
  head_arange = torch.arange(1, device=cache_position.device)
169
+ # PATCHED: this line calls the patched version of vmap_for_bhqkv
161
170
  causal_mask = patched__vmap_for_bhqkv(mask_function)(
162
171
  batch_arange, head_arange, cache_position, kv_arange
163
172
  )
@@ -214,6 +223,7 @@ if patch_DynamicLayer:
214
223
  self.dtype, self.device = key_states.dtype, key_states.device
215
224
  new_shape = list(key_states.shape)
216
225
  new_shape[-2] = 0
226
+ # PATCHED: used a tensor with an empty shape and not en empty list to initialize
217
227
  self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
218
228
  self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
219
229
  if patch_is_initialized:
@@ -248,6 +258,8 @@ def _patch_make_causal_mask(
248
258
  diagonal = past_key_values_length - sliding_window - 1
249
259
 
250
260
  context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
261
+ # PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
262
+ # and used masked_fill instead of masked_fill_
251
263
  # In this case, the current implementation of torch fails (17/12/2024).
252
264
  # Try model Phi-3.5-Mini-Instruct.
253
265
  mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
@@ -455,7 +467,16 @@ class patched_GenerationMixin:
455
467
  _PATCHES_ = [
456
468
  "_cache_dependant_input_preparation",
457
469
  "_cache_dependant_input_preparation_exporting",
458
- "prepare_inputs_for_generation",
470
+ (
471
+ None
472
+ if pv.Version(transformers.__version__) >= pv.Version("4.56")
473
+ else "prepare_inputs_for_generation"
474
+ ),
475
+ (
476
+ "_sample"
477
+ if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0")
478
+ else None
479
+ ),
459
480
  ]
460
481
  _PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin
461
482
 
@@ -588,7 +609,7 @@ class patched_GenerationMixin:
588
609
  model_inputs = {}
589
610
  # - some models don't have `Cache` support
590
611
  # (which implies they don't expect `cache_position` in `forward`)
591
- if self._supports_cache_class:
612
+ if getattr(self, "_supports_cache_class", False):
592
613
  model_inputs["cache_position"] = cache_position
593
614
  # - `cache_position` was not a mandatory input in
594
615
  # `prepare_inputs_for_generation` for those models, and this
@@ -728,6 +749,192 @@ class patched_GenerationMixin:
728
749
  model_inputs.pop("labels", None)
729
750
  return model_inputs
730
751
 
752
+ def _sample(
753
+ self,
754
+ input_ids: torch.LongTensor,
755
+ logits_processor: "LogitsProcessorList", # noqa: F821
756
+ stopping_criteria: "StoppingCriteriaList", # noqa: F821
757
+ generation_config: "GenerationConfig", # noqa: F821
758
+ synced_gpus: bool = False,
759
+ streamer: Optional["BaseStreamer"] = None, # noqa: F821
760
+ **model_kwargs,
761
+ ) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821
762
+ """
763
+ 2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
764
+ """
765
+ # init values
766
+ pad_token_id = generation_config._pad_token_tensor
767
+ output_attentions = generation_config.output_attentions
768
+ output_hidden_states = generation_config.output_hidden_states
769
+ output_scores = generation_config.output_scores
770
+ output_logits = generation_config.output_logits
771
+ return_dict_in_generate = generation_config.return_dict_in_generate
772
+ has_eos_stopping_criteria = any(
773
+ hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
774
+ )
775
+ do_sample = generation_config.do_sample
776
+
777
+ # init attention / hidden states / scores tuples
778
+ scores = () if (return_dict_in_generate and output_scores) else None
779
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
780
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
781
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
782
+ decoder_hidden_states = (
783
+ () if (return_dict_in_generate and output_hidden_states) else None
784
+ )
785
+
786
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
787
+ if return_dict_in_generate and self.config.is_encoder_decoder:
788
+ encoder_attentions = (
789
+ model_kwargs["encoder_outputs"].get("attentions")
790
+ if output_attentions
791
+ else None
792
+ )
793
+ encoder_hidden_states = (
794
+ model_kwargs["encoder_outputs"].get("hidden_states")
795
+ if output_hidden_states
796
+ else None
797
+ )
798
+
799
+ # keep track of which sequences are already finished
800
+ batch_size, cur_len = input_ids.shape[:2]
801
+ this_peer_finished = False
802
+ unfinished_sequences = torch.ones(
803
+ batch_size, dtype=torch.long, device=input_ids.device
804
+ )
805
+ model_kwargs = self._get_initial_cache_position(
806
+ cur_len, input_ids.device, model_kwargs
807
+ )
808
+
809
+ model_forward = self.__call__
810
+ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
811
+ if compile_forward:
812
+ os.environ["TOKENIZERS_PARALLELISM"] = "0"
813
+ # If we use FA2 and a static cache, we cannot compile with fullgraph
814
+ if self.config._attn_implementation == "flash_attention_2":
815
+ # only raise warning if the user passed an explicit compile-config
816
+ if (
817
+ generation_config.compile_config is not None
818
+ and generation_config.compile_config.fullgraph
819
+ ):
820
+ generation_config.compile_config.fullgraph = False
821
+ model_forward = self.get_compiled_call(generation_config.compile_config)
822
+
823
+ if generation_config.prefill_chunk_size is not None:
824
+ model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
825
+ is_prefill = False
826
+ else:
827
+ is_prefill = True
828
+
829
+ while self._has_unfinished_sequences(
830
+ this_peer_finished, synced_gpus, device=input_ids.device
831
+ ):
832
+ # prepare model inputs
833
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
834
+
835
+ if is_prefill:
836
+ outputs = self(**model_inputs, return_dict=True)
837
+ is_prefill = False
838
+ else:
839
+ outputs = model_forward(**model_inputs, return_dict=True)
840
+
841
+ model_kwargs = self._update_model_kwargs_for_generation(
842
+ outputs,
843
+ model_kwargs,
844
+ is_encoder_decoder=self.config.is_encoder_decoder,
845
+ )
846
+ if synced_gpus and this_peer_finished:
847
+ continue
848
+
849
+ next_token_logits = outputs.logits[:, -1, :].to(
850
+ copy=True, dtype=torch.float32, device=input_ids.device
851
+ )
852
+
853
+ # pre-process distribution
854
+ next_token_scores = logits_processor(input_ids, next_token_logits)
855
+
856
+ # Store scores, attentions and hidden_states when required
857
+ if return_dict_in_generate:
858
+ if output_scores:
859
+ scores += (next_token_scores,)
860
+ if output_logits:
861
+ raw_logits += (next_token_logits,)
862
+ if output_attentions:
863
+ decoder_attentions += (
864
+ (outputs.decoder_attentions,)
865
+ if self.config.is_encoder_decoder
866
+ else (outputs.attentions,)
867
+ )
868
+ if self.config.is_encoder_decoder:
869
+ cross_attentions += (outputs.cross_attentions,)
870
+
871
+ if output_hidden_states:
872
+ decoder_hidden_states += (
873
+ (outputs.decoder_hidden_states,)
874
+ if self.config.is_encoder_decoder
875
+ else (outputs.hidden_states,)
876
+ )
877
+
878
+ # token selection
879
+ if do_sample:
880
+ probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
881
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
882
+ else:
883
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
884
+
885
+ # finished sentences should have their next token be a padding token
886
+ if has_eos_stopping_criteria:
887
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
888
+ 1 - unfinished_sequences
889
+ )
890
+
891
+ # update generated ids, model inputs, and length for next step
892
+ # PATCHED: the two following lines, next_tokens can 2D already for this model
893
+ next_tokens_2d = (
894
+ next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None]
895
+ )
896
+ input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1)
897
+ if streamer is not None:
898
+ streamer.put(next_tokens.cpu())
899
+
900
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
901
+ this_peer_finished = unfinished_sequences.max() == 0
902
+ cur_len += 1
903
+
904
+ # This is needed to properly delete outputs.logits which may be very large
905
+ # for first iteration
906
+ # Otherwise a reference to outputs is kept which keeps
907
+ # the logits alive in the next iteration
908
+ del outputs
909
+
910
+ if streamer is not None:
911
+ streamer.end()
912
+
913
+ if return_dict_in_generate:
914
+ if self.config.is_encoder_decoder:
915
+ return transformers.generation.utils.GenerateEncoderDecoderOutput(
916
+ sequences=input_ids,
917
+ scores=scores,
918
+ logits=raw_logits,
919
+ encoder_attentions=encoder_attentions,
920
+ encoder_hidden_states=encoder_hidden_states,
921
+ decoder_attentions=decoder_attentions,
922
+ cross_attentions=cross_attentions,
923
+ decoder_hidden_states=decoder_hidden_states,
924
+ past_key_values=model_kwargs.get("past_key_values"),
925
+ )
926
+ else:
927
+ return transformers.generation.utils.GenerateDecoderOnlyOutput(
928
+ sequences=input_ids,
929
+ scores=scores,
930
+ logits=raw_logits,
931
+ attentions=decoder_attentions,
932
+ hidden_states=decoder_hidden_states,
933
+ past_key_values=model_kwargs.get("past_key_values"),
934
+ )
935
+ else:
936
+ return input_ids
937
+
731
938
 
732
939
  def patched__compute_dynamic_ntk_parameters(
733
940
  config: Optional[transformers.PretrainedConfig] = None,
@@ -791,6 +998,7 @@ def patched__compute_dynamic_ntk_parameters(
791
998
  if seq_len is None:
792
999
  seq_len = max_position_embeddings
793
1000
  else:
1001
+ # PATCHED: remove the line using max
794
1002
  torch._check(isinstance(seq_len, torch.Tensor))
795
1003
  seq_len = torch.maximum(
796
1004
  seq_len,
@@ -896,6 +1104,7 @@ def patched_dynamic_rope_update(rope_forward):
896
1104
  )
897
1105
  original_inv_freq = self.original_inv_freq.to(device)
898
1106
 
1107
+ # PATCHED: uses torch.cond instead of a test
899
1108
  cond = (seq_len > original_max_position_embeddings).item()
900
1109
  inv_freq = torch.cond(
901
1110
  cond,
@@ -967,6 +1176,7 @@ def patched_dynamic_rope_update(rope_forward):
967
1176
 
968
1177
  original_inv_freq = self.original_inv_freq.to(device)
969
1178
  cond = (seq_len >= self.original_max_seq_len).item()
1179
+ # PATCHED: uses torch.cond instead of a test
970
1180
  inv_freq = torch.cond(
971
1181
  cond,
972
1182
  (lambda x, y: x.clone()),
@@ -1002,6 +1212,7 @@ def common_eager_attention_forward(
1002
1212
 
1003
1213
  attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
1004
1214
  if attention_mask is not None:
1215
+ # PATCHED
1005
1216
  # The two following lines were added.
1006
1217
  if attention_mask is not None and attention_mask.ndim == 4:
1007
1218
  attention_mask = attention_mask[:, :, :, : key.shape[-2]]
@@ -1074,6 +1285,7 @@ def patched_modeling_marian_eager_attention_forward(
1074
1285
  class common_RotaryEmbedding(torch.nn.Module):
1075
1286
  # This may cause some issues.
1076
1287
  # @torch.no_grad()
1288
+ # PATCHED: the decorator
1077
1289
  @patched_dynamic_rope_update
1078
1290
  def forward(self, x, position_ids):
1079
1291
  inv_freq_expanded = (
@@ -1629,3 +1841,56 @@ if patch_qwen3:
1629
1841
  batch_size, sequence_length, hidden_dim
1630
1842
  )
1631
1843
  return final_hidden_states, router_logits
1844
+
1845
+
1846
+ try:
1847
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3Model # noqa: F401
1848
+
1849
+ patch_gemma3 = True
1850
+ except ImportError:
1851
+ patch_gemma3 = False
1852
+
1853
+
1854
+ if patch_gemma3:
1855
+
1856
+ class patched_Gemma3Model(torch.nn.Module):
1857
+ _PATCHES_ = ["get_placeholder_mask"]
1858
+ _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3Model
1859
+ _PATCHED_PR_ = "https://github.com/huggingface/transformers/pull/41319"
1860
+
1861
+ def get_placeholder_mask(
1862
+ self,
1863
+ input_ids: torch.LongTensor,
1864
+ inputs_embeds: torch.FloatTensor,
1865
+ image_features: torch.FloatTensor,
1866
+ ):
1867
+ if input_ids is None:
1868
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
1869
+ torch.tensor(
1870
+ self.config.image_token_id,
1871
+ dtype=torch.long,
1872
+ device=inputs_embeds.device,
1873
+ )
1874
+ )
1875
+ special_image_mask = special_image_mask.all(-1)
1876
+ else:
1877
+ special_image_mask = input_ids == self.config.image_token_id
1878
+
1879
+ n_image_tokens = special_image_mask.sum()
1880
+ special_image_mask = (
1881
+ special_image_mask.unsqueeze(-1)
1882
+ .expand_as(inputs_embeds)
1883
+ .to(inputs_embeds.device)
1884
+ )
1885
+ n_image_features = image_features.shape[0] * image_features.shape[1]
1886
+ # PATCHED: torch._check
1887
+ # if inputs_embeds[special_image_mask].numel() != image_features.numel():
1888
+ # raise ValueError( ... )
1889
+ torch._check(
1890
+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
1891
+ lambda: (
1892
+ f"Image features and image tokens do not match: tokens: "
1893
+ f"{n_image_tokens}, features {n_image_features}"
1894
+ ),
1895
+ )
1896
+ return special_image_mask
@@ -4829,3 +4829,39 @@ def _ccached_microsoft_phi3_mini_128k_instruct():
4829
4829
  "vocab_size": 32064,
4830
4830
  }
4831
4831
  )
4832
+
4833
+
4834
+ def _ccached_google_gemma_3_4b_it_like():
4835
+ "google/gemma-3-4b-it"
4836
+ return transformers.Gemma3Config(
4837
+ **{
4838
+ "architectures": ["Gemma3ForConditionalGeneration"],
4839
+ "boi_token_index": 255999,
4840
+ "eoi_token_index": 256000,
4841
+ "eos_token_id": [1, 106],
4842
+ "image_token_index": 262144,
4843
+ "initializer_range": 0.02,
4844
+ "mm_tokens_per_image": 256,
4845
+ "model_type": "gemma3",
4846
+ "text_config": {
4847
+ "hidden_size": 2560,
4848
+ "intermediate_size": 10240,
4849
+ "model_type": "gemma3_text",
4850
+ "num_hidden_layers": 34,
4851
+ "rope_scaling": {"factor": 8.0, "rope_type": "linear"},
4852
+ "sliding_window": 1024,
4853
+ },
4854
+ "torch_dtype": "bfloat16",
4855
+ "transformers_version": "4.50.0.dev0",
4856
+ "vision_config": {
4857
+ "hidden_size": 1152,
4858
+ "image_size": 896,
4859
+ "intermediate_size": 4304,
4860
+ "model_type": "siglip_vision_model",
4861
+ "num_attention_heads": 16,
4862
+ "num_hidden_layers": 27,
4863
+ "patch_size": 14,
4864
+ "vision_use_head": false,
4865
+ },
4866
+ }
4867
+ )
@@ -25,6 +25,20 @@ def _code_needing_rewriting(model: Any) -> Any:
25
25
  return code_needing_rewriting(model)
26
26
 
27
27
 
28
+ def _preprocess_model_id(
29
+ model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
30
+ ) -> Tuple[str, Optional[str], bool, bool]:
31
+ if subfolder or "//" not in model_id:
32
+ return model_id, subfolder, same_as_pretrained, use_pretrained
33
+ spl = model_id.split("//")
34
+ if spl[-1] == "pretrained":
35
+ return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
36
+ if spl[-1] in {"transformer", "vae"}:
37
+ # known subfolder
38
+ return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
39
+ return model_id, subfolder, same_as_pretrained, use_pretrained
40
+
41
+
28
42
  def get_untrained_model_with_inputs(
29
43
  model_id: str,
30
44
  config: Optional[Any] = None,
@@ -57,7 +71,7 @@ def get_untrained_model_with_inputs(
57
71
  to get a smaller model
58
72
  :param use_pretrained: download the pretrained weights as well
59
73
  :param use_preinstalled: use preinstalled configurations
60
- :param add_second_input: provides a second inputs to check a model
74
+ :param add_second_input: provides others inputs to check a model
61
75
  supports different shapes
62
76
  :param subfolder: subfolder to use for this model id
63
77
  :param use_only_preinstalled: use only preinstalled version
@@ -85,8 +99,16 @@ def get_untrained_model_with_inputs(
85
99
  f"model_id={model_id!r}, preinstalled model is only available "
86
100
  f"if use_only_preinstalled is False."
87
101
  )
102
+ model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
103
+ model_id,
104
+ subfolder,
105
+ same_as_pretrained=same_as_pretrained,
106
+ use_pretrained=use_pretrained,
107
+ )
88
108
  if verbose:
89
- print(f"[get_untrained_model_with_inputs] model_id={model_id!r}")
109
+ print(
110
+ f"[get_untrained_model_with_inputs] model_id={model_id!r}, subfolder={subfolder!r}"
111
+ )
90
112
  if use_preinstalled:
91
113
  print(f"[get_untrained_model_with_inputs] use preinstalled {model_id!r}")
92
114
  if config is None:
@@ -178,7 +200,7 @@ def get_untrained_model_with_inputs(
178
200
 
179
201
  if verbose:
180
202
  print(
181
- f"[get_untrained_model_with_inputs] package_source={package_source.__name__} é"
203
+ f"[get_untrained_model_with_inputs] package_source={package_source.__name__} "
182
204
  f"from {package_source.__file__}"
183
205
  )
184
206
  if use_pretrained:
@@ -193,7 +215,7 @@ def get_untrained_model_with_inputs(
193
215
  )
194
216
  if verbose:
195
217
  print(
196
- f"[get_untrained_model_with_inputs] -- done in "
218
+ f"[get_untrained_model_with_inputs] -- done(1) in "
197
219
  f"{time.perf_counter() - begin}s"
198
220
  )
199
221
  else:
@@ -250,14 +272,36 @@ def get_untrained_model_with_inputs(
250
272
  )
251
273
  if verbose:
252
274
  print(
253
- f"[get_untrained_model_with_inputs] -- done in "
275
+ f"[get_untrained_model_with_inputs] -- done(2) in "
254
276
  f"{time.perf_counter() - begin}s"
255
277
  )
256
278
 
257
279
  seed = int(os.environ.get("SEED", "17"))
258
280
  torch.manual_seed(seed)
281
+
282
+ if verbose:
283
+ begin = time.perf_counter()
284
+ print(
285
+ f"[get_untrained_model_with_inputs] "
286
+ f"instantiate_specific_model {cls_model}"
287
+ )
288
+
259
289
  model = instantiate_specific_model(cls_model, config)
290
+
291
+ if verbose:
292
+ print(
293
+ f"[get_untrained_model_with_inputs] -- done(3) in "
294
+ f"{time.perf_counter() - begin}s (model is {type(model)})"
295
+ )
296
+
260
297
  if model is None:
298
+
299
+ if verbose:
300
+ print(
301
+ f"[get_untrained_model_with_inputs] "
302
+ f"instantiate_specific_model(2) {cls_model}"
303
+ )
304
+
261
305
  try:
262
306
  if type(config) is dict:
263
307
  model = cls_model(**config)
@@ -268,6 +312,12 @@ def get_untrained_model_with_inputs(
268
312
  f"Unable to instantiate class {cls_model.__name__} with\n{config}"
269
313
  ) from e
270
314
 
315
+ if verbose:
316
+ print(
317
+ f"[get_untrained_model_with_inputs] -- done(4) in "
318
+ f"{time.perf_counter() - begin}s (model is {type(model)})"
319
+ )
320
+
271
321
  # input kwargs
272
322
  seed = int(os.environ.get("SEED", "17")) + 1
273
323
  torch.manual_seed(seed)