ctranslate2 4.7.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,797 @@
1
+ """Declares specification of the Transformer model."""
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+
7
+ from ctranslate2.specs import attention_spec, common_spec, model_spec
8
+
9
+
10
+ class TransformerEncoderSpec(model_spec.LayerSpec):
11
+ def __init__(
12
+ self,
13
+ num_layers: int,
14
+ num_heads: int,
15
+ pre_norm: bool = True,
16
+ no_final_norm: bool = False,
17
+ activation: common_spec.Activation = common_spec.Activation.RELU,
18
+ num_source_embeddings: int = 1,
19
+ embeddings_merge: common_spec.EmbeddingsMerge = common_spec.EmbeddingsMerge.CONCAT,
20
+ layernorm_embedding: bool = False,
21
+ relative_position: bool = False,
22
+ relative_attention_bias: bool = False,
23
+ ffn_glu: bool = False,
24
+ rms_norm: bool = False,
25
+ multi_query_attention: bool = False,
26
+ num_heads_kv: Optional[int] = None,
27
+ head_dim: Optional[int] = None,
28
+ rotary_dim: Optional[int] = None,
29
+ rotary_interleave: bool = True,
30
+ rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
31
+ rotary_scaling_factor: float = 1,
32
+ rotary_base: float = 10000,
33
+ sliding_window: Optional[int] = None,
34
+ qk_norm: Optional[bool] = False,
35
+ pre_post_layer_norm: bool = False,
36
+ ):
37
+ """Initializes a Transformer encoder specification.
38
+
39
+ Args:
40
+ num_layers: Number of layers.
41
+ num_heads: Number of attention heads.
42
+ pre_norm: Enable the pre-norm Transformer architecture.
43
+ no_final_norm: Disable the final layer norm in the pre-norm architecture.
44
+ activation: Activation to apply in the feed-forward network.
45
+ num_source_embeddings: Number of source embeddings.
46
+ embeddings_merge: When :obj:`num_source_embeddings` > 1, specify how the
47
+ embeddings are merged.
48
+ layernorm_embedding: Apply layer normalization after the embedding layer.
49
+ relative_position: Use relative position representations in the self-attention
50
+ layers as described in https://arxiv.org/abs/1803.02155.
51
+ relative_attention_bias: Use relative attention bias in the self-attention
52
+ layers as described in the T5 paper https://arxiv.org/abs/1910.10683.
53
+ ffn_glu: Use gated linear units in the FFN layers as described in
54
+ https://arxiv.org/abs/2002.05202.
55
+ rms_norm: Use the root mean square layer normalization.
56
+ multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
57
+ num_heads_kv: Number of attention heads for the key and value.
58
+ head_dim: Number of dimensions per attention head.
59
+ rotary_dim: Apply rotary embeddings to these first N dimensions. If 0, rotary
60
+ embeddings are applied to all dimensions.
61
+ rotary_interleave: Interleave the head dimensions when rotary embeddings are applied.
62
+ Otherwise the head dimensions are sliced in half.
63
+ rotary_scaling_type: Type of RoPE scaling.
64
+ rotary_scaling_factor: Factor used in the RoPE scaling.
65
+ rotary_base: The base period of the rotary embeddings.
66
+ sliding_window: Max sequence length to retain in KV Cache.
67
+ qk_norm: Apply layer normalization to the query and key projections.
68
+ pre_post_layer_norm: Add post layer norm for each pre norm layer.
69
+ """
70
+
71
+ if multi_query_attention:
72
+ if num_heads_kv is not None and num_heads_kv != 1:
73
+ raise ValueError(
74
+ "Enabling multi_query_attention implies num_heads_kv=1"
75
+ )
76
+ num_heads_kv = 1
77
+
78
+ self.multi_query_attention = multi_query_attention
79
+ self.num_heads = np.dtype("int16").type(num_heads)
80
+ self.pre_norm = pre_norm
81
+ self.activation = np.dtype("int8").type(activation)
82
+ self.embeddings_merge = np.dtype("int8").type(embeddings_merge)
83
+ self.embeddings = [
84
+ common_spec.EmbeddingsSpec() for _ in range(num_source_embeddings)
85
+ ]
86
+ self.scale_embeddings = True
87
+ if not relative_position and not relative_attention_bias:
88
+ self.position_encodings = PositionEncoderSpec()
89
+ if pre_norm and not no_final_norm:
90
+ self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
91
+ if layernorm_embedding:
92
+ self.layernorm_embedding = common_spec.LayerNormSpec(rms_norm=rms_norm)
93
+ if sliding_window is not None:
94
+ self.sliding_window = np.dtype("int32").type(sliding_window)
95
+
96
+ self.layer = [
97
+ TransformerEncoderLayerSpec(
98
+ relative_position=relative_position,
99
+ relative_attention_bias=relative_attention_bias,
100
+ ffn_glu=ffn_glu,
101
+ rms_norm=rms_norm,
102
+ num_heads_kv=num_heads_kv,
103
+ head_dim=head_dim,
104
+ rotary_dim=rotary_dim,
105
+ rotary_interleave=rotary_interleave,
106
+ rotary_scaling_type=rotary_scaling_type,
107
+ rotary_scaling_factor=rotary_scaling_factor,
108
+ rotary_base=rotary_base,
109
+ qk_norm=qk_norm,
110
+ pre_post_layer_norm=pre_post_layer_norm,
111
+ )
112
+ for _ in range(num_layers)
113
+ ]
114
+
115
+
116
+ class TransformerDecoderSpec(model_spec.LayerSpec):
117
+ def __init__(
118
+ self,
119
+ num_layers: int,
120
+ num_heads: int,
121
+ pre_norm: bool = True,
122
+ activation: common_spec.Activation = common_spec.Activation.RELU,
123
+ layernorm_embedding: bool = False,
124
+ with_encoder_attention: bool = True,
125
+ no_final_norm: bool = False,
126
+ project_in_out: bool = False,
127
+ relative_position: bool = False,
128
+ relative_attention_bias: bool = False,
129
+ alignment_layer: int = -1,
130
+ alignment_heads: int = 1,
131
+ ffn_glu: bool = False,
132
+ rms_norm: bool = False,
133
+ alibi: bool = False,
134
+ alibi_use_positive_positions: bool = False,
135
+ scale_alibi: bool = False,
136
+ rotary_dim: Optional[int] = None,
137
+ rotary_interleave: bool = True,
138
+ rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
139
+ rotary_scaling_factor: float = 1,
140
+ rotary_base: float = 10000,
141
+ original_max_position_embeddings: int = 0,
142
+ max_position_embeddings: int = 0,
143
+ parallel_residual: bool = False,
144
+ shared_layer_norm: bool = False,
145
+ pre_post_layer_norm: bool = False,
146
+ multi_query_attention: bool = False,
147
+ num_heads_kv: Optional[int] = None,
148
+ head_dim: Optional[int] = None,
149
+ sliding_window: Optional[int] = None,
150
+ quant_type: Optional[common_spec.Quantization] = None,
151
+ quant_group_size: Optional[int] = None,
152
+ quant_bits: Optional[int] = None,
153
+ qk_norm: bool = False,
154
+ external_pre_post_encoder_layers: Optional[bool] = False,
155
+ ):
156
+ """Initializes a Transformer decoder specification.
157
+
158
+ Args:
159
+ num_layers: Number of layers.
160
+ num_heads: Number of attention heads.
161
+ pre_norm: Enable the pre-norm Transformer architecture.
162
+ activation: Activation to apply in the feed-forward network.
163
+ layernorm_embedding: Apply layer normalization after the embedding layer.
164
+ with_encoder_attention: Enable the encoder attention sublayers.
165
+ no_final_norm: Disable the final layer norm in the pre-norm architecture.
166
+ project_in_out: Add linear transformations after the embedding layer and before
167
+ the final layer.
168
+ relative_position: Use relative position representations in the self-attention
169
+ layers as described in https://arxiv.org/abs/1803.02155.
170
+ relative_attention_bias: Use relative attention bias in the self-attention
171
+ layers as described in the T5 paper https://arxiv.org/abs/1910.10683.
172
+ alignment_layer: Layer index selected for alignment.
173
+ alignment_heads: Number of attention heads selected for alignment.
174
+ ffn_glu: Use gated linear units in the FFN layers as described in
175
+ https://arxiv.org/abs/2002.05202.
176
+ rms_norm: Use the root mean square layer normalization.
177
+ alibi: Use attention with linear biases.
178
+ alibi_use_positive_positions: Use positive positions in the ALiBi definition.
179
+ scale_alibi: Apply the dot product scale factor to ALiBi.
180
+ rotary_dim: Apply rotary embeddings to these first N dimensions. If 0, rotary
181
+ embeddings are applied to all dimensions.
182
+ rotary_interleave: Interleave the head dimensions when rotary embeddings are applied.
183
+ Otherwise the head dimensions are sliced in half.
184
+ rotary_scaling_type: Type of RoPE scaling.
185
+ rotary_scaling_factor: Factor used in the RoPE scaling.
186
+ rotary_base: The base period of the rotary embeddings.
187
+ original_max_position_embeddings: The original max position embeddings
188
+ for Su rope embeddings
189
+ max_position_embeddings: The max position embeddings for Su rope embeddings
190
+ parallel_residual: Use parallel residual connections in each layer block, as used
191
+ by the GPT-J and GPT-NeoX models.
192
+ shared_layer_norm: When using parallel residual, share the input and post
193
+ attention layer norms.
194
+ pre_post_layer_norm: Add post layer norm for each pre norm layer
195
+ multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
196
+ num_heads_kv: Number of attention heads for the key and value.
197
+ sliding_window: Max sequence length to retain in KV Cache.
198
+ quant_type: quantization type used (like awq... for lower bit quantization)
199
+ quant_group_size: group size of the lower bit quantization
200
+ quant_bits: number of bit of the quantization (ex: 4bit)
201
+ external_pre_post_encoder_layers: if the encoder attention pre and processing
202
+ is done outside the attention.
203
+ """
204
+
205
+ self._config = dict()
206
+ if parallel_residual:
207
+ if not pre_norm:
208
+ raise ValueError("The GPT-J block expects a pre-norm architecture")
209
+ if with_encoder_attention:
210
+ raise ValueError("The GPT-J block does not have cross attention")
211
+
212
+ if multi_query_attention:
213
+ if num_heads_kv is not None and num_heads_kv != 1:
214
+ raise ValueError(
215
+ "Enabling multi_query_attention implies num_heads_kv=1"
216
+ )
217
+ num_heads_kv = 1
218
+
219
+ self.num_heads = np.dtype("int16").type(num_heads)
220
+ self.pre_norm = pre_norm
221
+ self.activation = np.dtype("int8").type(activation)
222
+ self.alignment_layer = np.dtype("int16").type(alignment_layer)
223
+ self.alignment_heads = np.dtype("int16").type(alignment_heads)
224
+ self.embeddings = common_spec.EmbeddingsSpec()
225
+ self.scale_embeddings = True
226
+ self.scale_outputs = model_spec.OPTIONAL
227
+ self.alibi = alibi
228
+ self.alibi_use_positive_positions = alibi_use_positive_positions
229
+ self.scale_alibi = scale_alibi
230
+ if sliding_window is not None:
231
+ self.sliding_window = np.dtype("int32").type(sliding_window)
232
+ if (
233
+ not relative_position
234
+ and not relative_attention_bias
235
+ and not alibi
236
+ and rotary_dim is None
237
+ ):
238
+ self.position_encodings = PositionEncoderSpec()
239
+ if pre_norm and not no_final_norm:
240
+ self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
241
+ if layernorm_embedding:
242
+ self.layernorm_embedding = common_spec.LayerNormSpec(rms_norm=rms_norm)
243
+ self.projection = common_spec.LinearSpec()
244
+ self.layer = [
245
+ TransformerDecoderLayerSpec(
246
+ with_encoder_attention=with_encoder_attention,
247
+ relative_position=relative_position,
248
+ relative_attention_bias=relative_attention_bias,
249
+ ffn_glu=ffn_glu,
250
+ rms_norm=rms_norm,
251
+ rotary_dim=rotary_dim,
252
+ rotary_interleave=rotary_interleave,
253
+ rotary_scaling_type=rotary_scaling_type,
254
+ rotary_scaling_factor=rotary_scaling_factor,
255
+ rotary_base=rotary_base,
256
+ original_max_position_embeddings=original_max_position_embeddings,
257
+ max_position_embeddings=max_position_embeddings,
258
+ parallel_residual=parallel_residual,
259
+ shared_layer_norm=shared_layer_norm,
260
+ pre_post_layer_norm=pre_post_layer_norm,
261
+ num_heads_kv=num_heads_kv,
262
+ head_dim=head_dim,
263
+ sliding_window=sliding_window,
264
+ qk_norm=qk_norm,
265
+ external_pre_post_encoder_layers=external_pre_post_encoder_layers,
266
+ )
267
+ for _ in range(num_layers)
268
+ ]
269
+ self.start_from_zero_embedding = False
270
+ self._config["multi_query_attention"] = multi_query_attention or (
271
+ num_heads_kv != num_heads
272
+ )
273
+
274
+ if project_in_out:
275
+ self.project_in = common_spec.LinearSpec()
276
+ self.project_out = common_spec.LinearSpec()
277
+
278
+ if quant_type:
279
+ self._config["quantization_type"] = quant_type
280
+ self._config["quantization_bits"] = quant_bits
281
+ self._config["quantization_group_size"] = quant_group_size
282
+
283
+ @property
284
+ def config(self):
285
+ return self._config
286
+
287
+
288
+ class TransformerEncoderLayerSpec(model_spec.LayerSpec):
289
+ def __init__(
290
+ self,
291
+ relative_position=False,
292
+ relative_attention_bias=False,
293
+ ffn_glu=False,
294
+ rms_norm=False,
295
+ num_heads_kv=None,
296
+ head_dim=None,
297
+ sliding_window=None,
298
+ rotary_dim: Optional[int] = None,
299
+ rotary_interleave: bool = True,
300
+ rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
301
+ rotary_scaling_factor: float = 1,
302
+ rotary_base: float = 10000,
303
+ qk_norm=False,
304
+ pre_post_layer_norm: bool = False,
305
+ ):
306
+ self.self_attention = attention_spec.MultiHeadAttentionSpec(
307
+ self_attention=True,
308
+ relative_position=relative_position,
309
+ relative_attention_bias=relative_attention_bias,
310
+ rms_norm=rms_norm,
311
+ num_heads_kv=num_heads_kv,
312
+ head_dim=head_dim,
313
+ sliding_window=sliding_window,
314
+ rotary_dim=rotary_dim,
315
+ rotary_interleave=rotary_interleave,
316
+ rotary_scaling_type=rotary_scaling_type,
317
+ rotary_scaling_factor=rotary_scaling_factor,
318
+ rotary_base=rotary_base,
319
+ qk_norm=qk_norm,
320
+ )
321
+ self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)
322
+
323
+ if pre_post_layer_norm:
324
+ self.input_layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
325
+ self.post_attention_layer_norm = common_spec.LayerNormSpec(
326
+ rms_norm=rms_norm
327
+ )
328
+ self.pre_feedforward_layer_norm = common_spec.LayerNormSpec(
329
+ rms_norm=rms_norm
330
+ )
331
+ self.post_feedforward_layer_norm = common_spec.LayerNormSpec(
332
+ rms_norm=rms_norm
333
+ )
334
+
335
+ delattr(self.self_attention, "layer_norm")
336
+ delattr(self.ffn, "layer_norm")
337
+
338
+
339
+ class TransformerDecoderLayerSpec(model_spec.LayerSpec):
340
+ def __init__(
341
+ self,
342
+ with_encoder_attention=True,
343
+ relative_position=False,
344
+ relative_attention_bias=False,
345
+ ffn_glu=False,
346
+ rms_norm=False,
347
+ rotary_dim=None,
348
+ rotary_interleave=True,
349
+ rotary_scaling_type=None,
350
+ rotary_scaling_factor=1,
351
+ rotary_base=10000,
352
+ original_max_position_embeddings=0,
353
+ max_position_embeddings=0,
354
+ parallel_residual=False,
355
+ shared_layer_norm=False,
356
+ pre_post_layer_norm=False,
357
+ num_heads_kv=None,
358
+ head_dim=None,
359
+ sliding_window=None,
360
+ qk_norm=False,
361
+ external_pre_post_encoder_layers=False,
362
+ ):
363
+ self.self_attention = attention_spec.MultiHeadAttentionSpec(
364
+ self_attention=True,
365
+ relative_position=relative_position,
366
+ relative_attention_bias=relative_attention_bias,
367
+ rms_norm=rms_norm,
368
+ rotary_dim=rotary_dim,
369
+ rotary_interleave=rotary_interleave,
370
+ rotary_scaling_type=rotary_scaling_type,
371
+ rotary_scaling_factor=rotary_scaling_factor,
372
+ rotary_base=rotary_base,
373
+ original_max_position_embeddings=original_max_position_embeddings,
374
+ max_position_embeddings=max_position_embeddings,
375
+ num_heads_kv=num_heads_kv,
376
+ head_dim=head_dim,
377
+ sliding_window=sliding_window,
378
+ qk_norm=qk_norm,
379
+ )
380
+
381
+ if with_encoder_attention:
382
+ self.attention = attention_spec.MultiHeadAttentionSpec(
383
+ rms_norm=rms_norm,
384
+ num_heads_kv=num_heads_kv,
385
+ head_dim=head_dim,
386
+ sliding_window=sliding_window,
387
+ qk_norm=qk_norm,
388
+ has_norm=external_pre_post_encoder_layers is False,
389
+ )
390
+
391
+ self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)
392
+
393
+ if parallel_residual:
394
+ if shared_layer_norm:
395
+ self.shared_layer_norm = common_spec.LayerNormSpec()
396
+ else:
397
+ self.input_layer_norm = common_spec.LayerNormSpec()
398
+ self.post_attention_layer_norm = common_spec.LayerNormSpec()
399
+
400
+ delattr(self.self_attention, "layer_norm")
401
+ delattr(self.ffn, "layer_norm")
402
+
403
+ if pre_post_layer_norm:
404
+ # Self-attention layer norms
405
+ self.input_layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
406
+ self.post_attention_layer_norm = common_spec.LayerNormSpec(
407
+ rms_norm=rms_norm
408
+ )
409
+
410
+ if with_encoder_attention and external_pre_post_encoder_layers:
411
+ self.external_post_encoder_attention_layer_norm = (
412
+ common_spec.LayerNormSpec(rms_norm=rms_norm)
413
+ )
414
+ self.external_pre_encoder_attention_layer_norm = (
415
+ common_spec.LayerNormSpec(rms_norm=rms_norm)
416
+ )
417
+
418
+ # Feed-forward layer norms
419
+ self.pre_feedforward_layer_norm = common_spec.LayerNormSpec(
420
+ rms_norm=rms_norm
421
+ )
422
+ self.post_feedforward_layer_norm = common_spec.LayerNormSpec(
423
+ rms_norm=rms_norm
424
+ )
425
+
426
+ delattr(self.self_attention, "layer_norm")
427
+ delattr(self.ffn, "layer_norm")
428
+
429
+
430
+ class FeedForwardSpec(model_spec.LayerSpec):
431
+ def __init__(self, glu=False, rms_norm=False):
432
+ self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
433
+ self.linear_0 = common_spec.LinearSpec()
434
+ self.linear_1 = common_spec.LinearSpec()
435
+ if glu:
436
+ self.linear_0_noact = common_spec.LinearSpec()
437
+
438
+
439
+ class PositionEncoderSpec(model_spec.LayerSpec):
440
+ def __init__(self):
441
+ self.encodings = model_spec.OPTIONAL
442
+
443
+
444
+ class TransformerConfig(model_spec.SequenceToSequenceModelConfig):
445
+ """Configuration for Transformer models."""
446
+
447
+ def __init__(self, layer_norm_epsilon: Optional[float] = None, **kwargs):
448
+ """Initializes the configuration for Transformer models.
449
+
450
+ Args:
451
+ layer_norm_epsilon: The layer norm epsilon value.
452
+ **kwargs: Additional configuration.
453
+ """
454
+ super().__init__(layer_norm_epsilon=layer_norm_epsilon, **kwargs)
455
+
456
+
457
+ class TransformerSpec(model_spec.SequenceToSequenceModelSpec):
458
+ """Describes a Transformer model.
459
+
460
+ The specification is invariant to hidden dimensions but requires to
461
+ explicitly set the number of layers and attention heads.
462
+ """
463
+
464
+ def __init__(
465
+ self, encoder: TransformerEncoderSpec, decoder: TransformerDecoderSpec
466
+ ):
467
+ """Initializes a Transformer model specification.
468
+
469
+ Args:
470
+ encoder: The encoder specification.
471
+ decoder: The decoder specification.
472
+ """
473
+ if not isinstance(encoder, TransformerEncoderSpec):
474
+ raise TypeError("encoder argument must be a TransformerEncoderSpec")
475
+ if not isinstance(decoder, TransformerDecoderSpec):
476
+ raise TypeError("decoder argument must be a TransformerDecoderSpec")
477
+
478
+ super().__init__()
479
+ self.encoder = encoder
480
+ self.decoder = decoder
481
+ self._config.add_attribute(
482
+ "multi_query_attention", self.encoder.multi_query_attention
483
+ )
484
+
485
+ @classmethod
486
+ def from_config(
487
+ cls,
488
+ num_layers: Union[int, Tuple[int, int]],
489
+ num_heads: int,
490
+ with_relative_position: bool = False,
491
+ pre_norm: bool = True,
492
+ no_final_norm: bool = False,
493
+ activation: common_spec.Activation = common_spec.Activation.RELU,
494
+ alignment_layer: int = -1,
495
+ alignment_heads: int = 1,
496
+ num_source_embeddings: int = 1,
497
+ embeddings_merge: common_spec.EmbeddingsMerge = common_spec.EmbeddingsMerge.CONCAT,
498
+ layernorm_embedding: bool = False,
499
+ relative_attention_bias: bool = False,
500
+ ffn_glu: bool = False,
501
+ rms_norm: bool = False,
502
+ multi_query_attention: bool = False,
503
+ ):
504
+ """Creates a Transformer model specification.
505
+
506
+ Args:
507
+ num_layers: Number of encoder and decoder layers, or a 2-tuple if the
508
+ number is different.
509
+ num_heads: Number of attention heads.
510
+ with_relative_position: Use relative position representations in the self-attention
511
+ layers as described in https://arxiv.org/abs/1803.02155.
512
+ pre_norm: Enable the pre-norm Transformer architecture.
513
+ no_final_norm: Disable the final layer norm in the pre-norm architecture.
514
+ activation: Activation to apply in the feed-forward network.
515
+ alignment_layer: Layer index selected for alignment.
516
+ alignment_heads: Number of attention heads selected for alignment.
517
+ num_source_embeddings: Number of source embeddings.
518
+ embeddings_merge: When :obj:`num_source_embeddings` > 1, specify how the
519
+ embeddings are merged.
520
+ layernorm_embedding: Apply layer normalization after the embedding layer.
521
+ relative_attention_bias: Use relative attention bias in the self-attention
522
+ layers as described in the T5 paper https://arxiv.org/abs/1910.10683.
523
+ ffn_glu: Use gated linear units in the FFN layer as described in
524
+ https://arxiv.org/abs/2002.05202.
525
+ rms_norm: Use the root mean square layer normalization.
526
+ multi_query_attention: Use multi-query attention.
527
+ """
528
+ if isinstance(num_layers, (list, tuple)):
529
+ num_encoder_layers, num_decoder_layers = num_layers
530
+ else:
531
+ num_encoder_layers, num_decoder_layers = num_layers, num_layers
532
+
533
+ encoder = TransformerEncoderSpec(
534
+ num_encoder_layers,
535
+ num_heads,
536
+ pre_norm=pre_norm,
537
+ no_final_norm=no_final_norm,
538
+ activation=activation,
539
+ num_source_embeddings=num_source_embeddings,
540
+ embeddings_merge=embeddings_merge,
541
+ layernorm_embedding=layernorm_embedding,
542
+ relative_position=with_relative_position,
543
+ relative_attention_bias=relative_attention_bias,
544
+ ffn_glu=ffn_glu,
545
+ rms_norm=rms_norm,
546
+ multi_query_attention=multi_query_attention,
547
+ )
548
+
549
+ decoder = TransformerDecoderSpec(
550
+ num_decoder_layers,
551
+ num_heads,
552
+ pre_norm=pre_norm,
553
+ no_final_norm=no_final_norm,
554
+ activation=activation,
555
+ layernorm_embedding=layernorm_embedding,
556
+ relative_position=with_relative_position,
557
+ relative_attention_bias=relative_attention_bias,
558
+ alignment_layer=alignment_layer,
559
+ alignment_heads=alignment_heads,
560
+ ffn_glu=ffn_glu,
561
+ rms_norm=rms_norm,
562
+ multi_query_attention=multi_query_attention,
563
+ )
564
+
565
+ return cls(encoder, decoder)
566
+
567
+ @property
568
+ def name(self):
569
+ return "TransformerSpec"
570
+
571
+ @property
572
+ def revision(self):
573
+ return 7
574
+
575
+ def get_default_config(self):
576
+ return TransformerConfig()
577
+
578
+ def get_source_vocabulary_size(self):
579
+ return [spec.weight.shape[0] for spec in self.encoder.embeddings]
580
+
581
+ def get_target_vocabulary_size(self):
582
+ return self.decoder.embeddings.weight.shape[0]
583
+
584
+
585
+ class TransformerDecoderModelConfig(model_spec.LanguageModelConfig):
586
+ """Configuration for Transformer decoder models."""
587
+
588
+ def __init__(self, layer_norm_epsilon: Optional[float] = None, **kwargs):
589
+ """Initializes the configuration for Transformer decoder models.
590
+
591
+ Args:
592
+ layer_norm_epsilon: The layer norm epsilon value.
593
+ **kwargs: Additional configuration.
594
+ """
595
+ super().__init__(layer_norm_epsilon=layer_norm_epsilon, **kwargs)
596
+
597
+
598
+ class TransformerDecoderModelSpec(model_spec.LanguageModelSpec):
599
+ """Describes a Transformer decoder model (e.g. GPT-2)."""
600
+
601
+ def __init__(self, decoder: TransformerDecoderSpec):
602
+ """Initializes a Transformer decoder model specification.
603
+
604
+ Args:
605
+ decoder: The decoder specification.
606
+ """
607
+ if not isinstance(decoder, TransformerDecoderSpec):
608
+ raise TypeError("decoder argument must be a TransformerDecoderSpec")
609
+
610
+ super().__init__()
611
+ self.decoder = decoder
612
+ for key, value in self.decoder.config.items():
613
+ self._config.add_attribute(key, value)
614
+
615
+ @classmethod
616
+ def from_config(
617
+ cls,
618
+ num_layers: int,
619
+ num_heads: int,
620
+ pre_norm: bool = True,
621
+ activation: common_spec.Activation = common_spec.Activation.RELU,
622
+ layernorm_embedding: bool = False,
623
+ no_final_norm: bool = False,
624
+ project_in_out: bool = False,
625
+ with_relative_position: bool = False,
626
+ ffn_glu: bool = False,
627
+ rms_norm: bool = False,
628
+ alibi: bool = False,
629
+ alibi_use_positive_positions: bool = False,
630
+ scale_alibi: bool = False,
631
+ rotary_dim: Optional[int] = None,
632
+ rotary_interleave: bool = True,
633
+ rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
634
+ rotary_scaling_factor: float = 1,
635
+ rotary_base: float = 10000,
636
+ original_max_position_embeddings: int = 0,
637
+ max_position_embeddings: int = 0,
638
+ parallel_residual: bool = False,
639
+ shared_layer_norm: bool = False,
640
+ pre_post_layer_norm: bool = False,
641
+ multi_query_attention: bool = False,
642
+ num_heads_kv: Optional[int] = None,
643
+ head_dim: Optional[int] = None,
644
+ sliding_window: Optional[int] = None,
645
+ quant_type: Optional[common_spec.Quantization] = None,
646
+ quant_group_size: Optional[int] = None,
647
+ quant_bits: Optional[int] = None,
648
+ qk_norm: bool = False,
649
+ ):
650
+ """Creates a Transformer decoder model specification.
651
+
652
+ Args:
653
+ num_layers: Number of decoder layers.
654
+ num_heads: Number of attention heads.
655
+ pre_norm: Enable the pre-norm Transformer architecture.
656
+ activation: Activation to apply in the feed-forward network.
657
+ layernorm_embedding: Apply layer normalization after the embedding layer.
658
+ no_final_norm: Do not apply layer normalization after the last decoder block.
659
+ project_in_out: Add a linear layer after the embedding layer and another one
660
+ before the final output projection.
661
+ with_relative_position: Enable relative position representations modules.
662
+ ffn_glu: Use gated linear units in the FFN layers as described in
663
+ https://arxiv.org/abs/2002.05202.
664
+ rms_norm: Use the root mean square layer normalization.
665
+ alibi: Use attention with linear biases.
666
+ alibi_use_positive_positions: Use positive positions in the ALiBi definition.
667
+ scale_alibi: Apply the dot product scale factor to ALiBi.
668
+ rotary_dim: Apply rotary embeddings to these first N dimensions. If 0, rotary
669
+ embeddings are applied to all dimensions.
670
+ rotary_interleave: Interleave the head dimensions when rotary embeddings are applied.
671
+ Otherwise the head dimensions are sliced in half.
672
+ rotary_scaling_type: Type of RoPE scaling.
673
+ rotary_scaling_factor: Factor used in the RoPE scaling.
674
+ rotary_base: The base period of the rotary embeddings.
675
+ original_max_position_embeddings: The original max position embeddings
676
+ for Su rope embeddings
677
+ max_position_embeddings: The max position embeddings for Su rope embeddings
678
+ parallel_residual: Use parallel residual connections in each layer block, as used
679
+ by the GPT-J and GPT-NeoX models.
680
+ shared_layer_norm: When using parallel residual, share the input and post
681
+ attention layer norms.
682
+ pre_post_layer_norm: add post layer norm for each pre norm layer
683
+ multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
684
+ num_heads_kv: Number of attention heads for the key and value.
685
+ head_dim: Number of head
686
+ sliding_window: max sequence length to retain KV cache
687
+ quant_type: quantization type used (like awq... for lower bit quantization)
688
+ quant_group_size: group size of the lower bit quantization
689
+ quant_bits: number of bit of the quantization (ex: 4bit)
690
+ """
691
+ decoder = TransformerDecoderSpec(
692
+ num_layers,
693
+ num_heads,
694
+ pre_norm=pre_norm,
695
+ activation=activation,
696
+ layernorm_embedding=layernorm_embedding,
697
+ with_encoder_attention=False,
698
+ no_final_norm=no_final_norm,
699
+ project_in_out=project_in_out,
700
+ relative_position=with_relative_position,
701
+ ffn_glu=ffn_glu,
702
+ rms_norm=rms_norm,
703
+ alibi=alibi,
704
+ alibi_use_positive_positions=alibi_use_positive_positions,
705
+ scale_alibi=scale_alibi,
706
+ rotary_dim=rotary_dim,
707
+ rotary_interleave=rotary_interleave,
708
+ rotary_scaling_type=rotary_scaling_type,
709
+ rotary_scaling_factor=rotary_scaling_factor,
710
+ rotary_base=rotary_base,
711
+ original_max_position_embeddings=original_max_position_embeddings,
712
+ max_position_embeddings=max_position_embeddings,
713
+ parallel_residual=parallel_residual,
714
+ shared_layer_norm=shared_layer_norm,
715
+ pre_post_layer_norm=pre_post_layer_norm,
716
+ multi_query_attention=multi_query_attention,
717
+ num_heads_kv=num_heads_kv,
718
+ head_dim=head_dim,
719
+ sliding_window=sliding_window,
720
+ quant_type=quant_type,
721
+ quant_group_size=quant_group_size,
722
+ quant_bits=quant_bits,
723
+ qk_norm=qk_norm,
724
+ )
725
+
726
+ return cls(decoder)
727
+
728
+ @property
729
+ def name(self):
730
+ return "TransformerDecoderSpec"
731
+
732
+ @property
733
+ def revision(self):
734
+ return 8
735
+
736
+ def get_default_config(self):
737
+ return TransformerDecoderModelConfig()
738
+
739
+ def get_vocabulary_size(self):
740
+ return self.decoder.embeddings.weight.shape[0]
741
+
742
+
743
+ class TransformerEncoderModelConfig(model_spec.LanguageModelConfig):
744
+ """Configuration for Transformer encoder models."""
745
+
746
+ def __init__(self, layer_norm_epsilon: Optional[float] = None, **kwargs):
747
+ """Initializes the configuration for Transformer encoder models.
748
+
749
+ Args:
750
+ layer_norm_epsilon: The layer norm epsilon value.
751
+ **kwargs: Additional configuration.
752
+ """
753
+ super().__init__(layer_norm_epsilon=layer_norm_epsilon, **kwargs)
754
+
755
+
756
+ class TransformerEncoderModelSpec(model_spec.LanguageModelSpec):
757
+ """Describes a Transformer encoder model (e.g. BERT)."""
758
+
759
+ def __init__(
760
+ self,
761
+ encoder: TransformerEncoderSpec,
762
+ pooling_layer: bool = False,
763
+ pooling_activation: common_spec.Activation = common_spec.Activation.Tanh,
764
+ ):
765
+ """Initializes a Transformer encoder model specification.
766
+
767
+ Args:
768
+ encoder: The encoder specification.
769
+ pooling_layer: Add the pooling layer.
770
+ pooling_activation: The activation to apply after the pooling layer.
771
+ """
772
+ if not isinstance(encoder, TransformerEncoderSpec):
773
+ raise TypeError("encoder argument must be a TransformerEncoderSpec")
774
+
775
+ super().__init__()
776
+ self.encoder = encoder
777
+ self._config.add_attribute(
778
+ "multi_query_attention", self.encoder.multi_query_attention
779
+ )
780
+
781
+ if pooling_layer:
782
+ self.pooler_dense = common_spec.LinearSpec()
783
+ self.pooler_activation = np.dtype("int8").type(pooling_activation)
784
+
785
+ @property
786
+ def name(self):
787
+ return "TransformerEncoderSpec"
788
+
789
+ @property
790
+ def revision(self):
791
+ return 1
792
+
793
+ def get_default_config(self):
794
+ return TransformerEncoderModelConfig()
795
+
796
+ def get_vocabulary_size(self):
797
+ return self.encoder.embeddings[0].weight.shape[0]