nmn 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1650 @@
1
+ # Install required packages:
2
+ !pip install -Uq tiktoken grain matplotlib datasets
3
+ # pip install tensorflow-cpu wandb
4
+ # pip install --upgrade jax jaxlib flax grain
5
+ # pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
6
+
7
+ import jax
8
+ jax.devices()
9
+
10
+ # Download TinyStories dataset (can be replaced with HuggingFace datasets)
11
+ # wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
12
+
13
+ import jax
14
+ import jax.numpy as jnp
15
+
16
+ from jax.sharding import Mesh, PartitionSpec as P, NamedSharding # For data and model parallelism (explained in more detail later)
17
+ from jax.experimental import mesh_utils
18
+
19
+ import flax.nnx as nnx
20
+ import optax
21
+
22
+ from dataclasses import dataclass
23
+ import grain.python as pygrain
24
+ import pandas as pd
25
+ import tiktoken
26
+ import time
27
+ from typing import Optional, Dict, List, Union
28
+ import warnings
29
+
30
+ # Hugging Face datasets integration
31
+ try:
32
+ from datasets import load_dataset, DatasetDict
33
+ import datasets
34
+ HF_DATASETS_AVAILABLE = True
35
+ except ImportError:
36
+ HF_DATASETS_AVAILABLE = False
37
+ warnings.warn("Hugging Face datasets not available. Install with: pip install datasets")
38
+
39
+ # Create a `Mesh` object representing TPU device arrangement.
40
+ mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
41
+
42
+ tokenizer = tiktoken.get_encoding("gpt2")
43
+
44
+ # Dataset Configuration
45
+ @dataclass
46
+ class DatasetConfig:
47
+ """Configuration for dataset loading and preprocessing"""
48
+ name: str = "roneneldan/TinyStories" # Default dataset
49
+ subset: Optional[str] = None # For datasets with multiple subsets
50
+ split: str = "train"
51
+ text_column: str = "text"
52
+ streaming: bool = False # Use streaming for large datasets
53
+ cache_dir: Optional[str] = None
54
+ trust_remote_code: bool = False
55
+
56
+ # Text preprocessing options
57
+ separator: str = "<|endoftext|>" # Token to separate documents
58
+ min_length: int = 10 # Minimum text length to include
59
+ max_length: Optional[int] = None # Maximum text length before truncation
60
+
61
+ # File-based dataset options (for local files)
62
+ file_path: Optional[str] = None
63
+ file_type: str = "txt" # txt, json, csv, parquet
64
+
65
+ # Predefined dataset configurations
66
+ DATASET_CONFIGS = {
67
+ "tinystories": DatasetConfig(
68
+ name="roneneldan/TinyStories",
69
+ text_column="text",
70
+ separator="<|endoftext|>"
71
+ ),
72
+ "wikitext": DatasetConfig(
73
+ name="Salesforce/wikitext",
74
+ subset="wikitext-2-raw-v1",
75
+ text_column="text",
76
+ separator="\n\n"
77
+ ),
78
+ "openwebtext": DatasetConfig(
79
+ name="Skylion007/openwebtext",
80
+ text_column="text",
81
+ streaming=True, # Large dataset, use streaming
82
+ separator="<|endoftext|>"
83
+ ),
84
+ "bookscorpus": DatasetConfig(
85
+ name="bookcorpus/bookcorpus",
86
+ text_column="text",
87
+ trust_remote_code=True,
88
+ separator="<|endoftext|>"
89
+ ),
90
+ "c4": DatasetConfig(
91
+ name="allenai/c4",
92
+ subset="en",
93
+ text_column="text",
94
+ streaming=True,
95
+ separator="<|endoftext|>"
96
+ ),
97
+ "tiny_shakespeare": DatasetConfig(
98
+ name="tiny_shakespeare",
99
+ text_column="text",
100
+ separator="\n\n"
101
+ ),
102
+ "gutenberg": DatasetConfig(
103
+ name="sedthh/gutenberg_english",
104
+ text_column="text",
105
+ separator="<|endoftext|>"
106
+ ),
107
+ "pile": DatasetConfig(
108
+ name="EleutherAI/pile",
109
+ text_column="text",
110
+ streaming=True,
111
+ separator="<|endoftext|>"
112
+ ),
113
+ "common_crawl": DatasetConfig(
114
+ name="oscar",
115
+ subset="unshuffled_deduplicated_en",
116
+ text_column="text",
117
+ streaming=True,
118
+ separator="<|endoftext|>"
119
+ ),
120
+ "local_file": DatasetConfig(
121
+ name="local",
122
+ file_path="TinyStories-train.txt",
123
+ file_type="txt",
124
+ separator="<|endoftext|>"
125
+ ),
126
+ "fineweb": DatasetConfig(
127
+ name="VisionTheta/fineweb-100B",
128
+ file_type="txt",
129
+ streaming=True,
130
+ separator="<|endoftext|>"
131
+ ),
132
+
133
+ }
134
+
135
+ def load_huggingface_dataset(config: DatasetConfig) -> List[str]:
136
+ """Load dataset from Hugging Face Hub"""
137
+ if not HF_DATASETS_AVAILABLE:
138
+ raise ImportError("Hugging Face datasets not available. Install with: pip install datasets")
139
+
140
+ print(f"Loading dataset: {config.name}")
141
+ if config.subset:
142
+ print(f" Subset: {config.subset}")
143
+
144
+ try:
145
+ # Load dataset
146
+ load_kwargs = {
147
+ "path": config.name,
148
+ "split": config.split,
149
+ "streaming": config.streaming,
150
+ "trust_remote_code": config.trust_remote_code
151
+ }
152
+
153
+ if config.subset:
154
+ load_kwargs["name"] = config.subset
155
+ if config.cache_dir:
156
+ load_kwargs["cache_dir"] = config.cache_dir
157
+
158
+ dataset = load_dataset(**load_kwargs)
159
+
160
+ print(f"Dataset loaded successfully. Processing text...")
161
+
162
+ # Extract text
163
+ texts = []
164
+ count = 0
165
+ max_samples = 50000 if config.streaming else None # Limit for streaming datasets
166
+
167
+ for item in dataset:
168
+ if max_samples and count >= max_samples:
169
+ break
170
+
171
+ # Extract text from item
172
+ if config.text_column in item:
173
+ text = item[config.text_column]
174
+ elif isinstance(item, str):
175
+ text = item
176
+ else:
177
+ print(f"Warning: Text column '{config.text_column}' not found in item: {item.keys()}")
178
+ continue
179
+
180
+ # Filter by length
181
+ if len(text) < config.min_length:
182
+ continue
183
+ if config.max_length and len(text) > config.max_length:
184
+ text = text[:config.max_length]
185
+
186
+ texts.append(text)
187
+ count += 1
188
+
189
+ if count % 10000 == 0:
190
+ print(f" Processed {count} samples...")
191
+
192
+ print(f"Dataset processing complete. Total samples: {len(texts)}")
193
+ return texts
194
+
195
+ except Exception as e:
196
+ print(f"Error loading dataset {config.name}: {e}")
197
+ print("Falling back to local file if available...")
198
+ return []
199
+
200
+ def load_local_file(config: DatasetConfig) -> List[str]:
201
+ """Load dataset from local file"""
202
+ if not config.file_path:
203
+ raise ValueError("file_path must be specified for local datasets")
204
+
205
+ print(f"Loading local file: {config.file_path}")
206
+
207
+ try:
208
+ if config.file_type == "txt":
209
+ with open(config.file_path, 'r', encoding='utf-8') as f:
210
+ text = f.read()
211
+
212
+ # Split text by separator
213
+ if config.separator in text:
214
+ texts = text.split(config.separator)
215
+ texts = [t.strip() + config.separator for t in texts if t.strip()]
216
+ else:
217
+ # Split by paragraphs if no separator found
218
+ texts = [p.strip() for p in text.split('\n\n') if p.strip()]
219
+
220
+ elif config.file_type == "json":
221
+ import json
222
+ with open(config.file_path, 'r', encoding='utf-8') as f:
223
+ data = json.load(f)
224
+
225
+ if isinstance(data, list):
226
+ texts = [item[config.text_column] if isinstance(item, dict) else str(item) for item in data]
227
+ else:
228
+ texts = [data[config.text_column]] if isinstance(data, dict) else [str(data)]
229
+
230
+ elif config.file_type == "csv":
231
+ import csv
232
+ texts = []
233
+ with open(config.file_path, 'r', encoding='utf-8') as f:
234
+ reader = csv.DictReader(f)
235
+ for row in reader:
236
+ if config.text_column in row:
237
+ texts.append(row[config.text_column])
238
+
239
+ else:
240
+ raise ValueError(f"Unsupported file type: {config.file_type}")
241
+
242
+ # Filter by length
243
+ filtered_texts = []
244
+ for text in texts:
245
+ if len(text) >= config.min_length:
246
+ if config.max_length and len(text) > config.max_length:
247
+ text = text[:config.max_length]
248
+ filtered_texts.append(text)
249
+
250
+ print(f"Local file loaded successfully. Total samples: {len(filtered_texts)}")
251
+ return filtered_texts
252
+
253
+ except Exception as e:
254
+ print(f"Error loading local file: {e}")
255
+ return []
256
+
257
+ def get_dataset_info(dataset_name: str) -> Dict:
258
+ """Get information about available datasets"""
259
+ if dataset_name in DATASET_CONFIGS:
260
+ config = DATASET_CONFIGS[dataset_name]
261
+ return {
262
+ "name": config.name,
263
+ "subset": config.subset,
264
+ "text_column": config.text_column,
265
+ "separator": config.separator,
266
+ "streaming": config.streaming,
267
+ "description": f"Predefined configuration for {dataset_name}"
268
+ }
269
+ else:
270
+ return {"error": f"Dataset '{dataset_name}' not found in predefined configurations"}
271
+
272
+ def list_available_datasets() -> List[str]:
273
+ """List all available dataset configurations"""
274
+ return list(DATASET_CONFIGS.keys())
275
+
276
+ """Neural-Matter Network Definition"""
277
+
278
+ # Commented out IPython magic to ensure Python compatibility.
279
+ # Copyright 2024 The Flax Authors.
280
+ #
281
+ # Licensed under the Apache License, Version 2.0 (the "License");
282
+ # you may not use this file except in compliance with the License.
283
+ # You may obtain a copy of the License at
284
+ #
285
+ # http://www.apache.org/licenses/LICENSE-2.0
286
+ #
287
+ # Unless required by applicable law or agreed to in writing, software
288
+ # distributed under the License is distributed on an "AS IS" BASIS,
289
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
290
+ # See the License for the specific language governing permissions and
291
+ # limitations under the License.
292
+
293
+ """Attention core modules for Flax."""
294
+
295
+ from __future__ import annotations
296
+
297
+ import functools
298
+ from typing import Any, Callable, Optional
299
+ import typing as tp
300
+
301
+ import jax
302
+ import jax.numpy as jnp
303
+ from jax import lax, random
304
+
305
+ from flax import nnx
306
+ from flax.nnx import rnglib
307
+ from flax.nnx.module import Module, first_from
308
+ from flax.nnx.nn import initializers
309
+ from flax.nnx.nn.dtypes import promote_dtype
310
+ from flax.nnx.nn.linear import (
311
+ LinearGeneral,
312
+ default_kernel_init,
313
+ )
314
+ from flax.nnx.nn.normalization import LayerNorm
315
+ from flax.typing import (
316
+ Dtype,
317
+ Shape,
318
+ Initializer,
319
+ PrecisionLike,
320
+ DotGeneralT,
321
+ )
322
+
323
+
324
+
325
+ def yat_attention_weights(
326
+ query: Array,
327
+ key: Array,
328
+ bias: Optional[Array] = None,
329
+ mask: Optional[Array] = None,
330
+ broadcast_dropout: bool = True,
331
+ dropout_rng: Optional[Array] = None,
332
+ dropout_rate: float = 0.0,
333
+ deterministic: bool = False,
334
+ dtype: Optional[Dtype] = None,
335
+ precision: PrecisionLike = None,
336
+ module: Optional[Module] = None,
337
+ epsilon: float = 1e-5,
338
+ ):
339
+ """Computes attention weights using YatNMN distance-based calculation."""
340
+ query, key = promote_dtype((query, key), dtype=dtype)
341
+ dtype = query.dtype
342
+
343
+ assert query.ndim == key.ndim, 'q, k must have same rank.'
344
+ assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
345
+ assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
346
+ assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
347
+
348
+ # YatNMN-style attention calculation using the cleaner approach
349
+ # query shape: [..., q_length, num_heads, head_dim]
350
+ # key shape: [..., kv_length, num_heads, head_dim]
351
+
352
+ # Calculate dot product attention scores
353
+ attn = jnp.einsum('...qhd,...khd->...hqk', query, key, precision=precision)
354
+ squared_dot_product = jnp.square(attn)
355
+
356
+ # Calculate norms
357
+ q_norm = jnp.sum(jnp.square(query), axis=-1, keepdims=True) # [..., q_length, num_heads, 1]
358
+ k_norm = jnp.sum(jnp.square(key), axis=-1, keepdims=True) # [..., kv_length, num_heads, 1]
359
+ qk_norm_sum = q_norm + k_norm # Broadcasting: [..., q_length, num_heads, 1] + [..., kv_length, num_heads, 1]
360
+
361
+ # Transpose to match attention dimensions [..., num_heads, q_length, kv_length]
362
+ # The transpose converts [..., q_length, num_heads, kv_length] -> [..., num_heads, q_length, kv_length]
363
+ batch_dims = len(qk_norm_sum.shape) - 3
364
+ transpose_axes = tuple(range(batch_dims)) + (batch_dims + 1, batch_dims, batch_dims + 2)
365
+ qk_norm_sum_transposed = qk_norm_sum.transpose(transpose_axes)
366
+
367
+ # Calculate squared distances: ||q||² + ||k||² - 2*(q·k)²
368
+ squared_dist = qk_norm_sum_transposed - 2.0 * squared_dot_product
369
+
370
+ # YatNMN attention scores: (q·k)² / (squared_distance + ε)
371
+ attn_weights = squared_dot_product / (squared_dist + epsilon)
372
+
373
+ # apply attention bias: masking, dropout, proximity bias, etc.
374
+ if bias is not None:
375
+ attn_weights = attn_weights + bias
376
+ # apply attention mask
377
+ if mask is not None:
378
+ big_neg = jnp.finfo(dtype).min
379
+ attn_weights = jnp.where(mask, attn_weights, big_neg)
380
+
381
+ # normalize the attention weights
382
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
383
+
384
+ if module:
385
+ module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
386
+
387
+ # apply attention dropout
388
+ if not deterministic and dropout_rate > 0.0:
389
+ keep_prob = 1.0 - dropout_rate
390
+ if broadcast_dropout:
391
+ # dropout is broadcast across the batch + head dimensions
392
+ dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
393
+ keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
394
+ else:
395
+ keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
396
+ multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
397
+ attn_weights = attn_weights * multiplier
398
+
399
+ return attn_weights
400
+
401
+
402
+ def yat_attention(
403
+ query: Array,
404
+ key: Array,
405
+ value: Array,
406
+ bias: Optional[Array] = None,
407
+ mask: Optional[Array] = None,
408
+ broadcast_dropout: bool = True,
409
+ dropout_rng: Optional[Array] = None,
410
+ dropout_rate: float = 0.0,
411
+ deterministic: bool = False,
412
+ dtype: Optional[Dtype] = None,
413
+ precision: PrecisionLike = None,
414
+ module: Optional[Module] = None,
415
+ epsilon: float = 1e-5,
416
+ ):
417
+ """Computes attention using YatNMN distance-based calculation."""
418
+ query, key, value = promote_dtype((query, key, value), dtype=dtype)
419
+ dtype = query.dtype
420
+ assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
421
+ assert (
422
+ query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
423
+ ), 'q, k, v batch dims must match.'
424
+ assert (
425
+ query.shape[-2] == key.shape[-2] == value.shape[-2]
426
+ ), 'q, k, v num_heads must match.'
427
+ assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
428
+
429
+ # compute attention weights using YatNMN
430
+ attn_weights = yat_attention_weights(
431
+ query,
432
+ key,
433
+ bias,
434
+ mask,
435
+ broadcast_dropout,
436
+ dropout_rng,
437
+ dropout_rate,
438
+ deterministic,
439
+ dtype,
440
+ precision,
441
+ module,
442
+ epsilon,
443
+ )
444
+
445
+ # return weighted sum over values for each query position
446
+ return jnp.einsum(
447
+ '...hqk,...khd->...qhd', attn_weights, value, precision=precision
448
+ )
449
+
450
+ Array = jax.Array
451
+
452
+ # Add YatNMN class implementation
453
+ default_bias_init = initializers.zeros_init()
454
+ default_alpha_init = initializers.ones_init()
455
+
456
+ class YatNMN(Module):
457
+ """A linear transformation with custom distance-based computation."""
458
+
459
+ def __init__(
460
+ self,
461
+ in_features: int,
462
+ out_features: int,
463
+ *,
464
+ use_bias: bool = True,
465
+ use_alpha: bool = True,
466
+ dtype: Optional[Dtype] = None,
467
+ param_dtype: Dtype = jnp.float32,
468
+ precision: PrecisionLike = None,
469
+ kernel_init: Initializer = default_kernel_init,
470
+ bias_init: Initializer = default_bias_init,
471
+ alpha_init: Initializer = default_alpha_init,
472
+ dot_general: DotGeneralT = lax.dot_general,
473
+ rngs: rnglib.Rngs,
474
+ epsilon: float = 1e-5,
475
+ ):
476
+
477
+ kernel_key = rngs.params()
478
+ self.kernel = nnx.Param(
479
+ kernel_init(kernel_key, (in_features, out_features), param_dtype)
480
+ )
481
+ self.bias: nnx.Param[jax.Array] | None
482
+ if use_bias:
483
+ bias_key = rngs.params()
484
+ self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
485
+ else:
486
+ self.bias = None
487
+
488
+ self.alpha: nnx.Param[jax.Array] | None
489
+ if use_alpha:
490
+ alpha_key = rngs.params()
491
+ self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
492
+ else:
493
+ self.alpha = None
494
+
495
+ self.in_features = in_features
496
+ self.out_features = out_features
497
+ self.use_bias = use_bias
498
+ self.use_alpha = use_alpha
499
+ self.dtype = dtype
500
+ self.param_dtype = param_dtype
501
+ self.precision = precision
502
+ self.kernel_init = kernel_init
503
+ self.bias_init = bias_init
504
+ self.dot_general = dot_general
505
+ self.epsilon = epsilon
506
+
507
+ def __call__(self, inputs: Array) -> Array:
508
+ """Applies YatNMN transformation to inputs."""
509
+ kernel = self.kernel.value
510
+ bias = self.bias.value if self.bias is not None else None
511
+ alpha = self.alpha.value if self.alpha is not None else None
512
+
513
+ y = self.dot_general(
514
+ inputs,
515
+ kernel,
516
+ (((inputs.ndim - 1,), (0,)), ((), ())),
517
+ precision=self.precision,
518
+ )
519
+
520
+ inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
521
+ kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True)
522
+ distances = inputs_squared_sum + kernel_squared_sum - 2 * y
523
+
524
+ # Element-wise operation
525
+ y = y ** 2 / (distances + self.epsilon)
526
+
527
+ if bias is not None:
528
+ y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
529
+
530
+ if alpha is not None:
531
+ scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
532
+ y = y * scale
533
+
534
+ return y
535
+
536
+
537
+ def dot_product_attention_weights(
538
+ query: Array,
539
+ key: Array,
540
+ bias: Optional[Array] = None,
541
+ mask: Optional[Array] = None,
542
+ broadcast_dropout: bool = True,
543
+ dropout_rng: Optional[Array] = None,
544
+ dropout_rate: float = 0.0,
545
+ deterministic: bool = False,
546
+ dtype: Optional[Dtype] = None,
547
+ precision: PrecisionLike = None,
548
+ module: Optional[Module] = None,
549
+ ):
550
+ """Computes dot-product attention weights given query and key.
551
+
552
+ Used by :func:`dot_product_attention`, which is what you'll most likely use.
553
+ But if you want access to the attention weights for introspection, then
554
+ you can directly call this function and call einsum yourself.
555
+
556
+ Args:
557
+ query: queries for calculating attention with shape of `[batch..., q_length,
558
+ num_heads, qk_depth_per_head]`.
559
+ key: keys for calculating attention with shape of `[batch..., kv_length,
560
+ num_heads, qk_depth_per_head]`.
561
+ bias: bias for the attention weights. This should be broadcastable to the
562
+ shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
563
+ incorporating causal masks, padding masks, proximity bias, etc.
564
+ mask: mask for the attention weights. This should be broadcastable to the
565
+ shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
566
+ incorporating causal masks. Attention weights are masked out if their
567
+ corresponding mask value is `False`.
568
+ broadcast_dropout: bool: use a broadcasted dropout along batch dims.
569
+ dropout_rng: JAX PRNGKey: to be used for dropout
570
+ dropout_rate: dropout rate
571
+ deterministic: bool, deterministic or not (to apply dropout)
572
+ dtype: the dtype of the computation (default: infer from inputs and params)
573
+ precision: numerical precision of the computation see `jax.lax.Precision`
574
+ for details.
575
+ module: the Module that will sow the attention weights into the
576
+ ``nnx.Intermediate`` collection. If ``module`` is None, the attention
577
+ weights will not be sowed.
578
+
579
+ Returns:
580
+ Output of shape `[batch..., num_heads, q_length, kv_length]`.
581
+ """
582
+ query, key = promote_dtype((query, key), dtype=dtype) # type: ignore[bad-unpacking]
583
+ dtype = query.dtype
584
+
585
+ assert query.ndim == key.ndim, 'q, k must have same rank.'
586
+ assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
587
+ assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
588
+ assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
589
+
590
+ # calculate attention matrix
591
+ depth = query.shape[-1]
592
+ query = query / jnp.sqrt(depth).astype(dtype)
593
+ # attn weight shape is (batch..., num_heads, q_length, kv_length)
594
+ attn_weights = jnp.einsum(
595
+ '...qhd,...khd->...hqk', query, key, precision=precision
596
+ )
597
+
598
+ # apply attention bias: masking, dropout, proximity bias, etc.
599
+ if bias is not None:
600
+ attn_weights = attn_weights + bias
601
+ # apply attention mask
602
+ if mask is not None:
603
+ big_neg = jnp.finfo(dtype).min
604
+ attn_weights = jnp.where(mask, attn_weights, big_neg)
605
+
606
+ # normalize the attention weights
607
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
608
+
609
+ if module:
610
+ module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
611
+
612
+ # apply attention dropout
613
+ if not deterministic and dropout_rate > 0.0:
614
+ keep_prob = 1.0 - dropout_rate
615
+ if broadcast_dropout:
616
+ # dropout is broadcast across the batch + head dimensions
617
+ dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
618
+ keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore
619
+ else:
620
+ keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore
621
+ multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
622
+ attn_weights = attn_weights * multiplier
623
+
624
+ return attn_weights
625
+
626
+
627
+ def dot_product_attention(
628
+ query: Array,
629
+ key: Array,
630
+ value: Array,
631
+ bias: Optional[Array] = None,
632
+ mask: Optional[Array] = None,
633
+ broadcast_dropout: bool = True,
634
+ dropout_rng: Optional[Array] = None,
635
+ dropout_rate: float = 0.0,
636
+ deterministic: bool = False,
637
+ dtype: Optional[Dtype] = None,
638
+ precision: PrecisionLike = None,
639
+ module: Optional[Module] = None,
640
+ ):
641
+ """Computes dot-product attention given query, key, and value.
642
+
643
+ This is the core function for applying attention based on
644
+ https://arxiv.org/abs/1706.03762. It calculates the attention weights given
645
+ query and key and combines the values using the attention weights.
646
+
647
+ .. note::
648
+ ``query``, ``key``, ``value`` needn't have any batch dimensions.
649
+
650
+ Args:
651
+ query: queries for calculating attention with shape of ``[batch..., q_length,
652
+ num_heads, qk_depth_per_head]``.
653
+ key: keys for calculating attention with shape of ``[batch..., kv_length,
654
+ num_heads, qk_depth_per_head]``.
655
+ value: values to be used in attention with shape of ``[batch..., kv_length,
656
+ num_heads, v_depth_per_head]``.
657
+ bias: bias for the attention weights. This should be broadcastable to the
658
+ shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
659
+ incorporating causal masks, padding masks, proximity bias, etc.
660
+ mask: mask for the attention weights. This should be broadcastable to the
661
+ shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
662
+ incorporating causal masks. Attention weights are masked out if their
663
+ corresponding mask value is `False`.
664
+ broadcast_dropout: bool: use a broadcasted dropout along batch dims.
665
+ dropout_rng: JAX PRNGKey: to be used for dropout
666
+ dropout_rate: dropout rate
667
+ deterministic: bool, deterministic or not (to apply dropout)
668
+ dtype: the dtype of the computation (default: infer from inputs)
669
+ precision: numerical precision of the computation see `jax.lax.Precision`
670
+ for details.
671
+ module: the Module that will sow the attention weights into the
672
+ ``nnx.Intermediate`` collection. If ``module`` is None, the attention
673
+ weights will not be sowed.
674
+
675
+ Returns:
676
+ Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
677
+ """
678
+ query, key, value = promote_dtype((query, key, value), dtype=dtype) # type: ignore[bad-unpacking]
679
+ dtype = query.dtype
680
+ assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
681
+ assert (
682
+ query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
683
+ ), 'q, k, v batch dims must match.'
684
+ assert (
685
+ query.shape[-2] == key.shape[-2] == value.shape[-2]
686
+ ), 'q, k, v num_heads must match.'
687
+ assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
688
+
689
+ # compute attention weights
690
+ attn_weights = dot_product_attention_weights(
691
+ query,
692
+ key,
693
+ bias,
694
+ mask,
695
+ broadcast_dropout,
696
+ dropout_rng,
697
+ dropout_rate,
698
+ deterministic,
699
+ dtype,
700
+ precision,
701
+ module,
702
+ )
703
+
704
+ # return weighted sum over values for each query position
705
+ return jnp.einsum(
706
+ '...hqk,...khd->...qhd', attn_weights, value, precision=precision
707
+ )
708
+
709
+
710
+ class MultiHeadAttention(Module):
711
+ """Multi-head attention.
712
+
713
+ Example usage::
714
+
715
+ >>> import flax.linen as nn
716
+ >>> import jax
717
+
718
+ >>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16)
719
+ >>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
720
+ >>> shape = (4, 3, 2, 5)
721
+ >>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
722
+ >>> variables = layer.init(jax.random.key(0), q)
723
+
724
+ >>> # different inputs for inputs_q, inputs_k and inputs_v
725
+ >>> out = layer.apply(variables, q, k, v)
726
+ >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
727
+ >>> out = layer.apply(variables, q, k)
728
+ >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
729
+ >>> out = layer.apply(variables, q)
730
+
731
+ >>> attention_kwargs = dict(
732
+ ... num_heads=8,
733
+ ... qkv_features=16,
734
+ ... kernel_init=nn.initializers.ones,
735
+ ... bias_init=nn.initializers.zeros,
736
+ ... dropout_rate=0.5,
737
+ ... deterministic=False,
738
+ ... )
739
+ >>> class Module(nn.Module):
740
+ ... attention_kwargs: dict
741
+ ...
742
+ ... @nn.compact
743
+ ... def __call__(self, x, dropout_rng=None):
744
+ ... out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
745
+ ... out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
746
+ ... return out1, out2
747
+ >>> module = Module(attention_kwargs)
748
+ >>> variables = module.init({'params': key1, 'dropout': key2}, q)
749
+
750
+ >>> # out1 and out2 are different.
751
+ >>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
752
+ >>> # out3 and out4 are different.
753
+ >>> # out1 and out3 are different. out2 and out4 are different.
754
+ >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
755
+ >>> # out1 and out2 are the same.
756
+ >>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
757
+ >>> # out1 and out2 are the same as out3 and out4.
758
+ >>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
759
+ >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
760
+
761
+ Attributes:
762
+ num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
763
+ should be divisible by the number of heads.
764
+ dtype: the dtype of the computation (default: infer from inputs and params)
765
+ param_dtype: the dtype passed to parameter initializers (default: float32)
766
+ qkv_features: dimension of the key, query, and value.
767
+ out_features: dimension of the last projection
768
+ broadcast_dropout: bool: use a broadcasted dropout along batch dims.
769
+ dropout_rate: dropout rate
770
+ deterministic: if false, the attention weight is masked randomly using
771
+ dropout, whereas if true, the attention weights are deterministic.
772
+ precision: numerical precision of the computation see `jax.lax.Precision`
773
+ for details.
774
+ kernel_init: initializer for the kernel of the Dense layers.
775
+ out_kernel_init: optional initializer for the kernel of the output Dense layer,
776
+ if None, the kernel_init is used.
777
+ bias_init: initializer for the bias of the Dense layers.
778
+ out_bias_init: optional initializer for the bias of the output Dense layer,
779
+ if None, the bias_init is used.
780
+ use_bias: bool: whether pointwise QKVO dense transforms use bias.
781
+ attention_fn: dot_product_attention or compatible function. Accepts query,
782
+ key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,,
783
+ num_heads, value_channels]``
784
+ decode: whether to prepare and use an autoregressive cache.
785
+ normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442).
786
+ """
787
+
788
+ def __init__(
789
+ self,
790
+ num_heads: int,
791
+ in_features: int,
792
+ qkv_features: int | None = None,
793
+ out_features: int | None = None,
794
+ *,
795
+ dtype: Dtype | None = None,
796
+ param_dtype: Dtype = jnp.float32,
797
+ broadcast_dropout: bool = True,
798
+ dropout_rate: float = 0.0,
799
+ deterministic: bool | None = None,
800
+ precision: PrecisionLike = None,
801
+ kernel_init: Initializer = default_kernel_init,
802
+ out_kernel_init: Initializer | None = None,
803
+ bias_init: Initializer = initializers.zeros_init(),
804
+ out_bias_init: Initializer | None = None,
805
+ use_bias: bool = True,
806
+ attention_fn: Callable[..., Array] = yat_attention,
807
+ decode: bool | None = None,
808
+ normalize_qk: bool = False,
809
+ # Deprecated, will be removed.
810
+ qkv_dot_general: DotGeneralT | None = None,
811
+ out_dot_general: DotGeneralT | None = None,
812
+ qkv_dot_general_cls: Any = None,
813
+ out_dot_general_cls: Any = None,
814
+ rngs: rnglib.Rngs,
815
+ epsilon: float = 1e-5,
816
+ ):
817
+ self.num_heads = num_heads
818
+ self.in_features = in_features
819
+ self.qkv_features = (
820
+ qkv_features if qkv_features is not None else in_features
821
+ )
822
+ self.out_features = (
823
+ out_features if out_features is not None else in_features
824
+ )
825
+ self.dtype = dtype
826
+ self.param_dtype = param_dtype
827
+ self.broadcast_dropout = broadcast_dropout
828
+ self.dropout_rate = dropout_rate
829
+ self.deterministic = deterministic
830
+ self.precision = precision
831
+ self.kernel_init = kernel_init
832
+ self.out_kernel_init = out_kernel_init
833
+ self.bias_init = bias_init
834
+ self.out_bias_init = out_bias_init
835
+ self.use_bias = use_bias
836
+ self.attention_fn = attention_fn
837
+ self.decode = decode
838
+ self.normalize_qk = normalize_qk
839
+ self.qkv_dot_general = qkv_dot_general
840
+ self.out_dot_general = out_dot_general
841
+ self.qkv_dot_general_cls = qkv_dot_general_cls
842
+ self.out_dot_general_cls = out_dot_general_cls
843
+ self.epsilon = epsilon
844
+
845
+ if self.qkv_features % self.num_heads != 0:
846
+ raise ValueError(
847
+ f'Memory dimension ({self.qkv_features}) must be divisible by '
848
+ f"'num_heads' heads ({self.num_heads})."
849
+ )
850
+
851
+ self.head_dim = self.qkv_features // self.num_heads
852
+
853
+ # Replace LinearGeneral with YatNMN for query, key, value projections
854
+ yat_linear = functools.partial(
855
+ YatNMN,
856
+ in_features=self.in_features,
857
+ out_features=self.qkv_features, # Output total features, will reshape later
858
+ dtype=self.dtype,
859
+ param_dtype=self.param_dtype,
860
+ kernel_init=self.kernel_init,
861
+ bias_init=self.bias_init,
862
+ use_bias=self.use_bias,
863
+ precision=self.precision,
864
+ epsilon=self.epsilon,
865
+ )
866
+
867
+ # project inputs_q to multi-headed q/k/v
868
+ # dimensions will be reshaped to [batch..., length, n_heads, n_features_per_head]
869
+ self.query = yat_linear(rngs=rngs)
870
+ self.key = yat_linear(rngs=rngs)
871
+ self.value = yat_linear(rngs=rngs)
872
+
873
+ self.query_ln: LayerNorm | None
874
+ self.key_ln: LayerNorm | None
875
+ if self.normalize_qk:
876
+ # Normalizing query and key projections stabilizes training with higher
877
+ # LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
878
+ self.query_ln = LayerNorm(
879
+ self.head_dim,
880
+ use_bias=False,
881
+ dtype=self.dtype,
882
+ param_dtype=self.param_dtype,
883
+ rngs=rngs,
884
+ )
885
+ self.key_ln = LayerNorm(
886
+ self.head_dim,
887
+ use_bias=False,
888
+ dtype=self.dtype,
889
+ param_dtype=self.param_dtype,
890
+ rngs=rngs,
891
+ )
892
+ else:
893
+ self.query_ln = None
894
+ self.key_ln = None
895
+
896
+ # Remove the output layer - no more self.out
897
+ self.rngs = rngs if dropout_rate > 0.0 else None
898
+
899
+ self.cached_key: nnx.Cache[Array] | None = None
900
+ self.cached_value: nnx.Cache[Array] | None = None
901
+ self.cache_index: nnx.Cache[Array] | None = None
902
+
903
+ def __call__(
904
+ self,
905
+ inputs_q: Array,
906
+ inputs_k: Array | None = None,
907
+ inputs_v: Array | None = None,
908
+ *,
909
+ mask: Array | None = None,
910
+ deterministic: bool | None = None,
911
+ rngs: rnglib.Rngs | None = None,
912
+ sow_weights: bool = False,
913
+ decode: bool | None = None,
914
+ ):
915
+ """Applies multi-head dot product attention on the input data.
916
+
917
+ Projects the inputs into multi-headed query, key, and value vectors,
918
+ applies dot-product attention and project the results to an output vector.
919
+
920
+ If both inputs_k and inputs_v are None, they will both copy the value of
921
+ inputs_q (self attention).
922
+ If only inputs_v is None, it will copy the value of inputs_k.
923
+
924
+ Args:
925
+ inputs_q: input queries of shape `[batch_sizes..., length, features]`.
926
+ inputs_k: key of shape `[batch_sizes..., length, features]`. If None,
927
+ inputs_k will copy the value of inputs_q.
928
+ inputs_v: values of shape `[batch_sizes..., length, features]`. If None,
929
+ inputs_v will copy the value of inputs_k.
930
+ mask: attention mask of shape `[batch_sizes..., num_heads, query_length,
931
+ key/value_length]`. Attention weights are masked out if their
932
+ corresponding mask value is `False`.
933
+ deterministic: if false, the attention weight is masked randomly using
934
+ dropout, whereas if true, the attention weights are deterministic.
935
+ rngs: container for random number generators to generate the dropout
936
+ mask when `deterministic` is False. The `rngs` container should have a
937
+ `dropout` key.
938
+ sow_weights: if ``True``, the attention weights are sowed into the
939
+ 'intermediates' collection.
940
+
941
+ Returns:
942
+ output of shape `[batch_sizes..., length, features]`.
943
+ """
944
+ if rngs is None:
945
+ rngs = self.rngs
946
+
947
+ if inputs_k is None:
948
+ if inputs_v is not None:
949
+ raise ValueError(
950
+ '`inputs_k` cannot be None if `inputs_v` is not None. '
951
+ 'To have both `inputs_k` and `inputs_v` be the same value, pass in the '
952
+ 'value to `inputs_k` and leave `inputs_v` as None.'
953
+ )
954
+ inputs_k = inputs_q
955
+ if inputs_v is None:
956
+ inputs_v = inputs_k
957
+
958
+ if inputs_q.shape[-1] != self.in_features:
959
+ raise ValueError(
960
+ f'Incompatible input dimension, got {inputs_q.shape[-1]} '
961
+ f'but module expects {self.in_features}.'
962
+ )
963
+
964
+ # Apply YatNMN transformations and reshape to multi-head format
965
+ query = self.query(inputs_q)
966
+ key = self.key(inputs_k)
967
+ value = self.value(inputs_v)
968
+
969
+ # Reshape from [batch..., length, qkv_features] to [batch..., length, num_heads, head_dim]
970
+ query = query.reshape(query.shape[:-1] + (self.num_heads, self.head_dim))
971
+ key = key.reshape(key.shape[:-1] + (self.num_heads, self.head_dim))
972
+ value = value.reshape(value.shape[:-1] + (self.num_heads, self.head_dim))
973
+
974
+ if self.normalize_qk:
975
+ assert self.query_ln is not None and self.key_ln is not None
976
+ # Normalizing query and key projections stabilizes training with higher
977
+ # LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
978
+ query = self.query_ln(query)
979
+ key = self.key_ln(key)
980
+
981
+ # During fast autoregressive decoding, we feed one position at a time,
982
+ # and cache the keys and values step by step.
983
+ decode = first_from(
984
+ decode,
985
+ self.decode,
986
+ error_msg="""No `decode` argument was provided to MultiHeadAttention
987
+ as either a __call__ argument, class attribute, or nnx.flag.""",
988
+ )
989
+
990
+ if decode:
991
+ if (
992
+ self.cached_key is None
993
+ or self.cached_value is None
994
+ or self.cache_index is None
995
+ ):
996
+ raise ValueError(
997
+ 'Autoregressive cache not initialized, call ``init_cache`` first.'
998
+ )
999
+ (
1000
+ *batch_dims,
1001
+ max_length,
1002
+ num_heads,
1003
+ depth_per_head,
1004
+ ) = self.cached_key.value.shape
1005
+ # shape check of cached keys against query input
1006
+ expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
1007
+ if expected_shape != query.shape:
1008
+ raise ValueError(
1009
+ 'Autoregressive cache shape error, '
1010
+ 'expected query shape %s instead got %s.'
1011
+ # % (expected_shape, query.shape)
1012
+ )
1013
+ # update key, value caches with our new 1d spatial slices
1014
+ cur_index = self.cache_index.value
1015
+ zero = jnp.array(0, dtype=lax.dtype(cur_index.dtype))
1016
+ indices = (zero,) * len(batch_dims) + (cur_index, zero, zero)
1017
+ key = lax.dynamic_update_slice(self.cached_key.value, key, indices)
1018
+ value = lax.dynamic_update_slice(self.cached_value.value, value, indices)
1019
+ self.cached_key.value = key
1020
+ self.cached_value.value = value
1021
+ self.cache_index.value += 1
1022
+ # causal mask for cached decoder self-attention:
1023
+ # our single query position should only attend to those key
1024
+ # positions that have already been generated and cached,
1025
+ # not the remaining zero elements.
1026
+ mask = combine_masks(
1027
+ mask,
1028
+ jnp.broadcast_to(
1029
+ jnp.arange(max_length) <= cur_index,
1030
+ tuple(batch_dims) + (1, 1, max_length),
1031
+ ),
1032
+ )
1033
+
1034
+ if (
1035
+ self.dropout_rate > 0.0
1036
+ ): # Require `deterministic` only if using dropout.
1037
+ deterministic = first_from(
1038
+ deterministic,
1039
+ self.deterministic,
1040
+ error_msg="""No `deterministic` argument was provided to MultiHeadAttention
1041
+ as either a __call__ argument, class attribute, or nnx.flag.""",
1042
+ )
1043
+ if not deterministic:
1044
+ if rngs is None:
1045
+ raise ValueError(
1046
+ "'rngs' must be provided if 'dropout_rng' is not given."
1047
+ )
1048
+ dropout_rng = rngs.dropout()
1049
+ else:
1050
+ dropout_rng = None
1051
+ else:
1052
+ deterministic = True
1053
+ dropout_rng = None
1054
+
1055
+ # apply attention with epsilon parameter for YatNMN
1056
+ x = self.attention_fn(
1057
+ query,
1058
+ key,
1059
+ value,
1060
+ mask=mask,
1061
+ dropout_rng=dropout_rng,
1062
+ dropout_rate=self.dropout_rate,
1063
+ broadcast_dropout=self.broadcast_dropout,
1064
+ deterministic=deterministic,
1065
+ dtype=self.dtype,
1066
+ precision=self.precision,
1067
+ module=self if sow_weights else None,
1068
+ epsilon=self.epsilon, # Pass epsilon to yat_attention
1069
+ )
1070
+ # Reshape attention output back to original embedding dimension
1071
+ # from [batch..., length, num_heads, head_dim] to [batch..., length, qkv_features]
1072
+ x = x.reshape(x.shape[:-2] + (self.qkv_features,))
1073
+ return x
1074
+
1075
+ def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32):
1076
+ """Initializes cache for fast autoregressive decoding. When
1077
+ ``decode=True``, this method must be called first before performing
1078
+ forward inference.
1079
+
1080
+ Example usage::
1081
+
1082
+ >>> from flax import nnx
1083
+ >>> import jax.numpy as jnp
1084
+ ...
1085
+ >>> rngs = nnx.Rngs(42)
1086
+ ...
1087
+ >>> x = jnp.ones((1, 3))
1088
+ >>> model_nnx = nnx.MultiHeadAttention(
1089
+ ... num_heads=2,
1090
+ ... in_features=3,
1091
+ ... qkv_features=6,
1092
+ ... out_features=6,
1093
+ ... decode=True,
1094
+ ... rngs=rngs,
1095
+ ... )
1096
+ ...
1097
+ >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized
1098
+ ...
1099
+ >>> model_nnx.init_cache(x.shape)
1100
+ >>> out_nnx = model_nnx(x)
1101
+ """
1102
+ cache_shape = (*input_shape[:-1], self.num_heads, self.head_dim)
1103
+ self.cached_key = nnx.Cache(jnp.zeros(cache_shape, dtype))
1104
+ self.cached_value = nnx.Cache(jnp.zeros(cache_shape, dtype))
1105
+ self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32))
1106
+
1107
+
1108
+ # mask-making utility functions
1109
+
1110
+
1111
+ def make_attention_mask(
1112
+ query_input: Array,
1113
+ key_input: Array,
1114
+ pairwise_fn: Callable[..., Any] = jnp.multiply,
1115
+ extra_batch_dims: int = 0,
1116
+ dtype: Dtype = jnp.float32,
1117
+ ):
1118
+ """Mask-making helper for attention weights.
1119
+
1120
+ In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the
1121
+ attention weights will be `[batch..., heads, len_q, len_kv]` and this
1122
+ function will produce `[batch..., 1, len_q, len_kv]`.
1123
+
1124
+ Args:
1125
+ query_input: a batched, flat input of query_length size
1126
+ key_input: a batched, flat input of key_length size
1127
+ pairwise_fn: broadcasting elementwise comparison function
1128
+ extra_batch_dims: number of extra batch dims to add singleton axes for, none
1129
+ by default
1130
+ dtype: mask return dtype
1131
+
1132
+ Returns:
1133
+ A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention.
1134
+ """
1135
+ mask = pairwise_fn(
1136
+ jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2)
1137
+ )
1138
+ mask = jnp.expand_dims(mask, axis=-3)
1139
+ mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
1140
+ return mask.astype(dtype)
1141
+
1142
+
1143
+ def make_causal_mask(
1144
+ x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32
1145
+ ) -> Array:
1146
+ """Make a causal mask for self-attention.
1147
+
1148
+ In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights
1149
+ will be `[batch..., heads, len, len]` and this function will produce a
1150
+ causal mask of shape `[batch..., 1, len, len]`.
1151
+
1152
+ Args:
1153
+ x: input array of shape `[batch..., len]`
1154
+ extra_batch_dims: number of batch dims to add singleton axes for, none by
1155
+ default
1156
+ dtype: mask return dtype
1157
+
1158
+ Returns:
1159
+ A `[batch..., 1, len, len]` shaped causal mask for 1d attention.
1160
+ """
1161
+ idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
1162
+ return make_attention_mask(
1163
+ idxs,
1164
+ idxs,
1165
+ jnp.greater_equal,
1166
+ extra_batch_dims=extra_batch_dims,
1167
+ dtype=dtype,
1168
+ )
1169
+
1170
+
1171
+ def combine_masks(
1172
+ *masks: Optional[Array], dtype: Dtype = jnp.float32
1173
+ ) -> Array | None:
1174
+ """Combine attention masks.
1175
+
1176
+ Args:
1177
+ *masks: set of attention mask arguments to combine, some can be None.
1178
+ dtype: dtype for the returned mask.
1179
+
1180
+ Returns:
1181
+ Combined mask, reduced by logical and, returns None if no masks given.
1182
+ """
1183
+ masks_list = [m for m in masks if m is not None]
1184
+ if not masks_list:
1185
+ return None
1186
+ assert all(
1187
+ map(lambda x: x.ndim == masks_list[0].ndim, masks_list)
1188
+ ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}'
1189
+ mask, *other_masks = masks_list
1190
+ for other_mask in other_masks:
1191
+ mask = jnp.logical_and(mask, other_mask)
1192
+ return mask.astype(dtype)
1193
+
1194
+
1195
+
1196
+ # Define a triangular mask for causal attention with `jax.numpy.tril` and `jax.numpy.ones`.
1197
+ def causal_attention_mask(seq_len):
1198
+ return jnp.tril(jnp.ones((seq_len, seq_len)))
1199
+
1200
+ class TransformerBlock(nnx.Module):
1201
+ """ A single Transformer block.
1202
+
1203
+ Each Transformer block processes input sequences via self-attention and feed-forward networks.
1204
+
1205
+ Args:
1206
+ embed_dim (int): Embedding dimensionality.
1207
+ num_heads (int): Number of attention heads.
1208
+ ff_dim (int): Dimensionality of the feed-forward network.
1209
+ rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.
1210
+ rate (float): Dropout rate. Defaults to 0.1.
1211
+ """
1212
+ def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, *, rngs: nnx.Rngs, rate: float = 0.1):
1213
+ # Multi-Head Attention (MHA) with `flax.nnx.MultiHeadAttention`.
1214
+ # Specifies tensor sharding (depending on the mesh configuration)
1215
+ # where we shard the weights across devices for parallel computation.
1216
+ self.mha = MultiHeadAttention(num_heads=num_heads,
1217
+ in_features=embed_dim,
1218
+ kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
1219
+ bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
1220
+ rngs=rngs)
1221
+ # The first dropout with `flax.nnx.Dropout`.
1222
+ self.dropout1 = nnx.Dropout(rate=rate)
1223
+ # The first linear transformation for the feed-forward network with `flax.nnx.Linear`.
1224
+ self.nonlinear1 = YatNMN(in_features=embed_dim,
1225
+ out_features=embed_dim,
1226
+ kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
1227
+ bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
1228
+ alpha_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))),
1229
+ rngs=rngs)
1230
+ # The second dropout with `flax.nnx.Dropout`.
1231
+ self.dropout2 = nnx.Dropout(rate=rate)
1232
+
1233
+
1234
+ # Apply the Transformer block to the input sequence.
1235
+ def __call__(self, inputs, training: bool = False):
1236
+ input_shape = inputs.shape
1237
+ _, seq_len, _ = input_shape
1238
+
1239
+ # Instantiate the causal attention mask.
1240
+ mask = causal_attention_mask(seq_len)
1241
+
1242
+ # Apply Multi-Head Attention with the causal attention mask.
1243
+ attention_output = self.mha(
1244
+ inputs_q=inputs,
1245
+ mask=mask,
1246
+ decode=False
1247
+ )
1248
+ # Apply the first dropout.
1249
+ attention_output = self.dropout1(attention_output, deterministic=not training)
1250
+ # Apply the first layer normalization.
1251
+ out1 = inputs + attention_output
1252
+
1253
+ # The feed-forward network.
1254
+ # Apply the first linear transformation.
1255
+ ffn_output = self.nonlinear1(out1)
1256
+ # Apply the second dropout.
1257
+ ffn_output = self.dropout2(ffn_output, deterministic=not training)
1258
+ # Apply the second layer normalization and return the output of the Transformer block.
1259
+ return out1 + ffn_output
1260
+
1261
+ class TokenAndPositionEmbedding(nnx.Module):
1262
+ """ Combines token embeddings (words in an input sentence) with
1263
+ positional embeddings (the position of each word in a sentence).
1264
+
1265
+ Args:
1266
+ maxlen (int): Matimum sequence length.
1267
+ vocal_size (int): Vocabulary size.
1268
+ embed_dim (int): Embedding dimensionality.
1269
+ rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.
1270
+ """
1271
+ def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, *, rngs: nnx.Rngs):
1272
+ # Initialize token embeddings (using `flax.nnx.Embed`).
1273
+ # Each unique word has an embedding vector.
1274
+ self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)
1275
+ # Initialize positional embeddings (using `flax.nnx.Embed`).
1276
+ self.pos_emb = nnx.Embed(num_embeddings=maxlen, features=embed_dim, rngs=rngs)
1277
+
1278
+ # Takes a token sequence (integers) and returns the combined token and positional embeddings.
1279
+ def __call__(self, x):
1280
+ # Generate a sequence of positions for the input tokens.
1281
+ positions = jnp.arange(0, x.shape[1])[None, :]
1282
+ # Look up the positional embeddings for each position in the input sequence.
1283
+ position_embedding = self.pos_emb(positions)
1284
+ # Look up the token embeddings for each token in the input sequence.
1285
+ token_embedding = self.token_emb(x)
1286
+ # Combine token and positional embeddings.
1287
+ return token_embedding + position_embedding
1288
+
1289
+ class MiniGPT(nnx.Module):
1290
+ """ A miniGPT transformer model, inherits from `flax.nnx.Module`.
1291
+
1292
+ Args:
1293
+ maxlen (int): Maximum sequence length.
1294
+ vocab_size (int): Vocabulary size.
1295
+ embed_dim (int): Embedding dimensionality.
1296
+ num_heads (int): Number of attention heads.
1297
+ feed_forward_dim (int): Dimensionality of the feed-forward network.
1298
+ num_transformer_blocks (int): Number of transformer blocks. Each block contains attention and feed-forward networks.
1299
+ rngs (nnx.Rngs): A Flax NNX stream of JAX PRNG keys.
1300
+ """
1301
+ # Initialize miniGPT model components.
1302
+ def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, num_heads: int, feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs):
1303
+ # Initiliaze the `TokenAndPositionEmbedding` that combines token and positional embeddings.
1304
+ self.embedding_layer = TokenAndPositionEmbedding(
1305
+ maxlen, vocab_size, embed_dim, rngs=rngs
1306
+ )
1307
+ # Create a list of `TransformerBlock` instances.
1308
+ # Each block processes input sequences using attention and feed-forward networks.
1309
+ self.transformer_blocks = [TransformerBlock(
1310
+ embed_dim, num_heads, feed_forward_dim, rngs=rngs
1311
+ ) for _ in range(num_transformer_blocks)]
1312
+ # Initialize the output `flax.nnx.Linear` layer producing logits over the vocabulary for next-token prediction.
1313
+ self.output_layer = YatNMN(in_features=embed_dim,
1314
+ out_features=vocab_size,
1315
+ kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
1316
+ bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),
1317
+ alpha_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))),
1318
+ use_bias=False,
1319
+ rngs=rngs,
1320
+ )
1321
+
1322
+ def __call__(self, inputs, training: bool = False):
1323
+ # Pass the input tokens through the `embedding_layer` to get token embeddings.
1324
+ # Apply each transformer block sequentially to the embedded input, use the `training` flag for the behavior of `flax.nnx.Dropout`.
1325
+ x = self.embedding_layer(inputs)
1326
+ for transformer_block in self.transformer_blocks:
1327
+ x = transformer_block(x, training=training)
1328
+ # Pass the output of the transformer blocks through the output layer,
1329
+ # and obtain logits for each token in the vocabulary (for next token prediction).
1330
+ outputs = self.output_layer(x)
1331
+ return outputs
1332
+
1333
+ # Text generation.
1334
+ def generate_text(self, max_tokens: int, start_tokens: [int], top_k=10):
1335
+ # Sample the next token from a probability distribution based on
1336
+ # `logits` and `tok_k` (top-k) sampling strategy.
1337
+ def sample_from(logits):
1338
+ logits, indices = jax.lax.top_k(logits, k=top_k)
1339
+ # Convert logits to probabilities (using `flax.nnx.softmax`).
1340
+ logits = nnx.softmax(logits)
1341
+ return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)
1342
+
1343
+ # Generate text one token at a time until the maximum token limit is reached (`maxlen`).
1344
+ def generate_step(start_tokens):
1345
+ pad_len = maxlen - len(start_tokens)
1346
+ # Index of the last token in the current sequence.
1347
+ sample_index = len(start_tokens) - 1
1348
+ # If the input is longer than `maxlen`, then truncate it.
1349
+ if pad_len < 0:
1350
+ x = jnp.array(start_tokens[:maxlen])
1351
+ sample_index = maxlen - 1
1352
+ # If the input is shorter than `maxlen`, then pad it (`pad_len`).
1353
+ elif pad_len > 0:
1354
+ x = jnp.array(start_tokens + [0] * pad_len)
1355
+ else:
1356
+ x = jnp.array(start_tokens)
1357
+
1358
+ # Add a batch dimension.
1359
+ x = x[None, :]
1360
+ logits = self(x)
1361
+ next_token = sample_from(logits[0][sample_index])
1362
+ return next_token
1363
+
1364
+ # Store generated tokens.
1365
+ generated = []
1366
+ # Generate tokens until the end-of-text token is encountered or the maximum token limit is reached.
1367
+ for _ in range(max_tokens):
1368
+ next_token = generate_step(start_tokens + generated)
1369
+ # Truncate whatever is after '<|endoftext|>' (stop word)
1370
+ if next_token == tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]:
1371
+ # Stop text generation if the end-of-text token is encountered.
1372
+ break
1373
+ generated.append(int(next_token))
1374
+ # Decode the generated token IDs into text.
1375
+ return tokenizer.decode(start_tokens + generated)
1376
+
1377
+ # Creates the miniGPT model with 4 transformer blocks.
1378
+ def create_model(rngs):
1379
+ return MiniGPT(maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim, num_transformer_blocks=4, rngs=rngs)
1380
+
1381
+ vocab_size = tokenizer.n_vocab
1382
+ num_transformer_blocks = 3
1383
+ maxlen = 1024
1384
+ embed_dim = 512
1385
+ num_heads = 8
1386
+ feed_forward_dim = 512
1387
+ batch_size = 64 # You can set a bigger batch size if you use Kaggle's Cloud TPU.
1388
+ num_epochs = 5
1389
+
1390
+ """## Loading and preprocessing the data
1391
+
1392
+ Enhanced data loading with support for multiple Hugging Face datasets and local files.
1393
+ """
1394
+
1395
+ # Configuration for dataset selection
1396
+ CURRENT_DATASET = "fineweb" # Change this to use different datasets
1397
+ # Available options: tinystories, wikitext, openwebtext, bookscorpus, c4,
1398
+ # tiny_shakespeare, gutenberg, pile, common_crawl, local_file
1399
+
1400
+ @dataclass
1401
+ class EnhancedTextDataset:
1402
+ """Enhanced TextDataset with better preprocessing and flexible tokenization"""
1403
+ data: list
1404
+ maxlen: int
1405
+ tokenizer: any = tokenizer
1406
+ separator_token: str = "<|endoftext|>"
1407
+
1408
+ def __len__(self):
1409
+ return len(self.data)
1410
+
1411
+ def __getitem__(self, idx: int):
1412
+ text = self.data[idx]
1413
+ # Use Tiktoken for tokenization with proper handling of special tokens
1414
+ try:
1415
+ encoding = self.tokenizer.encode(
1416
+ text,
1417
+ allowed_special={self.separator_token}
1418
+ )[:self.maxlen]
1419
+ except Exception:
1420
+ # Fallback for texts without special tokens
1421
+ encoding = self.tokenizer.encode(text)[:self.maxlen]
1422
+
1423
+ # Pad to maxlen
1424
+ return encoding + [0] * (self.maxlen - len(encoding))
1425
+
1426
+ def load_and_preprocess_data_enhanced(
1427
+ dataset_name: str = CURRENT_DATASET,
1428
+ batch_size: int = batch_size,
1429
+ maxlen: int = maxlen,
1430
+ custom_config: Optional[DatasetConfig] = None
1431
+ ) -> pygrain.DataLoader:
1432
+ """
1433
+ Enhanced data loading function that supports multiple datasets
1434
+
1435
+ Args:
1436
+ dataset_name: Name of the dataset configuration to use
1437
+ batch_size: Batch size for data loading
1438
+ maxlen: Maximum sequence length
1439
+ custom_config: Custom dataset configuration (overrides dataset_name)
1440
+
1441
+ Returns:
1442
+ pygrain.DataLoader: Configured data loader
1443
+ """
1444
+
1445
+ print(f"=== Dataset Loading ===")
1446
+ print(f"Requested dataset: {dataset_name}")
1447
+
1448
+ # Use custom config if provided, otherwise get predefined config
1449
+ if custom_config:
1450
+ config = custom_config
1451
+ print("Using custom dataset configuration")
1452
+ elif dataset_name in DATASET_CONFIGS:
1453
+ config = DATASET_CONFIGS[dataset_name]
1454
+ print(f"Using predefined configuration for {dataset_name}")
1455
+ else:
1456
+ print(f"Dataset '{dataset_name}' not found. Available datasets:")
1457
+ for name in list_available_datasets():
1458
+ info = get_dataset_info(name)
1459
+ print(f" - {name}: {info.get('name', 'N/A')}")
1460
+
1461
+ # Fallback to local file
1462
+ print("Falling back to local file (TinyStories)")
1463
+ config = DATASET_CONFIGS["local_file"]
1464
+
1465
+ # Load the dataset
1466
+ texts = []
1467
+
1468
+ if config.name == "local" or config.file_path:
1469
+ # Load from local file
1470
+ texts = load_local_file(config)
1471
+ else:
1472
+ # Try to load from Hugging Face
1473
+ if HF_DATASETS_AVAILABLE:
1474
+ texts = load_huggingface_dataset(config)
1475
+
1476
+ # If HF loading failed or not available, try local fallback
1477
+ if not texts and dataset_name != "local_file":
1478
+ print("Attempting to load from local TinyStories file as fallback...")
1479
+ fallback_config = DATASET_CONFIGS["local_file"]
1480
+ texts = load_local_file(fallback_config)
1481
+
1482
+ if not texts:
1483
+ raise RuntimeError(f"Failed to load any data for dataset: {dataset_name}")
1484
+
1485
+ # Create enhanced dataset
1486
+ dataset = EnhancedTextDataset(
1487
+ data=texts,
1488
+ maxlen=maxlen,
1489
+ separator_token=config.separator
1490
+ )
1491
+
1492
+ print(f"Dataset created successfully:")
1493
+ print(f" - Total samples: {len(dataset)}")
1494
+ print(f" - Max sequence length: {maxlen}")
1495
+ print(f" - Separator token: {config.separator}")
1496
+
1497
+ # Create sampler and data loader
1498
+ sampler = pygrain.IndexSampler(
1499
+ len(dataset),
1500
+ shuffle=True, # Enable shuffling for better training
1501
+ seed=42,
1502
+ shard_options=pygrain.NoSharding(),
1503
+ num_epochs=num_epochs,
1504
+ )
1505
+
1506
+ dl = pygrain.DataLoader(
1507
+ data_source=dataset,
1508
+ sampler=sampler,
1509
+ operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
1510
+ )
1511
+
1512
+ print(f"Data loader created with batch size: {batch_size}")
1513
+ print("=== Dataset Loading Complete ===\n")
1514
+
1515
+ return dl
1516
+
1517
+ def switch_dataset(new_dataset: str) -> pygrain.DataLoader:
1518
+ """
1519
+ Utility function to quickly switch datasets
1520
+
1521
+ Args:
1522
+ new_dataset: Name of the new dataset to load
1523
+
1524
+ Returns:
1525
+ pygrain.DataLoader: New data loader
1526
+ """
1527
+ global CURRENT_DATASET
1528
+ CURRENT_DATASET = new_dataset
1529
+ return load_and_preprocess_data_enhanced(new_dataset, batch_size, maxlen)
1530
+
1531
+ def create_custom_dataset(
1532
+ name: str,
1533
+ subset: Optional[str] = None,
1534
+ text_column: str = "text",
1535
+ separator: str = "<|endoftext|>",
1536
+ streaming: bool = False,
1537
+ min_length: int = 10
1538
+ ) -> pygrain.DataLoader:
1539
+ """
1540
+ Create a data loader for a custom Hugging Face dataset
1541
+
1542
+ Args:
1543
+ name: Hugging Face dataset name (e.g., "username/dataset_name")
1544
+ subset: Dataset subset/configuration name
1545
+ text_column: Name of the text column in the dataset
1546
+ separator: Token to separate documents
1547
+ streaming: Whether to use streaming (for large datasets)
1548
+ min_length: Minimum text length to include
1549
+
1550
+ Returns:
1551
+ pygrain.DataLoader: Configured data loader for the custom dataset
1552
+ """
1553
+ config = DatasetConfig(
1554
+ name=name,
1555
+ subset=subset,
1556
+ text_column=text_column,
1557
+ separator=separator,
1558
+ streaming=streaming,
1559
+ min_length=min_length
1560
+ )
1561
+
1562
+ return load_and_preprocess_data_enhanced(
1563
+ dataset_name="custom",
1564
+ custom_config=config
1565
+ )
1566
+
1567
+ # Load the default dataset
1568
+ print("Loading default dataset...")
1569
+ text_dl = load_and_preprocess_data_enhanced()
1570
+
1571
+ # Print available datasets for reference
1572
+ print(f"Available predefined datasets: {list_available_datasets()}")
1573
+ print(f"Current dataset: {CURRENT_DATASET}")
1574
+
1575
+ """## Defining the loss function and training step function"""
1576
+
1577
+ # Defines the loss function using `optax.softmax_cross_entropy_with_integer_labels`.
1578
+ def loss_fn(model, batch):
1579
+ logits = model(batch[0])
1580
+ loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
1581
+ return loss, logits
1582
+
1583
+ # Define the training step with the `flax.nnx.jit` transformation decorator.
1584
+ @nnx.jit
1585
+ def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
1586
+ grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
1587
+ (loss, logits), grads = grad_fn(model, batch)
1588
+ metrics.update(loss=loss, logits=logits, lables=batch[1])
1589
+ optimizer.update(grads)
1590
+
1591
+ model = create_model(rngs=nnx.Rngs(0))
1592
+ optimizer = nnx.Optimizer(model, optax.adam(1e-3))
1593
+ metrics = nnx.MultiMetric(
1594
+ loss=nnx.metrics.Average('loss'),
1595
+ )
1596
+ rng = jax.random.PRNGKey(0)
1597
+
1598
+ start_prompt = "Once upon a time, "
1599
+ start_tokens = tokenizer.encode(start_prompt)[:maxlen]
1600
+ generated_text = model.generate_text(
1601
+ maxlen, start_tokens
1602
+ )
1603
+ print(f"Initial generated text:\n{generated_text}\n")
1604
+
1605
+
1606
+ metrics_history = {
1607
+ 'train_loss': [],
1608
+ }
1609
+
1610
+ prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))
1611
+
1612
+ step = 0
1613
+ for epoch in range(num_epochs):
1614
+ start_time = time.time()
1615
+ for batch in text_dl:
1616
+ if len(batch) % len(jax.devices()) != 0:
1617
+ continue # skip the remaining elements
1618
+ input_batch = jnp.array(jnp.array(batch).T)
1619
+ target_batch = prep_target_batch(input_batch)
1620
+ train_step(model, optimizer, metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P('batch', None))))
1621
+
1622
+ if (step + 1) % 200 == 0:
1623
+ for metric, value in metrics.compute().items():
1624
+ metrics_history[f'train_{metric}'].append(value)
1625
+ metrics.reset()
1626
+
1627
+ elapsed_time = time.time() - start_time
1628
+ print(f"Step {step + 1}, Loss: {metrics_history['train_loss'][-1]}, Elapsed Time: {elapsed_time:.2f} seconds")
1629
+ start_time = time.time()
1630
+
1631
+ generated_text = model.generate_text(
1632
+ maxlen, start_tokens
1633
+ )
1634
+ print(f"Generated text:\n{generated_text}\n")
1635
+ step += 1
1636
+
1637
+ # Final text generation
1638
+ generated_text = model.generate_text(
1639
+ maxlen, start_tokens
1640
+ )
1641
+ print(f"Final generated text:\n{generated_text}")
1642
+
1643
+ """Visualize the training loss."""
1644
+
1645
+ import matplotlib.pyplot as plt
1646
+ plt.plot(metrics_history['train_loss'])
1647
+ plt.title('Training Loss')
1648
+ plt.xlabel('Step')
1649
+ plt.ylabel('Loss')
1650
+ plt.show()