keras-hub-nightly 0.22.0.dev202508100425__py3-none-any.whl → 0.22.0.dev202508120417__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -199,6 +199,22 @@ from keras_hub.src.models.electra.electra_backbone import (
199
199
  from keras_hub.src.models.electra.electra_tokenizer import (
200
200
  ElectraTokenizer as ElectraTokenizer,
201
201
  )
202
+ from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESM2Backbone
203
+ from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESMBackbone
204
+ from keras_hub.src.models.esm.esm_classifier import (
205
+ ESMProteinClassifier as ESMProteinClassifier,
206
+ )
207
+ from keras_hub.src.models.esm.esm_classifier_preprocessor import (
208
+ ESMProteinClassifierPreprocessor as ESMProteinClassifierPreprocessor,
209
+ )
210
+ from keras_hub.src.models.esm.esm_masked_plm import (
211
+ ESMMaskedPLM as ESM2MaskedPLM,
212
+ )
213
+ from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM as ESMMaskedPLM
214
+ from keras_hub.src.models.esm.esm_masked_plm_preprocessor import (
215
+ ESMMaskedPLMPreprocessor as ESMMaskedPLMPreprocessor,
216
+ )
217
+ from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer
202
218
  from keras_hub.src.models.f_net.f_net_backbone import (
203
219
  FNetBackbone as FNetBackbone,
204
220
  )
File without changes
@@ -0,0 +1,95 @@
1
+ import keras
2
+ from keras import ops
3
+ from packaging import version
4
+
5
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
6
+ from keras_hub.src.models.roformer_v2.roformer_v2_attention import (
7
+ RoformerAttention,
8
+ )
9
+
10
+
11
+ class ESMRotaryEmbedding(RotaryEmbedding):
12
+ def _compute_cos_sin_embedding(self, x, position=1):
13
+ dim = x.shape[-1]
14
+ inv_freq = self.scaling_factor / (
15
+ self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim)
16
+ )
17
+ t = ops.arange(x.shape[position], dtype=x.dtype)
18
+ freqs = ops.outer(t, inv_freq)
19
+ emb = ops.concatenate((freqs, freqs), axis=-1)
20
+
21
+ cos_emb = ops.cos(emb)[None, :, None, :]
22
+ sin_emb = ops.sin(emb)[None, :, None, :]
23
+ return cos_emb, sin_emb
24
+
25
+ def call(self, q, k, position=1):
26
+ cos_emb, sin_emb = self._compute_cos_sin_embedding(q, position)
27
+
28
+ return (
29
+ self.apply_rotary_pos_emb(q, cos_emb, sin_emb),
30
+ self.apply_rotary_pos_emb(k, cos_emb, sin_emb),
31
+ )
32
+
33
+ def rotate_half(self, x):
34
+ x1, x2 = ops.split(x, 2, -1)
35
+ return ops.concatenate((-x2, x1), axis=-1)
36
+
37
+ def apply_rotary_pos_emb(self, x, cos, sin):
38
+ cos = cos[:, : x.shape[1], :, :]
39
+ sin = sin[:, : x.shape[1], :, :]
40
+
41
+ return (x * cos) + (self.rotate_half(x) * sin)
42
+
43
+
44
+ class EsmSelfAttention(RoformerAttention):
45
+ """MultiHeadAttention by ESM2
46
+
47
+ Referred to the implementation of HuggingFace.
48
+ In fact, this part of the calculation is exactly the same as RoFormer.
49
+ Only the calculation of the rotary part is different.
50
+ """
51
+
52
+ def __init__(self, use_rotary=True, **kwargs):
53
+ super().__init__(**kwargs)
54
+ self.use_rotary = use_rotary
55
+
56
+ def build(self, input_shape):
57
+ super().build(input_shape)
58
+ if self.use_rotary:
59
+ self.rotary_embedding_layer = ESMRotaryEmbedding(
60
+ max_wavelength=self.max_wavelength, dtype=self.dtype_policy
61
+ )
62
+ self.rotary_embedding_layer.build([])
63
+
64
+ def call(self, x, attention_mask=None):
65
+ qw = self.q_dense(x)
66
+ kw = self.k_dense(x)
67
+ vw = self.v_dense(x)
68
+
69
+ b, s = ops.shape(qw)[:2]
70
+ qw = ops.reshape(qw, (b, s, self.heads, self.head_size))
71
+ kw = ops.reshape(kw, (b, s, self.heads, self.head_size))
72
+ vw = ops.reshape(vw, (b, s, self.heads, self.head_size))
73
+
74
+ if self.use_rotary:
75
+ qw, kw = self.rotary_embedding_layer(qw, kw)
76
+ if version.parse(keras.__version__) < version.parse("3.6"):
77
+ raise ValueError("Please make sure your Keras version is >=3.6.")
78
+ flash_attention = keras.config.is_flash_attention_enabled()
79
+ attention_mask = ops.reshape(attention_mask, [b, 1, s, 1])
80
+ if keras.config.backend() == "torch":
81
+ attention_mask = ops.repeat(attention_mask, s, -1)
82
+ attention_mask = ops.transpose(attention_mask, [0, 1, 3, 2])
83
+ o = ops.dot_product_attention(
84
+ qw, kw, vw, mask=attention_mask, flash_attention=flash_attention
85
+ )
86
+ return self.o_dense(ops.reshape(o, [b, s, -1]))
87
+
88
+ def get_config(self):
89
+ config = super().get_config()
90
+ config.update(
91
+ {
92
+ "use_rotary": self.use_rotary,
93
+ }
94
+ )
95
+ return config
@@ -0,0 +1,229 @@
1
+ import keras
2
+ from keras import activations
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
6
+ from keras_hub.src.models.backbone import Backbone
7
+ from keras_hub.src.models.esm.esm_encoder import ESMEncoder
8
+
9
+
10
+ def esm2_kernel_initializer(stddev=0.02):
11
+ return keras.initializers.TruncatedNormal(stddev=stddev)
12
+
13
+
14
+ @keras_hub_export(
15
+ ["keras_hub.models.ESM2Backbone", "keras_hub.models.ESMBackbone"]
16
+ )
17
+ class ESMBackbone(Backbone):
18
+ """A ESM2 and ESM encoder network.
19
+
20
+ This class implements a bi-directional Transformer-based encoder as
21
+ described in ["ESM"](https://github.com/facebookresearch/esm).
22
+
23
+ The default constructor gives a fully customizable, randomly initialized
24
+ ESM2 encoder with any number of layers, heads, and embed dim.To
25
+ load preset architectures and weights, use the `from_preset()` constructor.
26
+
27
+
28
+ Args:
29
+ vocabulary_size: int. The size of the token vocabulary.
30
+ num_layers: int. The number of transformer layers.
31
+ num_heads: int. The number of attention heads for each transformer.
32
+ The hidden size must be divisible by the number of attention heads.
33
+ hidden_dim: int. The size of the transformer encoding and pooler layers.
34
+ intermediate_dim: int. The output dimension of the first Dense layer in
35
+ a two-layer feedforward network for each transformer.
36
+ dropout: float. Dropout probability for the Transformer encoder.
37
+ Defaults to 0.1
38
+ use_pre_layer_norm:bool.If true, then layer norm will be used before
39
+ entering the transformer block.
40
+ Since it's pre-norm, the default is false.
41
+ max_sequence_length: int. The maximum sequence length that this encoder
42
+ can consume. If None, `max_sequence_length` uses the value from
43
+ sequence length. This determines the variable shape for positional
44
+ embeddings.
45
+ position_embedding_type: str. The position embedding type to use.
46
+ One of "absolute" and "rotary".
47
+ Use "absolute" for ESM1. Use "rotary" for ESM2. Defaults to "rotary"
48
+ max_wavelength : int. The maximum angular wavelength of
49
+ the sine/cosine curves, for rotary embeddings.
50
+ Defaults to `10000`.
51
+ activation :string or keras.activations. The activation to
52
+ use for the transformer.
53
+ Defaults to `"gelu"`.
54
+ pad_token_id: int.padding token id. Normally 0,
55
+ but is set to 1 in the esm2 model.
56
+ Defaults to 0.
57
+ dtype: None or str or keras.mixed_precision.DTypePolicy. The dtype to
58
+ use for model computations and weights. Note that some computations,
59
+ such as softmax and layer normalization, will always be done at
60
+ float32 precision regardless of dtype.
61
+
62
+ Examples:
63
+ ```python
64
+ input_data = {
65
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
66
+ }
67
+
68
+ # Pretrained ESM2 encoder.
69
+ model = keras_hub.models.ESM2Backbone.from_preset('hf://facebook/esm2_t6_8M_UR50D')
70
+ model(input_data)
71
+
72
+ # Randomly initialized ESM2 encoder with a custom config.
73
+ model = keras_hub.models.ESM2Backbone(
74
+ vocabulary_size=30552,
75
+ num_layers=4,
76
+ num_heads=4,
77
+ hidden_dim=256,
78
+ intermediate_dim=512,
79
+ )
80
+ model(input_data)
81
+ ```
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ vocabulary_size,
87
+ num_layers,
88
+ num_heads,
89
+ hidden_dim,
90
+ intermediate_dim,
91
+ use_bias=True,
92
+ activation="gelu",
93
+ dropout=0.1,
94
+ dtype=None,
95
+ max_sequence_length=1024,
96
+ max_wavelength=10000,
97
+ layer_norm_eps=1e-12,
98
+ use_pre_layer_norm=False,
99
+ position_embedding_type="rotary",
100
+ pad_token_id=0,
101
+ **kwargs,
102
+ ):
103
+ if position_embedding_type not in (
104
+ "rotary",
105
+ "absolute",
106
+ ):
107
+ raise ValueError(
108
+ '`position_embedding_type` must be either `"rotary"`, or '
109
+ '`"absolute"`. Received '
110
+ f"position_embedding_type={position_embedding_type}."
111
+ )
112
+ head_size = hidden_dim // num_heads
113
+ # === Layers ===
114
+ self.token_embedding = keras.layers.Embedding(
115
+ input_dim=vocabulary_size,
116
+ output_dim=hidden_dim,
117
+ embeddings_initializer=esm2_kernel_initializer(),
118
+ dtype=dtype,
119
+ name="token_embedding",
120
+ )
121
+ if position_embedding_type == "absolute":
122
+ self.position_embedding = PositionEmbedding(
123
+ initializer=esm2_kernel_initializer(),
124
+ sequence_length=max_sequence_length,
125
+ dtype=dtype,
126
+ name="position_embedding",
127
+ )
128
+ self.embeddings_add = keras.layers.Add(
129
+ dtype=dtype,
130
+ name="embeddings_add",
131
+ )
132
+
133
+ self.output_layer_norm = keras.layers.LayerNormalization(
134
+ epsilon=layer_norm_eps,
135
+ dtype=dtype,
136
+ name="output_layer_norm",
137
+ )
138
+ if use_pre_layer_norm:
139
+ self.emb_layer_norm = keras.layers.LayerNormalization(
140
+ epsilon=layer_norm_eps,
141
+ dtype=dtype,
142
+ name="emb_layer_norm",
143
+ )
144
+ self.transformer_layers = []
145
+ for i in range(num_layers):
146
+ layer = ESMEncoder(
147
+ heads=num_heads,
148
+ head_size=head_size,
149
+ intermediate_size=intermediate_dim,
150
+ use_bias=use_bias,
151
+ max_wavelength=max_wavelength,
152
+ dropout=dropout,
153
+ activation=activation,
154
+ kernel_initializer=esm2_kernel_initializer(),
155
+ layer_norm_eps=layer_norm_eps,
156
+ dtype=dtype,
157
+ use_rotary=position_embedding_type == "rotary",
158
+ name=f"transformer_layer_{i}",
159
+ )
160
+ self.transformer_layers.append(layer)
161
+
162
+ # === Functional Model ===
163
+ token_id_input = keras.Input(
164
+ shape=(None,), dtype="int32", name="token_ids"
165
+ )
166
+
167
+ attention_mask = keras.ops.not_equal(token_id_input, pad_token_id)
168
+
169
+ token_vector = self.token_embedding(token_id_input)
170
+ if position_embedding_type == "absolute":
171
+ position_vector = self.position_embedding(
172
+ token_vector, start_index=pad_token_id
173
+ )
174
+ x = self.embeddings_add([token_vector, position_vector])
175
+ else:
176
+ x = token_vector
177
+ if use_pre_layer_norm:
178
+ x = self.emb_layer_norm(x)
179
+ for transformer_layer in self.transformer_layers:
180
+ x = transformer_layer(x, attention_mask=attention_mask)
181
+ output = self.output_layer_norm(x)
182
+ super().__init__(
183
+ inputs={
184
+ "token_ids": token_id_input,
185
+ },
186
+ outputs=output,
187
+ dtype=dtype,
188
+ **kwargs,
189
+ )
190
+
191
+ # === Config ===
192
+ self.vocabulary_size = vocabulary_size
193
+ self.num_layers = num_layers
194
+ self.num_heads = num_heads
195
+ self.hidden_dim = hidden_dim
196
+ self.intermediate_dim = intermediate_dim
197
+ self.dropout = dropout
198
+ self.max_wavelength = max_wavelength
199
+ self.head_size = head_size
200
+ self.activation = activations.get(activation)
201
+ self.use_bias = use_bias
202
+ self.start_token_index = 0
203
+ self.layer_norm_eps = layer_norm_eps
204
+ self.max_sequence_length = max_sequence_length
205
+ self.use_pre_layer_norm = use_pre_layer_norm
206
+ self.position_embedding_type = position_embedding_type
207
+ self.pad_token_id = pad_token_id
208
+
209
+ def get_config(self):
210
+ config = super().get_config()
211
+ config.update(
212
+ {
213
+ "vocabulary_size": self.vocabulary_size,
214
+ "num_layers": self.num_layers,
215
+ "num_heads": self.num_heads,
216
+ "hidden_dim": self.hidden_dim,
217
+ "intermediate_dim": self.intermediate_dim,
218
+ "dropout": self.dropout,
219
+ "max_wavelength": self.max_wavelength,
220
+ "use_bias": self.use_bias,
221
+ "activation": activations.serialize(self.activation),
222
+ "layer_norm_eps": self.layer_norm_eps,
223
+ "use_pre_layer_norm": self.use_pre_layer_norm,
224
+ "position_embedding_type": self.position_embedding_type,
225
+ "max_sequence_length": self.max_sequence_length,
226
+ "pad_token_id": self.pad_token_id,
227
+ }
228
+ )
229
+ return config
@@ -0,0 +1,184 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.esm.esm_backbone import ESMBackbone
5
+ from keras_hub.src.models.esm.esm_backbone import esm2_kernel_initializer
6
+ from keras_hub.src.models.esm.esm_classifier_preprocessor import (
7
+ ESMProteinClassifierPreprocessor,
8
+ )
9
+ from keras_hub.src.models.text_classifier import TextClassifier
10
+
11
+
12
+ @keras_hub_export("keras_hub.models.ESMProteinClassifier")
13
+ class ESMProteinClassifier(TextClassifier):
14
+ """An end-to-end ESM model for classification tasks.
15
+
16
+ This model attaches a classification head to
17
+ `keras_hub.models.ESMBackbone`, mapping from the backbone outputs
18
+ to logits suitable for a classification task. For usage of this model with
19
+ pre-trained weights, use the `from_preset()` constructor.
20
+
21
+ This model can optionally be configured with a `preprocessor` layer, in
22
+ which case it will automatically apply preprocessing to raw inputs during
23
+ `fit()`, `predict()`, and `evaluate()`. This is done by default when
24
+ creating the model with `from_preset()`.
25
+
26
+ Args:
27
+ backbone: A `keras_hub.models.ESMBackbone` instance.
28
+ num_classes: int. Number of classes to predict.
29
+ preprocessor: A `keras_hub.models.ESMProteinClassifierPreprocessor`
30
+ or `None`. If `None`, this model will not apply preprocessing, and
31
+ inputs should be preprocessed before calling the model.
32
+ activation: Optional `str` or callable. The
33
+ activation function to use on the model outputs. Set
34
+ `activation="softmax"` to return output probabilities.
35
+ Defaults to `None`.
36
+ dropout: float. The dropout probability value, applied after the dense
37
+ layer.
38
+
39
+ Examples:
40
+
41
+ Raw string data.
42
+ ```python
43
+ features = ["The quick brown fox jumped.", "I forgot my homework."]
44
+ labels = [0, 3]
45
+
46
+ # Pretrained classifier.
47
+ classifier = keras_hub.models.ESMProteinClassifier.from_preset(
48
+ hf://facebook/esm2_t6_8M_UR50D,
49
+ num_classes=4,
50
+ )
51
+ classifier.fit(x=features, y=labels, batch_size=2)
52
+ classifier.predict(x=features, batch_size=2)
53
+
54
+ # Re-compile (e.g., with a new learning rate).
55
+ classifier.compile(
56
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
57
+ optimizer=keras.optimizers.Adam(5e-5),
58
+ jit_compile=True,
59
+ )
60
+ # Access backbone programmatically (e.g., to change `trainable`).
61
+ classifier.backbone.trainable = False
62
+ # Fit again.
63
+ classifier.fit(x=features, y=labels, batch_size=2)
64
+ ```
65
+
66
+ Preprocessed integer data.
67
+ ```python
68
+ features = {
69
+ "token_ids": np.ones(shape=(2, 12), dtype="int32"),
70
+ }
71
+ labels = [0, 3]
72
+
73
+ # Pretrained classifier without preprocessing.
74
+ classifier = keras_hub.models.ESMProteinClassifier.from_preset(
75
+ hf://facebook/esm2_t6_8M_UR50D,
76
+ num_classes=4,
77
+ preprocessor=None,
78
+ )
79
+ classifier.fit(x=features, y=labels, batch_size=2)
80
+ ```
81
+
82
+ Custom backbone and vocabulary.
83
+ ```python
84
+ features = ["The quick brown fox jumped.", "I forgot my homework."]
85
+ labels = [0, 3]
86
+
87
+ vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
88
+ vocab += ["The", "quick", "brown", "fox", "jumped", "."]
89
+ tokenizer = keras_hub.models.ESMTokenizer(
90
+ vocabulary=vocab,
91
+ )
92
+ preprocessor = keras_hub.models.ESMProteinClassifierPreprocessor(
93
+ tokenizer=tokenizer,
94
+ sequence_length=128,
95
+ )
96
+ backbone = keras_hub.models.ESMBackbone(
97
+ vocabulary_size=30552,
98
+ num_layers=4,
99
+ num_heads=4,
100
+ hidden_dim=256,
101
+ intermediate_dim=512,
102
+ max_wavelength=128,
103
+ num_head=4,
104
+ )
105
+ classifier = keras_hub.models.ESMProteinClassifier(
106
+ backbone=backbone,
107
+ preprocessor=preprocessor,
108
+ num_classes=4,
109
+ )
110
+ classifier.fit(x=features, y=labels, batch_size=2)
111
+ ```
112
+ """
113
+
114
+ backbone_cls = ESMBackbone
115
+ preprocessor_cls = ESMProteinClassifierPreprocessor
116
+
117
+ def __init__(
118
+ self,
119
+ backbone,
120
+ num_classes,
121
+ preprocessor=None,
122
+ activation=None,
123
+ hidden_dim=None,
124
+ dropout=0.0,
125
+ **kwargs,
126
+ ):
127
+ # === Layers ===
128
+ self.backbone = backbone
129
+ self.preprocessor = preprocessor
130
+ self.pooled_dropout = keras.layers.Dropout(
131
+ dropout,
132
+ dtype=backbone.dtype_policy,
133
+ name="pooled_dropout",
134
+ )
135
+ hidden_dim = hidden_dim or backbone.hidden_dim
136
+ self.pooled_dense = keras.layers.Dense(
137
+ hidden_dim,
138
+ activation="tanh",
139
+ dtype=backbone.dtype_policy,
140
+ name="pooled_dense",
141
+ )
142
+ self.output_dropout = keras.layers.Dropout(
143
+ dropout,
144
+ dtype=backbone.dtype_policy,
145
+ name="output_dropout",
146
+ )
147
+ self.output_dense = keras.layers.Dense(
148
+ num_classes,
149
+ kernel_initializer=esm2_kernel_initializer(),
150
+ activation=activation,
151
+ dtype=backbone.dtype_policy,
152
+ name="logits",
153
+ )
154
+
155
+ # === Functional Model ===
156
+ inputs = backbone.input
157
+ x = backbone(inputs)[:, backbone.start_token_index, :]
158
+ x = self.pooled_dropout(x)
159
+ x = self.pooled_dense(x)
160
+ x = self.output_dropout(x)
161
+ outputs = self.output_dense(x)
162
+ super().__init__(
163
+ inputs=inputs,
164
+ outputs=outputs,
165
+ **kwargs,
166
+ )
167
+
168
+ # === Config ===
169
+ self.num_classes = num_classes
170
+ self.activation = keras.activations.get(activation)
171
+ self.hidden_dim = hidden_dim
172
+ self.dropout = dropout
173
+
174
+ def get_config(self):
175
+ config = super().get_config()
176
+ config.update(
177
+ {
178
+ "num_classes": self.num_classes,
179
+ "activation": keras.activations.serialize(self.activation),
180
+ "hidden_dim": self.hidden_dim,
181
+ "dropout": self.dropout,
182
+ }
183
+ )
184
+ return config
@@ -0,0 +1,135 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
5
+ from keras_hub.src.models.esm.esm_backbone import ESMBackbone
6
+ from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer
7
+ from keras_hub.src.models.text_classifier_preprocessor import (
8
+ TextClassifierPreprocessor,
9
+ )
10
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.ESMProteinClassifierPreprocessor")
14
+ class ESMProteinClassifierPreprocessor(TextClassifierPreprocessor):
15
+ """A ESM preprocessing layer which tokenizes and packs inputs.
16
+
17
+ This preprocessing layer will do three things:
18
+
19
+ 1. Tokenize any number of input segments using the `tokenizer`.
20
+ 2. Pack the inputs together using a `keras_hub.layers.StartEndPacker`.
21
+ with the appropriate start, end and pad tokens.
22
+ 3. Construct a dictionary with the key `"token_ids"`, that can be passed
23
+ directly to an ESM model.
24
+ This layer can be used directly with `tf.data.Dataset.map` to preprocess
25
+ string data in the `(x, y, sample_weight)` format used by
26
+ `keras.Model.fit`.
27
+
28
+ Args:
29
+ tokenizer: A `keras_hub.models.ESMTokenizer` instance.
30
+ sequence_length: The length of the packed inputs.
31
+ truncate: string. The algorithm to truncate a list of batched segments
32
+ to fit within `sequence_length`. The value can be either
33
+ `round_robin` or `waterfall`:
34
+ - `"round_robin"`: Available space is assigned one token at a
35
+ time in a round-robin fashion to the inputs that still need
36
+ some, until the limit is reached.
37
+ - `"waterfall"`: The allocation of the budget is done using a
38
+ "waterfall" algorithm that allocates quota in a
39
+ left-to-right manner and fills up the buckets until we run
40
+ out of budget. It supports an arbitrary number of segments.
41
+
42
+ Call arguments:
43
+ x: A tensor of single string sequences, or a tuple of multiple
44
+ tensor sequences to be packed together. Inputs may be batched or
45
+ unbatched. For single sequences, raw python inputs will be converted
46
+ to tensors. For multiple sequences, pass tensors directly.
47
+ y: Any label data. Will be passed through unaltered.
48
+ sample_weight: Any label weight data. Will be passed through unaltered.
49
+
50
+ Examples:
51
+
52
+ Directly calling the layer on data.
53
+ ```python
54
+ preprocessor = keras_hub.models.ESMProteinClassifierPreprocessor.from_preset
55
+ (
56
+ hf://facebook/esm2_t6_8M_UR50D
57
+ )
58
+
59
+ # Tokenize and pack a single sentence.
60
+ preprocessor("The quick brown fox jumped.")
61
+
62
+ # Tokenize a batch of single sentences.
63
+ preprocessor(["The quick brown fox jumped.", "Call me Ishmael."])
64
+
65
+ # Preprocess a batch of sentence pairs.
66
+ # When handling multiple sequences, always convert to tensors first!
67
+ first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."])
68
+ second = tf.constant(["The fox tripped.", "Oh look, a whale."])
69
+ preprocessor((first, second))
70
+
71
+ # Custom vocabulary.
72
+ vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
73
+ vocab += ["The", "quick", "brown", "fox", "jumped", "."]
74
+ tokenizer = keras_hub.models.ESMTokenizer(vocabulary=vocab)
75
+ preprocessor =
76
+ keras_hub.models.ESMProteinClassifierPreprocessor(tokenizer)
77
+ preprocessor("The quick brown fox jumped.")
78
+ ```
79
+
80
+ Mapping with `tf.data.Dataset`.
81
+ ```python
82
+ preprocessor = keras_hub.models.ESMProteinClassifierPreprocessor.from_preset
83
+ (
84
+ hf://facebook/esm2_t6_8M_UR50D
85
+ )
86
+
87
+ first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."])
88
+ second = tf.constant(["The fox tripped.", "Oh look, a whale."])
89
+ label = tf.constant([1, 1])
90
+
91
+ # Map labeled single sentences.
92
+ ds = tf.data.Dataset.from_tensor_slices((first, label))
93
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
94
+
95
+ # Map unlabeled single sentences.
96
+ ds = tf.data.Dataset.from_tensor_slices(first)
97
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
98
+
99
+ # Map labeled sentence pairs.
100
+ ds = tf.data.Dataset.from_tensor_slices(((first, second), label))
101
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
102
+
103
+ # Map unlabeled sentence pairs.
104
+ ds = tf.data.Dataset.from_tensor_slices((first, second))
105
+ # Watch out for tf.data's default unpacking of tuples here!
106
+ # Best to invoke the `preprocessor` directly in this case.
107
+ ds = ds.map(
108
+ lambda first, second: preprocessor(x=(first, second)),
109
+ num_parallel_calls=tf.data.AUTOTUNE,
110
+ )
111
+ ```
112
+ """
113
+
114
+ backbone_cls = ESMBackbone
115
+ tokenizer_cls = ESMTokenizer
116
+
117
+ def build(self, input_shape):
118
+ super().build(input_shape)
119
+ # Defer masker creation to `build()` so that we can be sure tokenizer
120
+ # assets have loaded when restoring a saved model.
121
+ self.packer = StartEndPacker(
122
+ start_value=self.tokenizer.start_token_id,
123
+ end_value=self.tokenizer.end_token_id,
124
+ pad_value=self.tokenizer.pad_token_id,
125
+ sequence_length=self.sequence_length,
126
+ )
127
+
128
+ @preprocessing_function
129
+ def call(self, x, y=None, sample_weight=None):
130
+ x = self.tokenizer(x)
131
+ token_ids = self.packer(x)
132
+ x = {
133
+ "token_ids": token_ids,
134
+ }
135
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)