keras-hub 0.24.0.dev0__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/models/__init__.py +12 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +188 -14
- keras_hub/src/models/esm/esm_attention.py +11 -4
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +8 -3
- keras_hub/src/models/gemma3/gemma3_presets.py +12 -0
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +20 -8
- keras_hub/src/models/gpt_oss/__init__.py +5 -0
- keras_hub/src/models/gpt_oss/gpt_oss_attention.py +330 -0
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +221 -0
- keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py +284 -0
- keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor.py +79 -0
- keras_hub/src/models/gpt_oss/gpt_oss_decoder.py +444 -0
- keras_hub/src/models/gpt_oss/gpt_oss_layer_norm.py +34 -0
- keras_hub/src/models/gpt_oss/gpt_oss_presets.py +51 -0
- keras_hub/src/models/gpt_oss/gpt_oss_tokenizer.py +39 -0
- keras_hub/src/models/llama3/llama3_presets.py +1 -1
- keras_hub/src/models/parseq/parseq_decoder.py +21 -9
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
- keras_hub/src/utils/transformers/convert_gemma3.py +353 -0
- keras_hub/src/utils/transformers/convert_gpt_oss.py +302 -0
- keras_hub/src/utils/transformers/preset_loader.py +12 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dist-info}/METADATA +1 -1
- {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dist-info}/RECORD +27 -16
- {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dist-info}/WHEEL +0 -0
- {keras_hub-0.24.0.dev0.dist-info → keras_hub-0.25.0.dist-info}/top_level.txt +0 -0
keras_hub/models/__init__.py
CHANGED
|
@@ -340,6 +340,18 @@ from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import (
|
|
|
340
340
|
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
|
|
341
341
|
GPTNeoXTokenizer as GPTNeoXTokenizer,
|
|
342
342
|
)
|
|
343
|
+
from keras_hub.src.models.gpt_oss.gpt_oss_backbone import (
|
|
344
|
+
GptOssBackbone as GptOssBackbone,
|
|
345
|
+
)
|
|
346
|
+
from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm import (
|
|
347
|
+
GptOssCausalLM as GptOssCausalLM,
|
|
348
|
+
)
|
|
349
|
+
from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm_preprocessor import (
|
|
350
|
+
GptOssCausalLMPreprocessor as GptOssCausalLMPreprocessor,
|
|
351
|
+
)
|
|
352
|
+
from keras_hub.src.models.gpt_oss.gpt_oss_tokenizer import (
|
|
353
|
+
GptOssTokenizer as GptOssTokenizer,
|
|
354
|
+
)
|
|
343
355
|
from keras_hub.src.models.hgnetv2.hgnetv2_backbone import (
|
|
344
356
|
HGNetV2Backbone as HGNetV2Backbone,
|
|
345
357
|
)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
import numpy as np
|
|
2
3
|
from keras import ops
|
|
3
4
|
|
|
4
5
|
from keras_hub.src.api_export import keras_hub_export
|
|
@@ -25,6 +26,17 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
|
25
26
|
curves.
|
|
26
27
|
scaling_factor: float. The scaling factor used to scale positions of
|
|
27
28
|
the tokens.
|
|
29
|
+
rope_type: str. The type of RoPE scaling to apply. Supported types:
|
|
30
|
+
"linear", "dynamic", "yarn". Defaults to "linear".
|
|
31
|
+
beta_fast: float. Beta fast parameter for YaRN scaling. Only used
|
|
32
|
+
when rope_type="yarn". Defaults to 32.0.
|
|
33
|
+
beta_slow: float. Beta slow parameter for YaRN scaling. Only used
|
|
34
|
+
when rope_type="yarn". Defaults to 1.0.
|
|
35
|
+
original_max_position_embeddings: int. Original maximum position
|
|
36
|
+
embeddings for YaRN scaling. Only used when rope_type="yarn".
|
|
37
|
+
Defaults to 4096.
|
|
38
|
+
truncate: bool. Whether to apply truncation for YaRN scaling. Only used
|
|
39
|
+
when rope_type="yarn". Defaults to False.
|
|
28
40
|
sequence_axis: int. Sequence axis in the input tensor.
|
|
29
41
|
feature_axis: int. Feature axis in the input tensor.
|
|
30
42
|
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
|
@@ -69,6 +81,11 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
|
69
81
|
self,
|
|
70
82
|
max_wavelength=10000,
|
|
71
83
|
scaling_factor=1.0,
|
|
84
|
+
rope_type="linear",
|
|
85
|
+
beta_fast=32.0,
|
|
86
|
+
beta_slow=1.0,
|
|
87
|
+
original_max_position_embeddings=4096,
|
|
88
|
+
truncate=False,
|
|
72
89
|
sequence_axis=1,
|
|
73
90
|
feature_axis=-1,
|
|
74
91
|
**kwargs,
|
|
@@ -78,24 +95,70 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
|
78
95
|
self.sequence_axis = sequence_axis
|
|
79
96
|
self.feature_axis = feature_axis
|
|
80
97
|
self.scaling_factor = scaling_factor
|
|
98
|
+
self.rope_type = rope_type
|
|
99
|
+
|
|
100
|
+
# YaRN-specific parameters (only used when rope_type="yarn")
|
|
101
|
+
self.beta_fast = beta_fast
|
|
102
|
+
self.beta_slow = beta_slow
|
|
103
|
+
self.original_max_position_embeddings = original_max_position_embeddings
|
|
104
|
+
self.truncate = truncate
|
|
81
105
|
self.built = True
|
|
82
106
|
|
|
107
|
+
def _normalize_axes(self, input_shape):
|
|
108
|
+
"""Normalize and validate axis indices for the given input shape."""
|
|
109
|
+
rank = len(input_shape)
|
|
110
|
+
|
|
111
|
+
# Normalize negative indices
|
|
112
|
+
sequence_axis = self.sequence_axis
|
|
113
|
+
feature_axis = self.feature_axis
|
|
114
|
+
|
|
115
|
+
if sequence_axis < 0:
|
|
116
|
+
sequence_axis += rank
|
|
117
|
+
if feature_axis < 0:
|
|
118
|
+
feature_axis += rank
|
|
119
|
+
|
|
120
|
+
if sequence_axis < 0 or sequence_axis >= rank:
|
|
121
|
+
raise ValueError(
|
|
122
|
+
f"sequence_axis {self.sequence_axis} "
|
|
123
|
+
f"is out of range for input with rank {rank}"
|
|
124
|
+
)
|
|
125
|
+
if feature_axis < 0 or feature_axis >= rank:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"feature_axis {self.feature_axis} "
|
|
128
|
+
f"is out of range for input with rank {rank}"
|
|
129
|
+
)
|
|
130
|
+
if sequence_axis == feature_axis:
|
|
131
|
+
raise ValueError("sequence_axis and feature_axis must be different")
|
|
132
|
+
|
|
133
|
+
return sequence_axis, feature_axis
|
|
134
|
+
|
|
135
|
+
def _validate_rotary_dimension(self, rotary_dim):
|
|
136
|
+
if rotary_dim % 2 != 0:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"Rotary dimension must be even, got {rotary_dim}."
|
|
139
|
+
"The rotary embedding splits the feature dimension "
|
|
140
|
+
"into two halves. Consider using a different feature "
|
|
141
|
+
"dimension or padding."
|
|
142
|
+
)
|
|
143
|
+
|
|
83
144
|
def call(self, inputs, start_index=0, positions=None):
|
|
145
|
+
input_shape = ops.shape(inputs)
|
|
146
|
+
sequence_axis, feature_axis = self._normalize_axes(input_shape)
|
|
147
|
+
|
|
148
|
+
rotary_dim = input_shape[feature_axis]
|
|
149
|
+
self._validate_rotary_dimension(rotary_dim)
|
|
150
|
+
|
|
84
151
|
# Take care of unbatched `positions`.
|
|
85
152
|
if positions is not None:
|
|
86
153
|
if len(ops.shape(positions)) == 1:
|
|
87
154
|
positions = ops.expand_dims(positions, axis=0)
|
|
88
155
|
|
|
89
|
-
inputs = ops.moveaxis(
|
|
90
|
-
inputs, (self.feature_axis, self.sequence_axis), (-1, 1)
|
|
91
|
-
)
|
|
156
|
+
inputs = ops.moveaxis(inputs, (feature_axis, sequence_axis), (-1, 1))
|
|
92
157
|
cos_emb, sin_emb = self._compute_cos_sin_embedding(
|
|
93
158
|
inputs, start_index, positions
|
|
94
159
|
)
|
|
95
160
|
output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb)
|
|
96
|
-
return ops.moveaxis(
|
|
97
|
-
output, (-1, 1), (self.feature_axis, self.sequence_axis)
|
|
98
|
-
)
|
|
161
|
+
return ops.moveaxis(output, (-1, 1), (feature_axis, sequence_axis))
|
|
99
162
|
|
|
100
163
|
def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
|
|
101
164
|
x1, x2 = ops.split(tensor, 2, axis=-1)
|
|
@@ -113,19 +176,35 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
|
113
176
|
return positions + ops.cast(start_index, dtype="float32")
|
|
114
177
|
|
|
115
178
|
def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
|
|
179
|
+
"""Compute cos & sin RoPE embeddings with optional YaRN scaling.
|
|
180
|
+
Uses tensor ops only to remain JIT/backends friendly.
|
|
181
|
+
"""
|
|
116
182
|
batch_axis = 0
|
|
117
|
-
feature_axis = len(inputs.shape) - 1
|
|
118
183
|
sequence_axis = 1
|
|
184
|
+
feature_axis = len(inputs.shape) - 1
|
|
119
185
|
|
|
120
186
|
rotary_dim = ops.shape(inputs)[feature_axis]
|
|
121
187
|
inverse_freq = self._get_inverse_freq(rotary_dim)
|
|
122
188
|
|
|
123
189
|
if positions is None:
|
|
124
190
|
positions = self._compute_positions(inputs, start_index)
|
|
125
|
-
positions = ops.expand_dims(
|
|
191
|
+
positions = ops.expand_dims(
|
|
192
|
+
positions, axis=batch_axis
|
|
193
|
+
) # shape (1, seq_len)
|
|
126
194
|
else:
|
|
127
195
|
positions = ops.cast(positions, "float32")
|
|
128
|
-
|
|
196
|
+
if len(ops.shape(positions)) == 1:
|
|
197
|
+
positions = ops.expand_dims(positions, axis=batch_axis)
|
|
198
|
+
|
|
199
|
+
if (
|
|
200
|
+
self.rope_type == "yarn"
|
|
201
|
+
and self.truncate
|
|
202
|
+
and self.original_max_position_embeddings is not None
|
|
203
|
+
):
|
|
204
|
+
positions = ops.minimum(
|
|
205
|
+
positions,
|
|
206
|
+
ops.cast(self.original_max_position_embeddings, "float32"),
|
|
207
|
+
)
|
|
129
208
|
|
|
130
209
|
freq = ops.einsum("bi,j->bij", positions, inverse_freq)
|
|
131
210
|
|
|
@@ -140,15 +219,103 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
|
140
219
|
|
|
141
220
|
cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype)
|
|
142
221
|
sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype)
|
|
222
|
+
|
|
223
|
+
if self.rope_type == "yarn":
|
|
224
|
+
# YaRN temperature scaling
|
|
225
|
+
factor = ops.add(
|
|
226
|
+
ops.multiply(
|
|
227
|
+
ops.cast(0.1, self.compute_dtype),
|
|
228
|
+
ops.log(ops.cast(self.scaling_factor, self.compute_dtype)),
|
|
229
|
+
),
|
|
230
|
+
ops.cast(1.0, self.compute_dtype),
|
|
231
|
+
)
|
|
232
|
+
cos_emb = cos_emb * factor
|
|
233
|
+
sin_emb = sin_emb * factor
|
|
143
234
|
return cos_emb, sin_emb
|
|
144
235
|
|
|
145
236
|
def _get_inverse_freq(self, rotary_dim):
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
237
|
+
"""Return inverse frequencies."""
|
|
238
|
+
idx = ops.arange(0, rotary_dim, 2, dtype="float32")
|
|
239
|
+
denom = ops.cast(rotary_dim, "float32")
|
|
240
|
+
freq_range = idx / denom
|
|
241
|
+
inv = ops.power(ops.cast(self.max_wavelength, "float32"), -freq_range)
|
|
242
|
+
|
|
243
|
+
if self.rope_type == "linear":
|
|
244
|
+
return inv / ops.cast(self.scaling_factor, "float32")
|
|
245
|
+
elif self.rope_type == "dynamic":
|
|
246
|
+
exponent = ops.cast(rotary_dim, "float32") / ops.cast(
|
|
247
|
+
max(1, rotary_dim - 2), "float32"
|
|
248
|
+
)
|
|
249
|
+
return inv / ops.power(
|
|
250
|
+
ops.cast(self.scaling_factor, "float32"), exponent
|
|
251
|
+
)
|
|
252
|
+
elif self.rope_type == "yarn":
|
|
253
|
+
return self._get_yarn_inverse_freq(rotary_dim)
|
|
254
|
+
else:
|
|
255
|
+
return inv
|
|
256
|
+
|
|
257
|
+
def _get_yarn_inverse_freq(self, rotary_dim):
|
|
258
|
+
# Get the base (rope_theta equivalent) from max_wavelength
|
|
259
|
+
base = ops.cast(self.max_wavelength, "float32")
|
|
260
|
+
|
|
261
|
+
# Compute base frequencies: base ** (idx / dim)
|
|
262
|
+
idx = ops.arange(0, rotary_dim, 2, dtype="float32")
|
|
263
|
+
pos_freqs = ops.power(base, idx / ops.cast(rotary_dim, "float32"))
|
|
264
|
+
|
|
265
|
+
# Compute interpolation and extrapolation frequencies
|
|
266
|
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
|
267
|
+
inv_freq_interpolation = 1.0 / (
|
|
268
|
+
ops.cast(self.scaling_factor, "float32") * pos_freqs
|
|
149
269
|
)
|
|
150
|
-
|
|
151
|
-
|
|
270
|
+
|
|
271
|
+
# Find correction range
|
|
272
|
+
beta_fast = ops.cast(self.beta_fast, "float32")
|
|
273
|
+
beta_slow = ops.cast(self.beta_slow, "float32")
|
|
274
|
+
|
|
275
|
+
# Find correction dimensions for beta_fast and beta_slow
|
|
276
|
+
def find_correction_dim_tensor(num_rotations, dim):
|
|
277
|
+
max_pos = ops.cast(self.original_max_position_embeddings, "float32")
|
|
278
|
+
return (dim * ops.log(max_pos / (num_rotations * 2 * np.pi))) / (
|
|
279
|
+
2 * ops.log(base)
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
low = find_correction_dim_tensor(
|
|
283
|
+
beta_fast, ops.cast(rotary_dim, "float32")
|
|
284
|
+
)
|
|
285
|
+
high = find_correction_dim_tensor(
|
|
286
|
+
beta_slow, ops.cast(rotary_dim, "float32")
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Apply truncation if specified
|
|
290
|
+
if self.truncate:
|
|
291
|
+
low = ops.floor(low)
|
|
292
|
+
high = ops.ceil(high)
|
|
293
|
+
|
|
294
|
+
# Clamp to valid range
|
|
295
|
+
low = ops.maximum(low, ops.cast(0, "float32"))
|
|
296
|
+
high = ops.minimum(high, ops.cast(rotary_dim // 2 - 1, "float32"))
|
|
297
|
+
|
|
298
|
+
# Linear ramp function
|
|
299
|
+
dim_half = rotary_dim // 2
|
|
300
|
+
idx_half = ops.arange(0, dim_half, dtype="float32")
|
|
301
|
+
|
|
302
|
+
# Prevent singularity
|
|
303
|
+
diff = high - low
|
|
304
|
+
diff = ops.maximum(diff, ops.cast(0.001, "float32"))
|
|
305
|
+
|
|
306
|
+
linear_func = (idx_half - low) / diff
|
|
307
|
+
ramp_func = ops.clip(linear_func, 0, 1)
|
|
308
|
+
|
|
309
|
+
# Apply the ramp to get extrapolation factor
|
|
310
|
+
inv_freq_extrapolation_factor = 1 - ramp_func
|
|
311
|
+
|
|
312
|
+
# Combine interpolation and extrapolation
|
|
313
|
+
scaled_inverse_freq = (
|
|
314
|
+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
|
315
|
+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return scaled_inverse_freq
|
|
152
319
|
|
|
153
320
|
def get_config(self):
|
|
154
321
|
config = super().get_config()
|
|
@@ -156,6 +323,13 @@ class RotaryEmbedding(keras.layers.Layer):
|
|
|
156
323
|
{
|
|
157
324
|
"max_wavelength": self.max_wavelength,
|
|
158
325
|
"scaling_factor": self.scaling_factor,
|
|
326
|
+
"rope_type": self.rope_type,
|
|
327
|
+
"beta_fast": self.beta_fast,
|
|
328
|
+
"beta_slow": self.beta_slow,
|
|
329
|
+
"original_max_position_embeddings": (
|
|
330
|
+
self.original_max_position_embeddings
|
|
331
|
+
),
|
|
332
|
+
"truncate": self.truncate,
|
|
159
333
|
"sequence_axis": self.sequence_axis,
|
|
160
334
|
"feature_axis": self.feature_axis,
|
|
161
335
|
}
|
|
@@ -14,7 +14,8 @@ class ESMRotaryEmbedding(RotaryEmbedding):
|
|
|
14
14
|
inv_freq = self.scaling_factor / (
|
|
15
15
|
self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim)
|
|
16
16
|
)
|
|
17
|
-
|
|
17
|
+
# Use ops.shape for dynamic shape compatibility with TFLite
|
|
18
|
+
t = ops.arange(ops.shape(x)[position], dtype=x.dtype)
|
|
18
19
|
freqs = ops.outer(t, inv_freq)
|
|
19
20
|
emb = ops.concatenate((freqs, freqs), axis=-1)
|
|
20
21
|
|
|
@@ -32,11 +33,17 @@ class ESMRotaryEmbedding(RotaryEmbedding):
|
|
|
32
33
|
|
|
33
34
|
def rotate_half(self, x):
|
|
34
35
|
x1, x2 = ops.split(x, 2, -1)
|
|
35
|
-
|
|
36
|
+
# Avoid `ops.concatenate` to prevent XLA compilation issues on JAX
|
|
37
|
+
# backend. Use stack + reshape approach from base RotaryEmbedding.
|
|
38
|
+
half_rot_x = ops.stack((-x2, x1), axis=-2)
|
|
39
|
+
half_rot_x = ops.reshape(half_rot_x, ops.shape(x))
|
|
40
|
+
return half_rot_x
|
|
36
41
|
|
|
37
42
|
def apply_rotary_pos_emb(self, x, cos, sin):
|
|
38
|
-
|
|
39
|
-
|
|
43
|
+
# Use ops.shape for dynamic shape compatibility with TFLite
|
|
44
|
+
seq_len = ops.shape(x)[1]
|
|
45
|
+
cos = cos[:, :seq_len, :, :]
|
|
46
|
+
sin = sin[:, :seq_len, :, :]
|
|
40
47
|
|
|
41
48
|
return (x * cos) + (self.rotate_half(x) * sin)
|
|
42
49
|
|
|
@@ -283,9 +283,14 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
|
283
283
|
# is `None`.
|
|
284
284
|
self.text_only_model = self.image_converter is None
|
|
285
285
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
286
|
+
if self.text_only_model:
|
|
287
|
+
self.image_placeholder = None
|
|
288
|
+
self.start_of_image_token = None
|
|
289
|
+
self.end_of_image_token = None
|
|
290
|
+
else:
|
|
291
|
+
self.image_placeholder = self.tokenizer.image_placeholder
|
|
292
|
+
self.start_of_image_token = self.tokenizer.start_of_image_token
|
|
293
|
+
self.end_of_image_token = self.tokenizer.end_of_image_token
|
|
289
294
|
|
|
290
295
|
def build(self, input_shape):
|
|
291
296
|
# Defer packer creation to `build()` so that we can be sure tokenizer
|
|
@@ -220,4 +220,16 @@ backbone_presets = {
|
|
|
220
220
|
},
|
|
221
221
|
"kaggle_handle": "kaggle://keras/medgemma/keras/medgemma_instruct_27b_text/1",
|
|
222
222
|
},
|
|
223
|
+
"function_gemma_instruct_270m": {
|
|
224
|
+
"metadata": {
|
|
225
|
+
"description": (
|
|
226
|
+
"A 270M Million parameter text-only model based on Gemma 3. "
|
|
227
|
+
"This model is trained specifically for function calling "
|
|
228
|
+
"improvements."
|
|
229
|
+
),
|
|
230
|
+
"params": 268098176,
|
|
231
|
+
"path": "gemma3",
|
|
232
|
+
},
|
|
233
|
+
"kaggle_handle": "kaggle://keras/function-gemma/keras/function_gemma_instruct_270m/1",
|
|
234
|
+
},
|
|
223
235
|
}
|
|
@@ -77,20 +77,32 @@ class Gemma3Tokenizer(SentencePieceTokenizer):
|
|
|
77
77
|
|
|
78
78
|
backbone_cls = Gemma3Backbone
|
|
79
79
|
|
|
80
|
-
def __init__(self, proto, **kwargs):
|
|
80
|
+
def __init__(self, proto, has_vision_tokens=True, **kwargs):
|
|
81
81
|
# Add special tokens.
|
|
82
82
|
|
|
83
|
+
self.has_vision_tokens = has_vision_tokens
|
|
83
84
|
# The usual tokens.
|
|
84
85
|
self._add_special_token("<bos>", "start_token")
|
|
85
86
|
self._add_special_token("<eos>", "end_token")
|
|
86
87
|
self._add_special_token("<pad>", "pad_token")
|
|
87
88
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
89
|
+
if has_vision_tokens:
|
|
90
|
+
# Image placeholder token.
|
|
91
|
+
self._add_special_token("<img>", "image_placeholder")
|
|
92
|
+
# Some tokens which are used in the preprocessor.
|
|
93
|
+
# We need to keep them
|
|
94
|
+
# here so that the preprocessor works with tf.data.
|
|
95
|
+
self._add_special_token("<start_of_image>", "start_of_image_token")
|
|
96
|
+
self._add_special_token("<end_of_image>", "end_of_image_token")
|
|
97
|
+
else:
|
|
98
|
+
# For text-only, skip assigning token IDs or set to -1
|
|
99
|
+
self.start_of_image_token_id = -1
|
|
100
|
+
self.image_placeholder_token_id = -1
|
|
101
|
+
self.end_of_image_token_id = -1
|
|
95
102
|
|
|
96
103
|
super().__init__(proto=proto, **kwargs)
|
|
104
|
+
|
|
105
|
+
def get_config(self):
|
|
106
|
+
config = super().get_config()
|
|
107
|
+
config.update({"has_vision_tokens": self.has_vision_tokens})
|
|
108
|
+
return config
|