onnx-diagnostic 0.6.3__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +281 -80
  3. onnx_diagnostic/doc.py +22 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +48 -20
  5. onnx_diagnostic/export/shape_helper.py +126 -0
  6. onnx_diagnostic/ext_test_case.py +1 -1
  7. onnx_diagnostic/helpers/cache_helper.py +78 -8
  8. onnx_diagnostic/helpers/config_helper.py +8 -4
  9. onnx_diagnostic/helpers/helper.py +30 -3
  10. onnx_diagnostic/helpers/log_helper.py +1744 -0
  11. onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
  12. onnx_diagnostic/helpers/model_builder_helper.py +54 -73
  13. onnx_diagnostic/helpers/torch_helper.py +18 -2
  14. onnx_diagnostic/reference/__init__.py +1 -0
  15. onnx_diagnostic/reference/ort_evaluator.py +29 -4
  16. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  17. onnx_diagnostic/reference/torch_evaluator.py +21 -0
  18. onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
  19. onnx_diagnostic/tasks/feature_extraction.py +3 -0
  20. onnx_diagnostic/tasks/fill_mask.py +3 -0
  21. onnx_diagnostic/tasks/image_classification.py +7 -1
  22. onnx_diagnostic/tasks/image_text_to_text.py +72 -18
  23. onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
  24. onnx_diagnostic/tasks/object_detection.py +3 -0
  25. onnx_diagnostic/tasks/sentence_similarity.py +3 -0
  26. onnx_diagnostic/tasks/summarization.py +3 -0
  27. onnx_diagnostic/tasks/text2text_generation.py +3 -0
  28. onnx_diagnostic/tasks/text_classification.py +3 -0
  29. onnx_diagnostic/tasks/text_generation.py +90 -43
  30. onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
  31. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
  32. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
  33. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
  34. onnx_diagnostic/torch_models/hghub/hub_api.py +81 -8
  35. onnx_diagnostic/torch_models/hghub/hub_data.py +6 -2
  36. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +58 -14
  38. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
  39. onnx_diagnostic/torch_models/{test_helper.py → validate.py} +166 -106
  40. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/METADATA +2 -2
  41. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/RECORD +44 -41
  42. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/WHEEL +0 -0
  43. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
  44. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@ from ...helpers.torch_helper import is_torchdynamo_exporting
11
11
 
12
12
 
13
13
  def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
14
- """Patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
14
+ """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
15
15
  from ...helpers import string_type
16
16
 
17
17
  dimensions: List[Tuple[Optional[int], ...]] = [
@@ -534,19 +534,169 @@ class patched_GenerationMixin:
534
534
  return model_inputs
535
535
 
536
536
 
537
- def patched_dynamic_rope_update(rope_forward):
537
+ def patched__compute_dynamic_ntk_parameters(
538
+ config: Optional[transformers.PretrainedConfig] = None,
539
+ device: Optional["torch.device"] = None,
540
+ seq_len: Optional[int] = None,
541
+ **rope_kwargs,
542
+ ) -> Tuple["torch.Tensor", float]:
543
+ """
544
+ manual patch:
545
+ ``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
546
+
547
+ Computes the inverse frequencies with NTK scaling.
548
+ Credits to the Reddit users /u/bloc97 and /u/emozilla
549
+
550
+ Args:
551
+ config ([`~transformers.PretrainedConfig`]):
552
+ The model configuration.
553
+ device (`torch.device`):
554
+ The device to use for initialization of the inverse frequencies.
555
+ seq_len (`int`, *optional*):
556
+ The current sequence length,
557
+ used to update the dynamic RoPE at inference time.
558
+ rope_kwargs (`Dict`, *optional*):
559
+ BC compatibility with the previous
560
+ RoPE class instantiation, will be removed in v4.45.
561
+
562
+ Returns:
563
+ Tuple of (`torch.Tensor`, `float`),
564
+ containing the inverse frequencies for the RoPE embeddings and the
565
+ post-processing scaling factor applied to the
566
+ omputed cos/sin (unused in this type of RoPE).
538
567
  """
539
- patch:transformers.modeling_rope_utils.dynamic_rope_update
568
+ if config is not None and len(rope_kwargs) > 0:
569
+ raise ValueError(
570
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
571
+ f"`_compute_dynamic_ntk_parameters`, got "
572
+ f"`rope_kwargs`={rope_kwargs} and `config`={config}"
573
+ )
574
+ if len(rope_kwargs) > 0:
575
+ base = rope_kwargs["base"]
576
+ dim = rope_kwargs["dim"]
577
+ max_position_embeddings = rope_kwargs["max_position_embeddings"]
578
+ factor = rope_kwargs["factor"]
579
+ elif config is not None:
580
+ base = config.rope_theta
581
+ partial_rotary_factor = (
582
+ config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
583
+ )
584
+ head_dim = getattr(
585
+ config, "head_dim", config.hidden_size // config.num_attention_heads
586
+ )
587
+ dim = int(head_dim * partial_rotary_factor)
588
+ max_position_embeddings = config.max_position_embeddings
589
+ factor = config.rope_scaling["factor"]
590
+
591
+ attention_factor = 1.0 # Unused in this type of RoPE
592
+
593
+ # seq_len: default to max_position_embeddings, e.g. at init time
594
+ # seq_len = seq_len if seq_len is not None and
595
+ # seq_len > max_position_embeddings else max_position_embeddings
596
+ if seq_len is None:
597
+ seq_len = max_position_embeddings
598
+ else:
599
+ torch._check(isinstance(seq_len, torch.Tensor))
600
+ seq_len = torch.maximum(
601
+ seq_len,
602
+ torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
603
+ )
604
+
605
+ # Compute the inverse frequencies
606
+ base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (
607
+ dim / (dim - 2)
608
+ )
609
+ inv_freq = 1.0 / (
610
+ base
611
+ ** (
612
+ torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
613
+ / dim
614
+ )
615
+ )
616
+ return inv_freq, attention_factor
617
+
618
+
619
+ def patched_dynamic_rope_update(rope_forward):
620
+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
621
+
622
+ ``rope_type`` is determined in the constructor of class
623
+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
624
+
625
+ .. code-block:: python
626
+
627
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
628
+ self.rope_type = config.rope_scaling.get(
629
+ "rope_type", config.rope_scaling.get("type"))
630
+ else:
631
+ self.rope_type = "default"
632
+
633
+ The original code of the patched function:
634
+
635
+ .. code-block:: python
636
+
637
+ def dynamic_rope_update(rope_forward):
638
+ def longrope_frequency_update(self, position_ids, device):
639
+ seq_len = torch.max(position_ids) + 1
640
+ if hasattr(self.config, "original_max_position_embeddings"):
641
+ original_max_position_embeddings =
642
+ self.config.original_max_position_embeddings
643
+ else:
644
+ original_max_position_embeddings =
645
+ self.config.max_position_embeddings
646
+ if seq_len > original_max_position_embeddings:
647
+ if not hasattr(self, "long_inv_freq"):
648
+ self.long_inv_freq, _ = self.rope_init_fn(
649
+ self.config, device, seq_len=original_max_position_embeddings + 1
650
+ )
651
+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
652
+ else:
653
+ self.original_inv_freq = self.original_inv_freq.to(device)
654
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
655
+
656
+ def dynamic_frequency_update(self, position_ids, device):
657
+ seq_len = torch.max(position_ids) + 1
658
+ if seq_len > self.max_seq_len_cached: # growth
659
+ inv_freq, self.attention_scaling = self.rope_init_fn(
660
+ self.config, device, seq_len=seq_len)
661
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
662
+ self.max_seq_len_cached = seq_len
663
+
664
+ if seq_len < self.original_max_seq_len and
665
+ self.max_seq_len_cached > self.original_max_seq_len:
666
+ self.original_inv_freq = self.original_inv_freq.to(device)
667
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
668
+ self.max_seq_len_cached = self.original_max_seq_len
669
+
670
+ @wraps(rope_forward)
671
+ def wrapper(self, x, position_ids):
672
+ if "dynamic" in self.rope_type:
673
+ dynamic_frequency_update(self, position_ids, device=x.device)
674
+ elif self.rope_type == "longrope":
675
+ longrope_frequency_update(self, position_ids, device=x.device)
676
+ return rope_forward(self, x, position_ids)
677
+
678
+ return wrapper
679
+
540
680
  """
541
681
 
542
682
  def longrope_frequency_update(self, position_ids, device):
683
+ # It is no use to patch the function after the model is created
684
+ # as rope_init_fn is an attribute set to one function when the model
685
+ # is created and when no patch is applied yet.
686
+ # So we select the patched version here.
687
+ rope_init_fn = (
688
+ patched__compute_dynamic_ntk_parameters
689
+ if self.rope_init_fn
690
+ is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
691
+ else self.rope_init_fn
692
+ )
543
693
  seq_len = torch.max(position_ids) + 1
544
694
  if hasattr(self.config, "original_max_position_embeddings"):
545
695
  original_max_position_embeddings = self.config.original_max_position_embeddings
546
696
  else:
547
697
  original_max_position_embeddings = self.config.max_position_embeddings
548
698
  # At export time, seq_len is unknown.
549
- long_inv_freq, _ = self.rope_init_fn(
699
+ long_inv_freq, _ = rope_init_fn(
550
700
  self.config, device, seq_len=original_max_position_embeddings + 1
551
701
  )
552
702
  original_inv_freq = self.original_inv_freq.to(device)
@@ -565,21 +715,70 @@ def patched_dynamic_rope_update(rope_forward):
565
715
  # self.inv_freq = self.original_inv_freq
566
716
 
567
717
  def dynamic_frequency_update(self, position_ids, device):
718
+ # constructor:
719
+ # - self.max_seq_len_cached = config.max_position_embeddings
720
+ # - self.original_max_seq_len = config.max_position_embeddings
721
+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
722
+
723
+ # It is no use to patch the function after the model is created
724
+ # as rope_init_fn is an attribute set to one function when the model
725
+ # is created and when no patch is applied yet.
726
+ # So we select the patched version here.
727
+ rope_init_fn = (
728
+ patched__compute_dynamic_ntk_parameters
729
+ if self.rope_init_fn
730
+ is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
731
+ else self.rope_init_fn
732
+ )
733
+
734
+ # This behaviour is difficult to translate.
735
+ # The sequence always grows.
736
+ # The test should always True.
737
+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
738
+ #
739
+ # if seq_len > self.max_seq_len_cached: # growth
740
+ # inv_freq, self.attention_scaling = self.rope_init_fn(
741
+ # self.config, device, seq_len=seq_len
742
+ # )
743
+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
744
+ # self.max_seq_len_cached = seq_len
745
+ #
746
+ # So we should not need what follows.
747
+ #
748
+ # cond = (seq_len > self.max_seq_len_cached).item()
749
+ # self.attention_scaling = torch.cond(
750
+ # cond,
751
+ # (lambda x, y: x.clone()),
752
+ # (lambda x, y: y.clone()),
753
+ # [attention_scaling, self.attention_scaling],
754
+ # )
755
+
568
756
  seq_len = torch.max(position_ids) + 1
569
- if seq_len > self.max_seq_len_cached: # growth
570
- inv_freq, self.attention_scaling = self.rope_init_fn(
571
- self.config, device, seq_len=seq_len
572
- )
573
- self.register_buffer("inv_freq", inv_freq, persistent=False)
574
- self.max_seq_len_cached = seq_len
757
+ long_inv_freq, self.attention_scaling = rope_init_fn(
758
+ self.config, device, seq_len=seq_len
759
+ )
575
760
 
576
- if (
577
- seq_len < self.original_max_seq_len
578
- and self.max_seq_len_cached > self.original_max_seq_len
579
- ):
580
- self.original_inv_freq = self.original_inv_freq.to(device)
581
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
582
- self.max_seq_len_cached = self.original_max_seq_len
761
+ # Second test to translate.
762
+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
763
+ # But in that case the following condition is a way to restore the original cache.
764
+
765
+ # if (
766
+ # seq_len < self.original_max_seq_len
767
+ # and self.max_seq_len_cached > self.original_max_seq_len
768
+ # ):
769
+ # self.original_inv_freq = self.original_inv_freq.to(device)
770
+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
771
+ # self.max_seq_len_cached = self.original_max_seq_len
772
+
773
+ original_inv_freq = self.original_inv_freq.to(device)
774
+ cond = (seq_len >= self.original_max_seq_len).item()
775
+ inv_freq = torch.cond(
776
+ cond,
777
+ (lambda x, y: x.clone()),
778
+ (lambda x, y: y.clone()),
779
+ [long_inv_freq, original_inv_freq],
780
+ )
781
+ self.inv_freq = inv_freq
583
782
 
584
783
  @wraps(rope_forward)
585
784
  def wrapper(self, x, position_ids):
@@ -619,3 +818,152 @@ class patched_Phi3RotaryEmbedding(torch.nn.Module):
619
818
  sin = emb.sin() * self.attention_scaling
620
819
 
621
820
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
821
+
822
+
823
+ class patched_IdeficsEmbedding(torch.nn.Module):
824
+ _PATCHES_ = ["forward"]
825
+ _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
826
+
827
+ def forward(self, x, seq_len=None):
828
+ # x: [bs, num_attention_heads, seq_len, head_size]
829
+ # if seq_len > self.max_seq_len_cached:
830
+ # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
831
+
832
+ def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
833
+ t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
834
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
835
+ emb = torch.cat((freqs, freqs), dim=-1)
836
+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
837
+
838
+ def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
839
+ torch._check(seq_len.item() <= cos_cached.shape[0])
840
+ co = cos_cached[: seq_len.item()].detach().clone()
841
+ torch._check(seq_len.item() <= sin_cached.shape[0])
842
+ si = sin_cached[: seq_len.item()].detach().clone()
843
+ return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
844
+
845
+ cos_cached, sin_cached = torch.cond(
846
+ (seq_len > self.max_seq_len_cached).item(),
847
+ _set_cos_sin_cache_then,
848
+ _set_cos_sin_cache_else,
849
+ [x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
850
+ )
851
+ return cos_cached, sin_cached
852
+
853
+
854
+ class patched_IdeficsAttention(torch.nn.Module):
855
+ _PATCHES_ = ["forward"]
856
+ _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsAttention
857
+
858
+ def forward(
859
+ self,
860
+ hidden_states: torch.Tensor,
861
+ key_value_states: Optional[torch.Tensor] = None,
862
+ attention_mask: Optional[torch.Tensor] = None,
863
+ position_ids: Optional[torch.LongTensor] = None,
864
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
865
+ output_attentions: bool = False,
866
+ use_cache: bool = False,
867
+ cache_position: Optional[torch.LongTensor] = None,
868
+ **kwargs,
869
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
870
+ # if key_value_states are provided this layer is used as a cross-attention layer
871
+ is_cross_attention = self.is_cross_attention or key_value_states is not None
872
+
873
+ bsz, q_len, _ = hidden_states.size()
874
+
875
+ query_states = (
876
+ self.q_proj(hidden_states)
877
+ .view(bsz, q_len, self.num_heads, self.head_dim)
878
+ .transpose(1, 2)
879
+ )
880
+ if not is_cross_attention:
881
+ key_states = (
882
+ self.k_proj(hidden_states)
883
+ .view(bsz, q_len, self.num_heads, self.head_dim)
884
+ .transpose(1, 2)
885
+ )
886
+ value_states = (
887
+ self.v_proj(hidden_states)
888
+ .view(bsz, q_len, self.num_heads, self.head_dim)
889
+ .transpose(1, 2)
890
+ )
891
+ else:
892
+ _, kv_len, _ = (
893
+ key_value_states.size()
894
+ ) # Note that, in this case, `kv_len` == `kv_seq_len`
895
+ key_states = (
896
+ self.k_proj(key_value_states)
897
+ .view(bsz, kv_len, self.num_heads, self.head_dim)
898
+ .transpose(1, 2)
899
+ )
900
+ value_states = (
901
+ self.v_proj(key_value_states)
902
+ .view(bsz, kv_len, self.num_heads, self.head_dim)
903
+ .transpose(1, 2)
904
+ )
905
+
906
+ kv_seq_len = key_states.shape[-2]
907
+ if past_key_value is not None:
908
+ kv_seq_len += cache_position[0]
909
+
910
+ if not is_cross_attention:
911
+ rotary_length = torch.maximum(
912
+ torch.tensor(kv_seq_len, dtype=torch.int64),
913
+ torch.tensor(q_len, dtype=torch.int64),
914
+ )
915
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
916
+ query_states, key_states = (
917
+ transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
918
+ query_states, key_states, cos, sin, position_ids
919
+ )
920
+ )
921
+ # [bsz, nh, t, hd]
922
+
923
+ if past_key_value is not None:
924
+ # sin and cos are specific to RoPE models;
925
+ # cache_position needed for the static cache
926
+ cache_kwargs = {"cache_position": cache_position}
927
+ key_states, value_states = past_key_value.update(
928
+ key_states, value_states, self.layer_idx, cache_kwargs
929
+ )
930
+
931
+ if self.qk_layer_norms:
932
+ query_states = self.q_layer_norm(query_states)
933
+ key_states = self.k_layer_norm(key_states)
934
+
935
+ attention_interface: Callable = (
936
+ transformers.models.idefics.modeling_idefics.eager_attention_forward
937
+ )
938
+
939
+ if self.config._attn_implementation != "eager":
940
+ if self.config._attn_implementation == "sdpa" and output_attentions:
941
+ transformers.models.idefics.modeling_idefics.logger.warning_once(
942
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
943
+ "`output_attentions=True`. Falling back to "
944
+ "eager attention. This warning can be removed using the argument "
945
+ '`attn_implementation="eager"` when loading the model.'
946
+ )
947
+ else:
948
+ attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
949
+ self.config._attn_implementation
950
+ ]
951
+
952
+ attn_output, attn_weights = attention_interface(
953
+ self,
954
+ query_states,
955
+ key_states,
956
+ value_states,
957
+ attention_mask,
958
+ dropout=0.0 if not self.training else self.dropout,
959
+ scaling=self.scaling,
960
+ **kwargs,
961
+ )
962
+
963
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
964
+ attn_output = self.o_proj(attn_output)
965
+
966
+ if output_attentions:
967
+ attn_weights = None
968
+
969
+ return attn_output, attn_weights, past_key_value
@@ -2,9 +2,11 @@ import copy
2
2
  import functools
3
3
  import json
4
4
  import os
5
+ import pprint
6
+ import sys
5
7
  from typing import Any, Dict, List, Optional, Union
6
8
  import transformers
7
- from huggingface_hub import HfApi, model_info, hf_hub_download
9
+ from huggingface_hub import HfApi, model_info, hf_hub_download, list_repo_files
8
10
  from ...helpers.config_helper import update_config
9
11
  from . import hub_data_cached_configs
10
12
  from .hub_data import __date__, __data_tasks__, load_architecture_task, __data_arch_values__
@@ -33,10 +35,14 @@ def _retrieve_cached_configurations() -> Dict[str, transformers.PretrainedConfig
33
35
  return res
34
36
 
35
37
 
36
- def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.PretrainedConfig]:
38
+ def get_cached_configuration(
39
+ name: str, exc: bool = False, **kwargs
40
+ ) -> Optional[transformers.PretrainedConfig]:
37
41
  """
38
42
  Returns cached configuration to avoid having to many accesses to internet.
39
43
  It returns None if not Cache. The list of cached models follows.
44
+ If *exc* is True or if environment variable ``NOHTTP`` is defined,
45
+ the function raises an exception if *name* is not found.
40
46
 
41
47
  .. runpython::
42
48
 
@@ -54,8 +60,11 @@ def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.Pretr
54
60
  conf = copy.deepcopy(conf)
55
61
  update_config(conf, kwargs)
56
62
  return conf
57
- if os.environ.get("NOHTTP", ""):
58
- raise AssertionError(f"Unable to find {name!r} in {sorted(cached)}")
63
+ assert not exc and not os.environ.get("NOHTTP", ""), (
64
+ f"Unable to find {name!r} (exc={exc}, "
65
+ f"NOHTTP={os.environ.get('NOHTTP', '')!r}) "
66
+ f"in {pprint.pformat(sorted(cached))}"
67
+ )
59
68
  return None
60
69
 
61
70
 
@@ -64,6 +73,7 @@ def get_pretrained_config(
64
73
  trust_remote_code: bool = True,
65
74
  use_preinstalled: bool = True,
66
75
  subfolder: Optional[str] = None,
76
+ use_only_preinstalled: bool = False,
67
77
  **kwargs,
68
78
  ) -> Any:
69
79
  """
@@ -77,13 +87,20 @@ def get_pretrained_config(
77
87
  :func:`get_cached_configuration`, the cached list is mostly for
78
88
  unit tests
79
89
  :param subfolder: subfolder for the given model id
90
+ :param use_only_preinstalled: if True, raises an exception if not preinstalled
80
91
  :param kwargs: additional kwargs
81
92
  :return: a configuration
82
93
  """
83
94
  if use_preinstalled:
84
- conf = get_cached_configuration(model_id, subfolder=subfolder, **kwargs)
95
+ conf = get_cached_configuration(
96
+ model_id, exc=use_only_preinstalled, subfolder=subfolder, **kwargs
97
+ )
85
98
  if conf is not None:
86
99
  return conf
100
+ assert not use_only_preinstalled, (
101
+ f"Inconsistencies: use_only_preinstalled={use_only_preinstalled}, "
102
+ f"use_preinstalled={use_preinstalled!r}"
103
+ )
87
104
  if subfolder:
88
105
  try:
89
106
  return transformers.AutoConfig.from_pretrained(
@@ -122,12 +139,15 @@ def _guess_task_from_config(config: Any) -> Optional[str]:
122
139
 
123
140
 
124
141
  @functools.cache
125
- def task_from_arch(arch: str, default_value: Optional[str] = None) -> str:
142
+ def task_from_arch(
143
+ arch: str, default_value: Optional[str] = None, model_id: Optional[str] = None
144
+ ) -> str:
126
145
  """
127
146
  This function relies on stored information. That information needs to be refresh.
128
147
 
129
148
  :param arch: architecture name
130
149
  :param default_value: default value in case the task cannot be determined
150
+ :param model_id: unused unless the architecture does not help.
131
151
  :return: task
132
152
 
133
153
  .. runpython::
@@ -140,9 +160,16 @@ def task_from_arch(arch: str, default_value: Optional[str] = None) -> str:
140
160
  <onnx_diagnostic.torch_models.hghub.hub_data.load_architecture_task>`.
141
161
  """
142
162
  data = load_architecture_task()
163
+ if arch not in data and model_id:
164
+ # Let's try with the model id.
165
+ return task_from_id(model_id)
143
166
  if default_value is not None:
144
167
  return data.get(arch, default_value)
145
- assert arch in data, f"Architecture {arch!r} is unknown, last refresh in {__date__}"
168
+ assert arch in data, (
169
+ f"Architecture {arch!r} is unknown, last refresh in {__date__}. "
170
+ f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__`` "
171
+ f"needs to be updated (model_id={(model_id or '?')!r})."
172
+ )
146
173
  return data[arch]
147
174
 
148
175
 
@@ -160,6 +187,7 @@ def task_from_id(
160
187
  if the task cannot be determined
161
188
  :param pretrained: uses the config
162
189
  :param fall_back_to_pretrained: falls back to pretrained config
190
+ :param exc: raises an exception if True
163
191
  :return: task
164
192
  """
165
193
  if not pretrained:
@@ -175,9 +203,14 @@ def task_from_id(
175
203
  guess = _guess_task_from_config(config)
176
204
  if guess is not None:
177
205
  return guess
206
+ data = load_architecture_task()
207
+ if model_id in data:
208
+ return data[model_id]
178
209
  assert config.architectures is not None and len(config.architectures) == 1, (
179
210
  f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
180
- f"architectures={config.architectures} in config={config}"
211
+ f"architectures={config.architectures} in config={config}. "
212
+ f"The task can be added in "
213
+ f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__``."
181
214
  )
182
215
  return task_from_arch(config.architectures[0], default_value=default_value)
183
216
 
@@ -295,3 +328,43 @@ def enumerate_model_list(
295
328
  n -= 1
296
329
  if n == 0:
297
330
  break
331
+
332
+
333
+ def download_code_modelid(
334
+ model_id: str, verbose: int = 0, add_path_to_sys_path: bool = True
335
+ ) -> List[str]:
336
+ """
337
+ Downloads the code for a given model id.
338
+
339
+ :param model_id: model id
340
+ :param verbose: verbosity
341
+ :param add_path_to_sys_path: add folder where the files are downloaded to sys.path
342
+ :return: list of downloaded files
343
+ """
344
+ if verbose:
345
+ print(f"[download_code_modelid] retrieve file list for {model_id!r}")
346
+ files = list_repo_files(model_id)
347
+ pyfiles = [name for name in files if os.path.splitext(name)[-1] == ".py"]
348
+ if verbose:
349
+ print(f"[download_code_modelid] python files {pyfiles}")
350
+ absfiles = []
351
+ paths = set()
352
+ for i, name in enumerate(pyfiles):
353
+ if verbose:
354
+ print(f"[download_code_modelid] download file {i+1}/{len(pyfiles)}: {name!r}")
355
+ r = hf_hub_download(repo_id=model_id, filename=name)
356
+ p = os.path.split(r)[0]
357
+ paths.add(p)
358
+ absfiles.append(r)
359
+ if add_path_to_sys_path:
360
+ for p in paths:
361
+ init = os.path.join(p, "__init__.py")
362
+ if not os.path.exists(init):
363
+ with open(init, "w"):
364
+ pass
365
+ if p in sys.path:
366
+ continue
367
+ if verbose:
368
+ print(f"[download_code_modelid] add {p!r} to 'sys.path'")
369
+ sys.path.insert(0, p)
370
+ return absfiles
@@ -3,7 +3,7 @@ import functools
3
3
  import textwrap
4
4
  from typing import Dict, List
5
5
 
6
- __date__ = "2025-03-26"
6
+ __date__ = "2025-06-21"
7
7
 
8
8
  __data_arch_values__ = {"ResNetForImageClassification": dict(image_size=224)}
9
9
 
@@ -52,6 +52,8 @@ __data_arch__ = textwrap.dedent(
52
52
  GPTNeoModel,feature-extraction
53
53
  GPTNeoXForCausalLM,text-generation
54
54
  GemmaForCausalLM,text-generation
55
+ Gemma2ForCausalLM,text-generation
56
+ Gemma3ForConditionalGeneration,image-text-to-text
55
57
  GraniteForCausalLM,text-generation
56
58
  GroupViTModel,feature-extraction
57
59
  HieraForImageClassification,image-classification
@@ -97,6 +99,7 @@ __data_arch__ = textwrap.dedent(
97
99
  PegasusModel,feature-extraction
98
100
  Phi3ForCausalLM,text-generation
99
101
  PhiForCausalLM,text-generation
102
+ PhiMoEForCausalLM,text-generation
100
103
  Pix2StructForConditionalGeneration,image-to-text
101
104
  PLBartForConditionalGeneration,text2text-generation
102
105
  PoolFormerModel,image-feature-extraction
@@ -144,7 +147,8 @@ __data_arch__ = textwrap.dedent(
144
147
  XLMRobertaModel,sentence-similarity
145
148
  Wav2Vec2ForCTC,automatic-speech-recognition
146
149
  YolosForObjectDetection,object-detection
147
- YolosModel,image-feature-extraction"""
150
+ YolosModel,image-feature-extraction
151
+ emilyalsentzer/Bio_ClinicalBERT,fill-mask"""
148
152
  )
149
153
 
150
154
  __data_tasks__ = [