keras-hub 0.21.1.dev0__py3-none-any.whl → 0.22.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +9 -0
- keras_hub/models/__init__.py +47 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +6 -3
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +17 -3
- keras_hub/src/layers/preprocessing/start_end_packer.py +24 -6
- keras_hub/src/models/backbone.py +13 -10
- keras_hub/src/models/clip/clip_backbone.py +3 -102
- keras_hub/src/models/clip/clip_layers.py +295 -0
- keras_hub/src/models/clip/clip_preprocessor.py +57 -48
- keras_hub/src/models/clip/clip_text_encoder.py +2 -2
- keras_hub/src/models/clip/clip_vision_encoder.py +3 -3
- keras_hub/src/models/deit/__init__.py +5 -0
- keras_hub/src/models/deit/deit_backbone.py +154 -0
- keras_hub/src/models/deit/deit_image_classifier.py +171 -0
- keras_hub/src/models/deit/deit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/deit/deit_image_converter.py +8 -0
- keras_hub/src/models/deit/deit_layers.py +519 -0
- keras_hub/src/models/deit/deit_presets.py +49 -0
- keras_hub/src/models/dinov2/__init__.py +5 -0
- keras_hub/src/models/dinov2/dinov2_backbone.py +228 -0
- keras_hub/src/models/dinov2/dinov2_image_converter.py +8 -0
- keras_hub/src/models/dinov2/dinov2_layers.py +886 -0
- keras_hub/src/models/dinov2/dinov2_presets.py +89 -0
- keras_hub/src/models/esm/__init__.py +5 -0
- keras_hub/src/models/esm/esm_attention.py +95 -0
- keras_hub/src/models/esm/esm_backbone.py +229 -0
- keras_hub/src/models/esm/esm_classifier.py +184 -0
- keras_hub/src/models/esm/esm_classifier_preprocessor.py +135 -0
- keras_hub/src/models/esm/esm_encoder.py +134 -0
- keras_hub/src/models/esm/esm_masked_plm.py +117 -0
- keras_hub/src/models/esm/esm_masked_plm_preprocessor.py +143 -0
- keras_hub/src/models/esm/esm_presets.py +53 -0
- keras_hub/src/models/esm/esm_tokenizer.py +82 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +6 -2
- keras_hub/src/models/gemma/gemma_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +2 -2
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +1 -1
- keras_hub/src/models/hgnetv2/__init__.py +5 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +193 -0
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +148 -0
- keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py +216 -0
- keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_preprocessor.py +14 -0
- keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py +8 -0
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +918 -0
- keras_hub/src/models/hgnetv2/hgnetv2_presets.py +58 -0
- keras_hub/src/models/llama3/llama3_presets.py +3 -3
- keras_hub/src/models/mistral/mistral_presets.py +17 -1
- keras_hub/src/models/mixtral/mixtral_presets.py +2 -2
- keras_hub/src/models/mobilenet/mobilenet_presets.py +4 -4
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +2 -2
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +2 -2
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +17 -17
- keras_hub/src/models/qwen3/__init__.py +5 -0
- keras_hub/src/models/qwen3/qwen3_attention.py +369 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +191 -0
- keras_hub/src/models/qwen3/qwen3_causal_lm.py +390 -0
- keras_hub/src/models/qwen3/qwen3_causal_lm_preprocessor.py +10 -0
- keras_hub/src/models/qwen3/qwen3_decoder.py +309 -0
- keras_hub/src/models/qwen3/qwen3_layernorm.py +38 -0
- keras_hub/src/models/qwen3/qwen3_presets.py +73 -0
- keras_hub/src/models/qwen3/qwen3_tokenizer.py +48 -0
- keras_hub/src/models/qwen_moe/qwen_moe_attention.py +1 -0
- keras_hub/src/models/qwen_moe/qwen_moe_presets.py +2 -2
- keras_hub/src/models/roformer_v2/roformer_v2_attention.py +0 -2
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +16 -7
- keras_hub/src/models/stable_diffusion_3/mmdit.py +61 -4
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +31 -32
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +6 -2
- keras_hub/src/models/vit/vit_backbone.py +31 -11
- keras_hub/src/models/vit/vit_image_converter.py +0 -70
- keras_hub/src/models/vit/vit_layers.py +33 -18
- keras_hub/src/models/vit/vit_presets.py +11 -11
- keras_hub/src/utils/keras_utils.py +17 -0
- keras_hub/src/utils/preset_utils.py +19 -4
- keras_hub/src/utils/tensor_utils.py +14 -0
- keras_hub/src/utils/transformers/convert_deit.py +155 -0
- keras_hub/src/utils/transformers/convert_dinov2.py +180 -0
- keras_hub/src/utils/transformers/convert_esm.py +159 -0
- keras_hub/src/utils/transformers/convert_llama3.py +6 -0
- keras_hub/src/utils/transformers/convert_qwen3.py +145 -0
- keras_hub/src/utils/transformers/export/gemma.py +89 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +98 -0
- keras_hub/src/utils/transformers/preset_loader.py +14 -2
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +1 -0
- {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dev0.dist-info}/METADATA +4 -4
- {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dev0.dist-info}/RECORD +92 -48
- keras_hub/src/models/clip/clip_encoder_block.py +0 -111
- keras_hub/src/models/clip/clip_vision_embedding.py +0 -101
- {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dev0.dist-info}/WHEEL +0 -0
- {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""DINOV2 model preset configurations."""
|
|
2
|
+
|
|
3
|
+
# Metadata for loading pretrained model weights.
|
|
4
|
+
backbone_presets = {
|
|
5
|
+
"dinov2_small": {
|
|
6
|
+
"metadata": {
|
|
7
|
+
"description": (
|
|
8
|
+
"Vision Transformer (small-sized model) trained using DINOv2."
|
|
9
|
+
),
|
|
10
|
+
"params": 22_582_656,
|
|
11
|
+
"path": "dinov2",
|
|
12
|
+
},
|
|
13
|
+
"kaggle_handle": "kaggle://keras/dinov2/keras/dinov2_small/1",
|
|
14
|
+
},
|
|
15
|
+
"dinov2_base": {
|
|
16
|
+
"metadata": {
|
|
17
|
+
"description": (
|
|
18
|
+
"Vision Transformer (base-sized model) trained using DINOv2."
|
|
19
|
+
),
|
|
20
|
+
"params": 87_632_640,
|
|
21
|
+
"path": "dinov2",
|
|
22
|
+
},
|
|
23
|
+
"kaggle_handle": "kaggle://keras/dinov2/keras/dinov2_base/1",
|
|
24
|
+
},
|
|
25
|
+
"dinov2_large": {
|
|
26
|
+
"metadata": {
|
|
27
|
+
"description": (
|
|
28
|
+
"Vision Transformer (large-sized model) trained using DINOv2."
|
|
29
|
+
),
|
|
30
|
+
"params": 305_771_520,
|
|
31
|
+
"path": "dinov2",
|
|
32
|
+
},
|
|
33
|
+
"kaggle_handle": "kaggle://keras/dinov2/keras/dinov2_large/1",
|
|
34
|
+
},
|
|
35
|
+
"dinov2_giant": {
|
|
36
|
+
"metadata": {
|
|
37
|
+
"description": (
|
|
38
|
+
"Vision Transformer (giant-sized model) trained using DINOv2."
|
|
39
|
+
),
|
|
40
|
+
"params": 1_138_585_088,
|
|
41
|
+
"path": "dinov2",
|
|
42
|
+
},
|
|
43
|
+
"kaggle_handle": "kaggle://keras/dinov2/keras/dinov2_giant/1",
|
|
44
|
+
},
|
|
45
|
+
"dinov2_with_registers_small": {
|
|
46
|
+
"metadata": {
|
|
47
|
+
"description": (
|
|
48
|
+
"Vision Transformer (small-sized model) trained using DINOv2, "
|
|
49
|
+
"with registers."
|
|
50
|
+
),
|
|
51
|
+
"params": 22_584_192,
|
|
52
|
+
"path": "dinov2",
|
|
53
|
+
},
|
|
54
|
+
"kaggle_handle": "kaggle://keras/dinov2/keras/dinov2_with_registers_small/1",
|
|
55
|
+
},
|
|
56
|
+
"dinov2_with_registers_base": {
|
|
57
|
+
"metadata": {
|
|
58
|
+
"description": (
|
|
59
|
+
"Vision Transformer (base-sized model) trained using DINOv2, "
|
|
60
|
+
"with registers."
|
|
61
|
+
),
|
|
62
|
+
"params": 87_635_712,
|
|
63
|
+
"path": "dinov2",
|
|
64
|
+
},
|
|
65
|
+
"kaggle_handle": "kaggle://keras/dinov2/keras/dinov2_with_registers_base/1",
|
|
66
|
+
},
|
|
67
|
+
"dinov2_with_registers_large": {
|
|
68
|
+
"metadata": {
|
|
69
|
+
"description": (
|
|
70
|
+
"Vision Transformer (large-sized model) trained using DINOv2, "
|
|
71
|
+
"with registers."
|
|
72
|
+
),
|
|
73
|
+
"params": 305_775_616,
|
|
74
|
+
"path": "dinov2",
|
|
75
|
+
},
|
|
76
|
+
"kaggle_handle": "kaggle://keras/dinov2/keras/dinov2_with_registers_large/1",
|
|
77
|
+
},
|
|
78
|
+
"dinov2_with_registers_giant": {
|
|
79
|
+
"metadata": {
|
|
80
|
+
"description": (
|
|
81
|
+
"Vision Transformer (giant-sized model) trained using DINOv2, "
|
|
82
|
+
"with registers."
|
|
83
|
+
),
|
|
84
|
+
"params": 1_138_591_232,
|
|
85
|
+
"path": "dinov2",
|
|
86
|
+
},
|
|
87
|
+
"kaggle_handle": "kaggle://keras/dinov2/keras/dinov2_with_registers_giant/1",
|
|
88
|
+
},
|
|
89
|
+
}
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import ops
|
|
3
|
+
from packaging import version
|
|
4
|
+
|
|
5
|
+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
|
6
|
+
from keras_hub.src.models.roformer_v2.roformer_v2_attention import (
|
|
7
|
+
RoformerAttention,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ESMRotaryEmbedding(RotaryEmbedding):
|
|
12
|
+
def _compute_cos_sin_embedding(self, x, position=1):
|
|
13
|
+
dim = x.shape[-1]
|
|
14
|
+
inv_freq = self.scaling_factor / (
|
|
15
|
+
self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim)
|
|
16
|
+
)
|
|
17
|
+
t = ops.arange(x.shape[position], dtype=x.dtype)
|
|
18
|
+
freqs = ops.outer(t, inv_freq)
|
|
19
|
+
emb = ops.concatenate((freqs, freqs), axis=-1)
|
|
20
|
+
|
|
21
|
+
cos_emb = ops.cos(emb)[None, :, None, :]
|
|
22
|
+
sin_emb = ops.sin(emb)[None, :, None, :]
|
|
23
|
+
return cos_emb, sin_emb
|
|
24
|
+
|
|
25
|
+
def call(self, q, k, position=1):
|
|
26
|
+
cos_emb, sin_emb = self._compute_cos_sin_embedding(q, position)
|
|
27
|
+
|
|
28
|
+
return (
|
|
29
|
+
self.apply_rotary_pos_emb(q, cos_emb, sin_emb),
|
|
30
|
+
self.apply_rotary_pos_emb(k, cos_emb, sin_emb),
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def rotate_half(self, x):
|
|
34
|
+
x1, x2 = ops.split(x, 2, -1)
|
|
35
|
+
return ops.concatenate((-x2, x1), axis=-1)
|
|
36
|
+
|
|
37
|
+
def apply_rotary_pos_emb(self, x, cos, sin):
|
|
38
|
+
cos = cos[:, : x.shape[1], :, :]
|
|
39
|
+
sin = sin[:, : x.shape[1], :, :]
|
|
40
|
+
|
|
41
|
+
return (x * cos) + (self.rotate_half(x) * sin)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class EsmSelfAttention(RoformerAttention):
|
|
45
|
+
"""MultiHeadAttention by ESM2
|
|
46
|
+
|
|
47
|
+
Referred to the implementation of HuggingFace.
|
|
48
|
+
In fact, this part of the calculation is exactly the same as RoFormer.
|
|
49
|
+
Only the calculation of the rotary part is different.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, use_rotary=True, **kwargs):
|
|
53
|
+
super().__init__(**kwargs)
|
|
54
|
+
self.use_rotary = use_rotary
|
|
55
|
+
|
|
56
|
+
def build(self, input_shape):
|
|
57
|
+
super().build(input_shape)
|
|
58
|
+
if self.use_rotary:
|
|
59
|
+
self.rotary_embedding_layer = ESMRotaryEmbedding(
|
|
60
|
+
max_wavelength=self.max_wavelength, dtype=self.dtype_policy
|
|
61
|
+
)
|
|
62
|
+
self.rotary_embedding_layer.build([])
|
|
63
|
+
|
|
64
|
+
def call(self, x, attention_mask=None):
|
|
65
|
+
qw = self.q_dense(x)
|
|
66
|
+
kw = self.k_dense(x)
|
|
67
|
+
vw = self.v_dense(x)
|
|
68
|
+
|
|
69
|
+
b, s = ops.shape(qw)[:2]
|
|
70
|
+
qw = ops.reshape(qw, (b, s, self.heads, self.head_size))
|
|
71
|
+
kw = ops.reshape(kw, (b, s, self.heads, self.head_size))
|
|
72
|
+
vw = ops.reshape(vw, (b, s, self.heads, self.head_size))
|
|
73
|
+
|
|
74
|
+
if self.use_rotary:
|
|
75
|
+
qw, kw = self.rotary_embedding_layer(qw, kw)
|
|
76
|
+
if version.parse(keras.__version__) < version.parse("3.6"):
|
|
77
|
+
raise ValueError("Please make sure your Keras version is >=3.6.")
|
|
78
|
+
flash_attention = keras.config.is_flash_attention_enabled()
|
|
79
|
+
attention_mask = ops.reshape(attention_mask, [b, 1, s, 1])
|
|
80
|
+
if keras.config.backend() == "torch":
|
|
81
|
+
attention_mask = ops.repeat(attention_mask, s, -1)
|
|
82
|
+
attention_mask = ops.transpose(attention_mask, [0, 1, 3, 2])
|
|
83
|
+
o = ops.dot_product_attention(
|
|
84
|
+
qw, kw, vw, mask=attention_mask, flash_attention=flash_attention
|
|
85
|
+
)
|
|
86
|
+
return self.o_dense(ops.reshape(o, [b, s, -1]))
|
|
87
|
+
|
|
88
|
+
def get_config(self):
|
|
89
|
+
config = super().get_config()
|
|
90
|
+
config.update(
|
|
91
|
+
{
|
|
92
|
+
"use_rotary": self.use_rotary,
|
|
93
|
+
}
|
|
94
|
+
)
|
|
95
|
+
return config
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import activations
|
|
3
|
+
|
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
+
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
6
|
+
from keras_hub.src.models.backbone import Backbone
|
|
7
|
+
from keras_hub.src.models.esm.esm_encoder import ESMEncoder
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def esm2_kernel_initializer(stddev=0.02):
|
|
11
|
+
return keras.initializers.TruncatedNormal(stddev=stddev)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@keras_hub_export(
|
|
15
|
+
["keras_hub.models.ESM2Backbone", "keras_hub.models.ESMBackbone"]
|
|
16
|
+
)
|
|
17
|
+
class ESMBackbone(Backbone):
|
|
18
|
+
"""A ESM2 and ESM encoder network.
|
|
19
|
+
|
|
20
|
+
This class implements a bi-directional Transformer-based encoder as
|
|
21
|
+
described in ["ESM"](https://github.com/facebookresearch/esm).
|
|
22
|
+
|
|
23
|
+
The default constructor gives a fully customizable, randomly initialized
|
|
24
|
+
ESM2 encoder with any number of layers, heads, and embed dim.To
|
|
25
|
+
load preset architectures and weights, use the `from_preset()` constructor.
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
vocabulary_size: int. The size of the token vocabulary.
|
|
30
|
+
num_layers: int. The number of transformer layers.
|
|
31
|
+
num_heads: int. The number of attention heads for each transformer.
|
|
32
|
+
The hidden size must be divisible by the number of attention heads.
|
|
33
|
+
hidden_dim: int. The size of the transformer encoding and pooler layers.
|
|
34
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
|
35
|
+
a two-layer feedforward network for each transformer.
|
|
36
|
+
dropout: float. Dropout probability for the Transformer encoder.
|
|
37
|
+
Defaults to 0.1
|
|
38
|
+
use_pre_layer_norm:bool.If true, then layer norm will be used before
|
|
39
|
+
entering the transformer block.
|
|
40
|
+
Since it's pre-norm, the default is false.
|
|
41
|
+
max_sequence_length: int. The maximum sequence length that this encoder
|
|
42
|
+
can consume. If None, `max_sequence_length` uses the value from
|
|
43
|
+
sequence length. This determines the variable shape for positional
|
|
44
|
+
embeddings.
|
|
45
|
+
position_embedding_type: str. The position embedding type to use.
|
|
46
|
+
One of "absolute" and "rotary".
|
|
47
|
+
Use "absolute" for ESM1. Use "rotary" for ESM2. Defaults to "rotary"
|
|
48
|
+
max_wavelength : int. The maximum angular wavelength of
|
|
49
|
+
the sine/cosine curves, for rotary embeddings.
|
|
50
|
+
Defaults to `10000`.
|
|
51
|
+
activation :string or keras.activations. The activation to
|
|
52
|
+
use for the transformer.
|
|
53
|
+
Defaults to `"gelu"`.
|
|
54
|
+
pad_token_id: int.padding token id. Normally 0,
|
|
55
|
+
but is set to 1 in the esm2 model.
|
|
56
|
+
Defaults to 0.
|
|
57
|
+
dtype: None or str or keras.mixed_precision.DTypePolicy. The dtype to
|
|
58
|
+
use for model computations and weights. Note that some computations,
|
|
59
|
+
such as softmax and layer normalization, will always be done at
|
|
60
|
+
float32 precision regardless of dtype.
|
|
61
|
+
|
|
62
|
+
Examples:
|
|
63
|
+
```python
|
|
64
|
+
input_data = {
|
|
65
|
+
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
# Pretrained ESM2 encoder.
|
|
69
|
+
model = keras_hub.models.ESM2Backbone.from_preset('hf://facebook/esm2_t6_8M_UR50D')
|
|
70
|
+
model(input_data)
|
|
71
|
+
|
|
72
|
+
# Randomly initialized ESM2 encoder with a custom config.
|
|
73
|
+
model = keras_hub.models.ESM2Backbone(
|
|
74
|
+
vocabulary_size=30552,
|
|
75
|
+
num_layers=4,
|
|
76
|
+
num_heads=4,
|
|
77
|
+
hidden_dim=256,
|
|
78
|
+
intermediate_dim=512,
|
|
79
|
+
)
|
|
80
|
+
model(input_data)
|
|
81
|
+
```
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
vocabulary_size,
|
|
87
|
+
num_layers,
|
|
88
|
+
num_heads,
|
|
89
|
+
hidden_dim,
|
|
90
|
+
intermediate_dim,
|
|
91
|
+
use_bias=True,
|
|
92
|
+
activation="gelu",
|
|
93
|
+
dropout=0.1,
|
|
94
|
+
dtype=None,
|
|
95
|
+
max_sequence_length=1024,
|
|
96
|
+
max_wavelength=10000,
|
|
97
|
+
layer_norm_eps=1e-12,
|
|
98
|
+
use_pre_layer_norm=False,
|
|
99
|
+
position_embedding_type="rotary",
|
|
100
|
+
pad_token_id=0,
|
|
101
|
+
**kwargs,
|
|
102
|
+
):
|
|
103
|
+
if position_embedding_type not in (
|
|
104
|
+
"rotary",
|
|
105
|
+
"absolute",
|
|
106
|
+
):
|
|
107
|
+
raise ValueError(
|
|
108
|
+
'`position_embedding_type` must be either `"rotary"`, or '
|
|
109
|
+
'`"absolute"`. Received '
|
|
110
|
+
f"position_embedding_type={position_embedding_type}."
|
|
111
|
+
)
|
|
112
|
+
head_size = hidden_dim // num_heads
|
|
113
|
+
# === Layers ===
|
|
114
|
+
self.token_embedding = keras.layers.Embedding(
|
|
115
|
+
input_dim=vocabulary_size,
|
|
116
|
+
output_dim=hidden_dim,
|
|
117
|
+
embeddings_initializer=esm2_kernel_initializer(),
|
|
118
|
+
dtype=dtype,
|
|
119
|
+
name="token_embedding",
|
|
120
|
+
)
|
|
121
|
+
if position_embedding_type == "absolute":
|
|
122
|
+
self.position_embedding = PositionEmbedding(
|
|
123
|
+
initializer=esm2_kernel_initializer(),
|
|
124
|
+
sequence_length=max_sequence_length,
|
|
125
|
+
dtype=dtype,
|
|
126
|
+
name="position_embedding",
|
|
127
|
+
)
|
|
128
|
+
self.embeddings_add = keras.layers.Add(
|
|
129
|
+
dtype=dtype,
|
|
130
|
+
name="embeddings_add",
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
self.output_layer_norm = keras.layers.LayerNormalization(
|
|
134
|
+
epsilon=layer_norm_eps,
|
|
135
|
+
dtype=dtype,
|
|
136
|
+
name="output_layer_norm",
|
|
137
|
+
)
|
|
138
|
+
if use_pre_layer_norm:
|
|
139
|
+
self.emb_layer_norm = keras.layers.LayerNormalization(
|
|
140
|
+
epsilon=layer_norm_eps,
|
|
141
|
+
dtype=dtype,
|
|
142
|
+
name="emb_layer_norm",
|
|
143
|
+
)
|
|
144
|
+
self.transformer_layers = []
|
|
145
|
+
for i in range(num_layers):
|
|
146
|
+
layer = ESMEncoder(
|
|
147
|
+
heads=num_heads,
|
|
148
|
+
head_size=head_size,
|
|
149
|
+
intermediate_size=intermediate_dim,
|
|
150
|
+
use_bias=use_bias,
|
|
151
|
+
max_wavelength=max_wavelength,
|
|
152
|
+
dropout=dropout,
|
|
153
|
+
activation=activation,
|
|
154
|
+
kernel_initializer=esm2_kernel_initializer(),
|
|
155
|
+
layer_norm_eps=layer_norm_eps,
|
|
156
|
+
dtype=dtype,
|
|
157
|
+
use_rotary=position_embedding_type == "rotary",
|
|
158
|
+
name=f"transformer_layer_{i}",
|
|
159
|
+
)
|
|
160
|
+
self.transformer_layers.append(layer)
|
|
161
|
+
|
|
162
|
+
# === Functional Model ===
|
|
163
|
+
token_id_input = keras.Input(
|
|
164
|
+
shape=(None,), dtype="int32", name="token_ids"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
attention_mask = keras.ops.not_equal(token_id_input, pad_token_id)
|
|
168
|
+
|
|
169
|
+
token_vector = self.token_embedding(token_id_input)
|
|
170
|
+
if position_embedding_type == "absolute":
|
|
171
|
+
position_vector = self.position_embedding(
|
|
172
|
+
token_vector, start_index=pad_token_id
|
|
173
|
+
)
|
|
174
|
+
x = self.embeddings_add([token_vector, position_vector])
|
|
175
|
+
else:
|
|
176
|
+
x = token_vector
|
|
177
|
+
if use_pre_layer_norm:
|
|
178
|
+
x = self.emb_layer_norm(x)
|
|
179
|
+
for transformer_layer in self.transformer_layers:
|
|
180
|
+
x = transformer_layer(x, attention_mask=attention_mask)
|
|
181
|
+
output = self.output_layer_norm(x)
|
|
182
|
+
super().__init__(
|
|
183
|
+
inputs={
|
|
184
|
+
"token_ids": token_id_input,
|
|
185
|
+
},
|
|
186
|
+
outputs=output,
|
|
187
|
+
dtype=dtype,
|
|
188
|
+
**kwargs,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# === Config ===
|
|
192
|
+
self.vocabulary_size = vocabulary_size
|
|
193
|
+
self.num_layers = num_layers
|
|
194
|
+
self.num_heads = num_heads
|
|
195
|
+
self.hidden_dim = hidden_dim
|
|
196
|
+
self.intermediate_dim = intermediate_dim
|
|
197
|
+
self.dropout = dropout
|
|
198
|
+
self.max_wavelength = max_wavelength
|
|
199
|
+
self.head_size = head_size
|
|
200
|
+
self.activation = activations.get(activation)
|
|
201
|
+
self.use_bias = use_bias
|
|
202
|
+
self.start_token_index = 0
|
|
203
|
+
self.layer_norm_eps = layer_norm_eps
|
|
204
|
+
self.max_sequence_length = max_sequence_length
|
|
205
|
+
self.use_pre_layer_norm = use_pre_layer_norm
|
|
206
|
+
self.position_embedding_type = position_embedding_type
|
|
207
|
+
self.pad_token_id = pad_token_id
|
|
208
|
+
|
|
209
|
+
def get_config(self):
|
|
210
|
+
config = super().get_config()
|
|
211
|
+
config.update(
|
|
212
|
+
{
|
|
213
|
+
"vocabulary_size": self.vocabulary_size,
|
|
214
|
+
"num_layers": self.num_layers,
|
|
215
|
+
"num_heads": self.num_heads,
|
|
216
|
+
"hidden_dim": self.hidden_dim,
|
|
217
|
+
"intermediate_dim": self.intermediate_dim,
|
|
218
|
+
"dropout": self.dropout,
|
|
219
|
+
"max_wavelength": self.max_wavelength,
|
|
220
|
+
"use_bias": self.use_bias,
|
|
221
|
+
"activation": activations.serialize(self.activation),
|
|
222
|
+
"layer_norm_eps": self.layer_norm_eps,
|
|
223
|
+
"use_pre_layer_norm": self.use_pre_layer_norm,
|
|
224
|
+
"position_embedding_type": self.position_embedding_type,
|
|
225
|
+
"max_sequence_length": self.max_sequence_length,
|
|
226
|
+
"pad_token_id": self.pad_token_id,
|
|
227
|
+
}
|
|
228
|
+
)
|
|
229
|
+
return config
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
+
from keras_hub.src.models.esm.esm_backbone import ESMBackbone
|
|
5
|
+
from keras_hub.src.models.esm.esm_backbone import esm2_kernel_initializer
|
|
6
|
+
from keras_hub.src.models.esm.esm_classifier_preprocessor import (
|
|
7
|
+
ESMProteinClassifierPreprocessor,
|
|
8
|
+
)
|
|
9
|
+
from keras_hub.src.models.text_classifier import TextClassifier
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@keras_hub_export("keras_hub.models.ESMProteinClassifier")
|
|
13
|
+
class ESMProteinClassifier(TextClassifier):
|
|
14
|
+
"""An end-to-end ESM model for classification tasks.
|
|
15
|
+
|
|
16
|
+
This model attaches a classification head to
|
|
17
|
+
`keras_hub.models.ESMBackbone`, mapping from the backbone outputs
|
|
18
|
+
to logits suitable for a classification task. For usage of this model with
|
|
19
|
+
pre-trained weights, use the `from_preset()` constructor.
|
|
20
|
+
|
|
21
|
+
This model can optionally be configured with a `preprocessor` layer, in
|
|
22
|
+
which case it will automatically apply preprocessing to raw inputs during
|
|
23
|
+
`fit()`, `predict()`, and `evaluate()`. This is done by default when
|
|
24
|
+
creating the model with `from_preset()`.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
backbone: A `keras_hub.models.ESMBackbone` instance.
|
|
28
|
+
num_classes: int. Number of classes to predict.
|
|
29
|
+
preprocessor: A `keras_hub.models.ESMProteinClassifierPreprocessor`
|
|
30
|
+
or `None`. If `None`, this model will not apply preprocessing, and
|
|
31
|
+
inputs should be preprocessed before calling the model.
|
|
32
|
+
activation: Optional `str` or callable. The
|
|
33
|
+
activation function to use on the model outputs. Set
|
|
34
|
+
`activation="softmax"` to return output probabilities.
|
|
35
|
+
Defaults to `None`.
|
|
36
|
+
dropout: float. The dropout probability value, applied after the dense
|
|
37
|
+
layer.
|
|
38
|
+
|
|
39
|
+
Examples:
|
|
40
|
+
|
|
41
|
+
Raw string data.
|
|
42
|
+
```python
|
|
43
|
+
features = ["The quick brown fox jumped.", "I forgot my homework."]
|
|
44
|
+
labels = [0, 3]
|
|
45
|
+
|
|
46
|
+
# Pretrained classifier.
|
|
47
|
+
classifier = keras_hub.models.ESMProteinClassifier.from_preset(
|
|
48
|
+
hf://facebook/esm2_t6_8M_UR50D,
|
|
49
|
+
num_classes=4,
|
|
50
|
+
)
|
|
51
|
+
classifier.fit(x=features, y=labels, batch_size=2)
|
|
52
|
+
classifier.predict(x=features, batch_size=2)
|
|
53
|
+
|
|
54
|
+
# Re-compile (e.g., with a new learning rate).
|
|
55
|
+
classifier.compile(
|
|
56
|
+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
|
57
|
+
optimizer=keras.optimizers.Adam(5e-5),
|
|
58
|
+
jit_compile=True,
|
|
59
|
+
)
|
|
60
|
+
# Access backbone programmatically (e.g., to change `trainable`).
|
|
61
|
+
classifier.backbone.trainable = False
|
|
62
|
+
# Fit again.
|
|
63
|
+
classifier.fit(x=features, y=labels, batch_size=2)
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
Preprocessed integer data.
|
|
67
|
+
```python
|
|
68
|
+
features = {
|
|
69
|
+
"token_ids": np.ones(shape=(2, 12), dtype="int32"),
|
|
70
|
+
}
|
|
71
|
+
labels = [0, 3]
|
|
72
|
+
|
|
73
|
+
# Pretrained classifier without preprocessing.
|
|
74
|
+
classifier = keras_hub.models.ESMProteinClassifier.from_preset(
|
|
75
|
+
hf://facebook/esm2_t6_8M_UR50D,
|
|
76
|
+
num_classes=4,
|
|
77
|
+
preprocessor=None,
|
|
78
|
+
)
|
|
79
|
+
classifier.fit(x=features, y=labels, batch_size=2)
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
Custom backbone and vocabulary.
|
|
83
|
+
```python
|
|
84
|
+
features = ["The quick brown fox jumped.", "I forgot my homework."]
|
|
85
|
+
labels = [0, 3]
|
|
86
|
+
|
|
87
|
+
vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
|
88
|
+
vocab += ["The", "quick", "brown", "fox", "jumped", "."]
|
|
89
|
+
tokenizer = keras_hub.models.ESMTokenizer(
|
|
90
|
+
vocabulary=vocab,
|
|
91
|
+
)
|
|
92
|
+
preprocessor = keras_hub.models.ESMProteinClassifierPreprocessor(
|
|
93
|
+
tokenizer=tokenizer,
|
|
94
|
+
sequence_length=128,
|
|
95
|
+
)
|
|
96
|
+
backbone = keras_hub.models.ESMBackbone(
|
|
97
|
+
vocabulary_size=30552,
|
|
98
|
+
num_layers=4,
|
|
99
|
+
num_heads=4,
|
|
100
|
+
hidden_dim=256,
|
|
101
|
+
intermediate_dim=512,
|
|
102
|
+
max_wavelength=128,
|
|
103
|
+
num_head=4,
|
|
104
|
+
)
|
|
105
|
+
classifier = keras_hub.models.ESMProteinClassifier(
|
|
106
|
+
backbone=backbone,
|
|
107
|
+
preprocessor=preprocessor,
|
|
108
|
+
num_classes=4,
|
|
109
|
+
)
|
|
110
|
+
classifier.fit(x=features, y=labels, batch_size=2)
|
|
111
|
+
```
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
backbone_cls = ESMBackbone
|
|
115
|
+
preprocessor_cls = ESMProteinClassifierPreprocessor
|
|
116
|
+
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
backbone,
|
|
120
|
+
num_classes,
|
|
121
|
+
preprocessor=None,
|
|
122
|
+
activation=None,
|
|
123
|
+
hidden_dim=None,
|
|
124
|
+
dropout=0.0,
|
|
125
|
+
**kwargs,
|
|
126
|
+
):
|
|
127
|
+
# === Layers ===
|
|
128
|
+
self.backbone = backbone
|
|
129
|
+
self.preprocessor = preprocessor
|
|
130
|
+
self.pooled_dropout = keras.layers.Dropout(
|
|
131
|
+
dropout,
|
|
132
|
+
dtype=backbone.dtype_policy,
|
|
133
|
+
name="pooled_dropout",
|
|
134
|
+
)
|
|
135
|
+
hidden_dim = hidden_dim or backbone.hidden_dim
|
|
136
|
+
self.pooled_dense = keras.layers.Dense(
|
|
137
|
+
hidden_dim,
|
|
138
|
+
activation="tanh",
|
|
139
|
+
dtype=backbone.dtype_policy,
|
|
140
|
+
name="pooled_dense",
|
|
141
|
+
)
|
|
142
|
+
self.output_dropout = keras.layers.Dropout(
|
|
143
|
+
dropout,
|
|
144
|
+
dtype=backbone.dtype_policy,
|
|
145
|
+
name="output_dropout",
|
|
146
|
+
)
|
|
147
|
+
self.output_dense = keras.layers.Dense(
|
|
148
|
+
num_classes,
|
|
149
|
+
kernel_initializer=esm2_kernel_initializer(),
|
|
150
|
+
activation=activation,
|
|
151
|
+
dtype=backbone.dtype_policy,
|
|
152
|
+
name="logits",
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# === Functional Model ===
|
|
156
|
+
inputs = backbone.input
|
|
157
|
+
x = backbone(inputs)[:, backbone.start_token_index, :]
|
|
158
|
+
x = self.pooled_dropout(x)
|
|
159
|
+
x = self.pooled_dense(x)
|
|
160
|
+
x = self.output_dropout(x)
|
|
161
|
+
outputs = self.output_dense(x)
|
|
162
|
+
super().__init__(
|
|
163
|
+
inputs=inputs,
|
|
164
|
+
outputs=outputs,
|
|
165
|
+
**kwargs,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# === Config ===
|
|
169
|
+
self.num_classes = num_classes
|
|
170
|
+
self.activation = keras.activations.get(activation)
|
|
171
|
+
self.hidden_dim = hidden_dim
|
|
172
|
+
self.dropout = dropout
|
|
173
|
+
|
|
174
|
+
def get_config(self):
|
|
175
|
+
config = super().get_config()
|
|
176
|
+
config.update(
|
|
177
|
+
{
|
|
178
|
+
"num_classes": self.num_classes,
|
|
179
|
+
"activation": keras.activations.serialize(self.activation),
|
|
180
|
+
"hidden_dim": self.hidden_dim,
|
|
181
|
+
"dropout": self.dropout,
|
|
182
|
+
}
|
|
183
|
+
)
|
|
184
|
+
return config
|