nmn 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nmn/nnx/examples/language/mingpt.py +1650 -0
- nmn/nnx/examples/vision/cnn_cifar.py +1769 -0
- nmn/nnx/nmn.py +1 -1
- nmn/nnx/yatattention.py +764 -0
- nmn/nnx/yatconv.py +22 -2
- nmn/torch/nmn.py +2 -1
- {nmn-0.1.2.dist-info → nmn-0.1.4.dist-info}/METADATA +2 -2
- nmn-0.1.4.dist-info/RECORD +14 -0
- nmn-0.1.2.dist-info/RECORD +0 -11
- {nmn-0.1.2.dist-info → nmn-0.1.4.dist-info}/WHEEL +0 -0
- {nmn-0.1.2.dist-info → nmn-0.1.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1650 @@
|
|
1
|
+
# Install required packages:
|
2
|
+
!pip install -Uq tiktoken grain matplotlib datasets
|
3
|
+
# pip install tensorflow-cpu wandb
|
4
|
+
# pip install --upgrade jax jaxlib flax grain
|
5
|
+
# pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
6
|
+
|
7
|
+
import jax
|
8
|
+
jax.devices()
|
9
|
+
|
10
|
+
# Download TinyStories dataset (can be replaced with HuggingFace datasets)
|
11
|
+
# wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
|
12
|
+
|
13
|
+
import jax
|
14
|
+
import jax.numpy as jnp
|
15
|
+
|
16
|
+
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding # For data and model parallelism (explained in more detail later)
|
17
|
+
from jax.experimental import mesh_utils
|
18
|
+
|
19
|
+
import flax.nnx as nnx
|
20
|
+
import optax
|
21
|
+
|
22
|
+
from dataclasses import dataclass
|
23
|
+
import grain.python as pygrain
|
24
|
+
import pandas as pd
|
25
|
+
import tiktoken
|
26
|
+
import time
|
27
|
+
from typing import Optional, Dict, List, Union
|
28
|
+
import warnings
|
29
|
+
|
30
|
+
# Hugging Face datasets integration
|
31
|
+
try:
|
32
|
+
from datasets import load_dataset, DatasetDict
|
33
|
+
import datasets
|
34
|
+
HF_DATASETS_AVAILABLE = True
|
35
|
+
except ImportError:
|
36
|
+
HF_DATASETS_AVAILABLE = False
|
37
|
+
warnings.warn("Hugging Face datasets not available. Install with: pip install datasets")
|
38
|
+
|
39
|
+
# Create a `Mesh` object representing TPU device arrangement.
|
40
|
+
mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
|
41
|
+
|
42
|
+
tokenizer = tiktoken.get_encoding("gpt2")
|
43
|
+
|
44
|
+
# Dataset Configuration
|
45
|
+
@dataclass
|
46
|
+
class DatasetConfig:
|
47
|
+
"""Configuration for dataset loading and preprocessing"""
|
48
|
+
name: str = "roneneldan/TinyStories" # Default dataset
|
49
|
+
subset: Optional[str] = None # For datasets with multiple subsets
|
50
|
+
split: str = "train"
|
51
|
+
text_column: str = "text"
|
52
|
+
streaming: bool = False # Use streaming for large datasets
|
53
|
+
cache_dir: Optional[str] = None
|
54
|
+
trust_remote_code: bool = False
|
55
|
+
|
56
|
+
# Text preprocessing options
|
57
|
+
separator: str = "<|endoftext|>" # Token to separate documents
|
58
|
+
min_length: int = 10 # Minimum text length to include
|
59
|
+
max_length: Optional[int] = None # Maximum text length before truncation
|
60
|
+
|
61
|
+
# File-based dataset options (for local files)
|
62
|
+
file_path: Optional[str] = None
|
63
|
+
file_type: str = "txt" # txt, json, csv, parquet
|
64
|
+
|
65
|
+
# Predefined dataset configurations
|
66
|
+
DATASET_CONFIGS = {
|
67
|
+
"tinystories": DatasetConfig(
|
68
|
+
name="roneneldan/TinyStories",
|
69
|
+
text_column="text",
|
70
|
+
separator="<|endoftext|>"
|
71
|
+
),
|
72
|
+
"wikitext": DatasetConfig(
|
73
|
+
name="Salesforce/wikitext",
|
74
|
+
subset="wikitext-2-raw-v1",
|
75
|
+
text_column="text",
|
76
|
+
separator="\n\n"
|
77
|
+
),
|
78
|
+
"openwebtext": DatasetConfig(
|
79
|
+
name="Skylion007/openwebtext",
|
80
|
+
text_column="text",
|
81
|
+
streaming=True, # Large dataset, use streaming
|
82
|
+
separator="<|endoftext|>"
|
83
|
+
),
|
84
|
+
"bookscorpus": DatasetConfig(
|
85
|
+
name="bookcorpus/bookcorpus",
|
86
|
+
text_column="text",
|
87
|
+
trust_remote_code=True,
|
88
|
+
separator="<|endoftext|>"
|
89
|
+
),
|
90
|
+
"c4": DatasetConfig(
|
91
|
+
name="allenai/c4",
|
92
|
+
subset="en",
|
93
|
+
text_column="text",
|
94
|
+
streaming=True,
|
95
|
+
separator="<|endoftext|>"
|
96
|
+
),
|
97
|
+
"tiny_shakespeare": DatasetConfig(
|
98
|
+
name="tiny_shakespeare",
|
99
|
+
text_column="text",
|
100
|
+
separator="\n\n"
|
101
|
+
),
|
102
|
+
"gutenberg": DatasetConfig(
|
103
|
+
name="sedthh/gutenberg_english",
|
104
|
+
text_column="text",
|
105
|
+
separator="<|endoftext|>"
|
106
|
+
),
|
107
|
+
"pile": DatasetConfig(
|
108
|
+
name="EleutherAI/pile",
|
109
|
+
text_column="text",
|
110
|
+
streaming=True,
|
111
|
+
separator="<|endoftext|>"
|
112
|
+
),
|
113
|
+
"common_crawl": DatasetConfig(
|
114
|
+
name="oscar",
|
115
|
+
subset="unshuffled_deduplicated_en",
|
116
|
+
text_column="text",
|
117
|
+
streaming=True,
|
118
|
+
separator="<|endoftext|>"
|
119
|
+
),
|
120
|
+
"local_file": DatasetConfig(
|
121
|
+
name="local",
|
122
|
+
file_path="TinyStories-train.txt",
|
123
|
+
file_type="txt",
|
124
|
+
separator="<|endoftext|>"
|
125
|
+
),
|
126
|
+
"fineweb": DatasetConfig(
|
127
|
+
name="VisionTheta/fineweb-100B",
|
128
|
+
file_type="txt",
|
129
|
+
streaming=True,
|
130
|
+
separator="<|endoftext|>"
|
131
|
+
),
|
132
|
+
|
133
|
+
}
|
134
|
+
|
135
|
+
def load_huggingface_dataset(config: DatasetConfig) -> List[str]:
|
136
|
+
"""Load dataset from Hugging Face Hub"""
|
137
|
+
if not HF_DATASETS_AVAILABLE:
|
138
|
+
raise ImportError("Hugging Face datasets not available. Install with: pip install datasets")
|
139
|
+
|
140
|
+
print(f"Loading dataset: {config.name}")
|
141
|
+
if config.subset:
|
142
|
+
print(f" Subset: {config.subset}")
|
143
|
+
|
144
|
+
try:
|
145
|
+
# Load dataset
|
146
|
+
load_kwargs = {
|
147
|
+
"path": config.name,
|
148
|
+
"split": config.split,
|
149
|
+
"streaming": config.streaming,
|
150
|
+
"trust_remote_code": config.trust_remote_code
|
151
|
+
}
|
152
|
+
|
153
|
+
if config.subset:
|
154
|
+
load_kwargs["name"] = config.subset
|
155
|
+
if config.cache_dir:
|
156
|
+
load_kwargs["cache_dir"] = config.cache_dir
|
157
|
+
|
158
|
+
dataset = load_dataset(**load_kwargs)
|
159
|
+
|
160
|
+
print(f"Dataset loaded successfully. Processing text...")
|
161
|
+
|
162
|
+
# Extract text
|
163
|
+
texts = []
|
164
|
+
count = 0
|
165
|
+
max_samples = 50000 if config.streaming else None # Limit for streaming datasets
|
166
|
+
|
167
|
+
for item in dataset:
|
168
|
+
if max_samples and count >= max_samples:
|
169
|
+
break
|
170
|
+
|
171
|
+
# Extract text from item
|
172
|
+
if config.text_column in item:
|
173
|
+
text = item[config.text_column]
|
174
|
+
elif isinstance(item, str):
|
175
|
+
text = item
|
176
|
+
else:
|
177
|
+
print(f"Warning: Text column '{config.text_column}' not found in item: {item.keys()}")
|
178
|
+
continue
|
179
|
+
|
180
|
+
# Filter by length
|
181
|
+
if len(text) < config.min_length:
|
182
|
+
continue
|
183
|
+
if config.max_length and len(text) > config.max_length:
|
184
|
+
text = text[:config.max_length]
|
185
|
+
|
186
|
+
texts.append(text)
|
187
|
+
count += 1
|
188
|
+
|
189
|
+
if count % 10000 == 0:
|
190
|
+
print(f" Processed {count} samples...")
|
191
|
+
|
192
|
+
print(f"Dataset processing complete. Total samples: {len(texts)}")
|
193
|
+
return texts
|
194
|
+
|
195
|
+
except Exception as e:
|
196
|
+
print(f"Error loading dataset {config.name}: {e}")
|
197
|
+
print("Falling back to local file if available...")
|
198
|
+
return []
|
199
|
+
|
200
|
+
def load_local_file(config: DatasetConfig) -> List[str]:
|
201
|
+
"""Load dataset from local file"""
|
202
|
+
if not config.file_path:
|
203
|
+
raise ValueError("file_path must be specified for local datasets")
|
204
|
+
|
205
|
+
print(f"Loading local file: {config.file_path}")
|
206
|
+
|
207
|
+
try:
|
208
|
+
if config.file_type == "txt":
|
209
|
+
with open(config.file_path, 'r', encoding='utf-8') as f:
|
210
|
+
text = f.read()
|
211
|
+
|
212
|
+
# Split text by separator
|
213
|
+
if config.separator in text:
|
214
|
+
texts = text.split(config.separator)
|
215
|
+
texts = [t.strip() + config.separator for t in texts if t.strip()]
|
216
|
+
else:
|
217
|
+
# Split by paragraphs if no separator found
|
218
|
+
texts = [p.strip() for p in text.split('\n\n') if p.strip()]
|
219
|
+
|
220
|
+
elif config.file_type == "json":
|
221
|
+
import json
|
222
|
+
with open(config.file_path, 'r', encoding='utf-8') as f:
|
223
|
+
data = json.load(f)
|
224
|
+
|
225
|
+
if isinstance(data, list):
|
226
|
+
texts = [item[config.text_column] if isinstance(item, dict) else str(item) for item in data]
|
227
|
+
else:
|
228
|
+
texts = [data[config.text_column]] if isinstance(data, dict) else [str(data)]
|
229
|
+
|
230
|
+
elif config.file_type == "csv":
|
231
|
+
import csv
|
232
|
+
texts = []
|
233
|
+
with open(config.file_path, 'r', encoding='utf-8') as f:
|
234
|
+
reader = csv.DictReader(f)
|
235
|
+
for row in reader:
|
236
|
+
if config.text_column in row:
|
237
|
+
texts.append(row[config.text_column])
|
238
|
+
|
239
|
+
else:
|
240
|
+
raise ValueError(f"Unsupported file type: {config.file_type}")
|
241
|
+
|
242
|
+
# Filter by length
|
243
|
+
filtered_texts = []
|
244
|
+
for text in texts:
|
245
|
+
if len(text) >= config.min_length:
|
246
|
+
if config.max_length and len(text) > config.max_length:
|
247
|
+
text = text[:config.max_length]
|
248
|
+
filtered_texts.append(text)
|
249
|
+
|
250
|
+
print(f"Local file loaded successfully. Total samples: {len(filtered_texts)}")
|
251
|
+
return filtered_texts
|
252
|
+
|
253
|
+
except Exception as e:
|
254
|
+
print(f"Error loading local file: {e}")
|
255
|
+
return []
|
256
|
+
|
257
|
+
def get_dataset_info(dataset_name: str) -> Dict:
|
258
|
+
"""Get information about available datasets"""
|
259
|
+
if dataset_name in DATASET_CONFIGS:
|
260
|
+
config = DATASET_CONFIGS[dataset_name]
|
261
|
+
return {
|
262
|
+
"name": config.name,
|
263
|
+
"subset": config.subset,
|
264
|
+
"text_column": config.text_column,
|
265
|
+
"separator": config.separator,
|
266
|
+
"streaming": config.streaming,
|
267
|
+
"description": f"Predefined configuration for {dataset_name}"
|
268
|
+
}
|
269
|
+
else:
|
270
|
+
return {"error": f"Dataset '{dataset_name}' not found in predefined configurations"}
|
271
|
+
|
272
|
+
def list_available_datasets() -> List[str]:
|
273
|
+
"""List all available dataset configurations"""
|
274
|
+
return list(DATASET_CONFIGS.keys())
|
275
|
+
|
276
|
+
"""Neural-Matter Network Definition"""
|
277
|
+
|
278
|
+
# Commented out IPython magic to ensure Python compatibility.
|
279
|
+
# Copyright 2024 The Flax Authors.
|
280
|
+
#
|
281
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
282
|
+
# you may not use this file except in compliance with the License.
|
283
|
+
# You may obtain a copy of the License at
|
284
|
+
#
|
285
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
286
|
+
#
|
287
|
+
# Unless required by applicable law or agreed to in writing, software
|
288
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
289
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
290
|
+
# See the License for the specific language governing permissions and
|
291
|
+
# limitations under the License.
|
292
|
+
|
293
|
+
"""Attention core modules for Flax."""
|
294
|
+
|
295
|
+
from __future__ import annotations
|
296
|
+
|
297
|
+
import functools
|
298
|
+
from typing import Any, Callable, Optional
|
299
|
+
import typing as tp
|
300
|
+
|
301
|
+
import jax
|
302
|
+
import jax.numpy as jnp
|
303
|
+
from jax import lax, random
|
304
|
+
|
305
|
+
from flax import nnx
|
306
|
+
from flax.nnx import rnglib
|
307
|
+
from flax.nnx.module import Module, first_from
|
308
|
+
from flax.nnx.nn import initializers
|
309
|
+
from flax.nnx.nn.dtypes import promote_dtype
|
310
|
+
from flax.nnx.nn.linear import (
|
311
|
+
LinearGeneral,
|
312
|
+
default_kernel_init,
|
313
|
+
)
|
314
|
+
from flax.nnx.nn.normalization import LayerNorm
|
315
|
+
from flax.typing import (
|
316
|
+
Dtype,
|
317
|
+
Shape,
|
318
|
+
Initializer,
|
319
|
+
PrecisionLike,
|
320
|
+
DotGeneralT,
|
321
|
+
)
|
322
|
+
|
323
|
+
|
324
|
+
|
325
|
+
def yat_attention_weights(
|
326
|
+
query: Array,
|
327
|
+
key: Array,
|
328
|
+
bias: Optional[Array] = None,
|
329
|
+
mask: Optional[Array] = None,
|
330
|
+
broadcast_dropout: bool = True,
|
331
|
+
dropout_rng: Optional[Array] = None,
|
332
|
+
dropout_rate: float = 0.0,
|
333
|
+
deterministic: bool = False,
|
334
|
+
dtype: Optional[Dtype] = None,
|
335
|
+
precision: PrecisionLike = None,
|
336
|
+
module: Optional[Module] = None,
|
337
|
+
epsilon: float = 1e-5,
|
338
|
+
):
|
339
|
+
"""Computes attention weights using YatNMN distance-based calculation."""
|
340
|
+
query, key = promote_dtype((query, key), dtype=dtype)
|
341
|
+
dtype = query.dtype
|
342
|
+
|
343
|
+
assert query.ndim == key.ndim, 'q, k must have same rank.'
|
344
|
+
assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
|
345
|
+
assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
|
346
|
+
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
|
347
|
+
|
348
|
+
# YatNMN-style attention calculation using the cleaner approach
|
349
|
+
# query shape: [..., q_length, num_heads, head_dim]
|
350
|
+
# key shape: [..., kv_length, num_heads, head_dim]
|
351
|
+
|
352
|
+
# Calculate dot product attention scores
|
353
|
+
attn = jnp.einsum('...qhd,...khd->...hqk', query, key, precision=precision)
|
354
|
+
squared_dot_product = jnp.square(attn)
|
355
|
+
|
356
|
+
# Calculate norms
|
357
|
+
q_norm = jnp.sum(jnp.square(query), axis=-1, keepdims=True) # [..., q_length, num_heads, 1]
|
358
|
+
k_norm = jnp.sum(jnp.square(key), axis=-1, keepdims=True) # [..., kv_length, num_heads, 1]
|
359
|
+
qk_norm_sum = q_norm + k_norm # Broadcasting: [..., q_length, num_heads, 1] + [..., kv_length, num_heads, 1]
|
360
|
+
|
361
|
+
# Transpose to match attention dimensions [..., num_heads, q_length, kv_length]
|
362
|
+
# The transpose converts [..., q_length, num_heads, kv_length] -> [..., num_heads, q_length, kv_length]
|
363
|
+
batch_dims = len(qk_norm_sum.shape) - 3
|
364
|
+
transpose_axes = tuple(range(batch_dims)) + (batch_dims + 1, batch_dims, batch_dims + 2)
|
365
|
+
qk_norm_sum_transposed = qk_norm_sum.transpose(transpose_axes)
|
366
|
+
|
367
|
+
# Calculate squared distances: ||q||² + ||k||² - 2*(q·k)²
|
368
|
+
squared_dist = qk_norm_sum_transposed - 2.0 * squared_dot_product
|
369
|
+
|
370
|
+
# YatNMN attention scores: (q·k)² / (squared_distance + ε)
|
371
|
+
attn_weights = squared_dot_product / (squared_dist + epsilon)
|
372
|
+
|
373
|
+
# apply attention bias: masking, dropout, proximity bias, etc.
|
374
|
+
if bias is not None:
|
375
|
+
attn_weights = attn_weights + bias
|
376
|
+
# apply attention mask
|
377
|
+
if mask is not None:
|
378
|
+
big_neg = jnp.finfo(dtype).min
|
379
|
+
attn_weights = jnp.where(mask, attn_weights, big_neg)
|
380
|
+
|
381
|
+
# normalize the attention weights
|
382
|
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
383
|
+
|
384
|
+
if module:
|
385
|
+
module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
|
386
|
+
|
387
|
+
# apply attention dropout
|
388
|
+
if not deterministic and dropout_rate > 0.0:
|
389
|
+
keep_prob = 1.0 - dropout_rate
|
390
|
+
if broadcast_dropout:
|
391
|
+
# dropout is broadcast across the batch + head dimensions
|
392
|
+
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
|
393
|
+
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
|
394
|
+
else:
|
395
|
+
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
|
396
|
+
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
|
397
|
+
attn_weights = attn_weights * multiplier
|
398
|
+
|
399
|
+
return attn_weights
|
400
|
+
|
401
|
+
|
402
|
+
def yat_attention(
|
403
|
+
query: Array,
|
404
|
+
key: Array,
|
405
|
+
value: Array,
|
406
|
+
bias: Optional[Array] = None,
|
407
|
+
mask: Optional[Array] = None,
|
408
|
+
broadcast_dropout: bool = True,
|
409
|
+
dropout_rng: Optional[Array] = None,
|
410
|
+
dropout_rate: float = 0.0,
|
411
|
+
deterministic: bool = False,
|
412
|
+
dtype: Optional[Dtype] = None,
|
413
|
+
precision: PrecisionLike = None,
|
414
|
+
module: Optional[Module] = None,
|
415
|
+
epsilon: float = 1e-5,
|
416
|
+
):
|
417
|
+
"""Computes attention using YatNMN distance-based calculation."""
|
418
|
+
query, key, value = promote_dtype((query, key, value), dtype=dtype)
|
419
|
+
dtype = query.dtype
|
420
|
+
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
|
421
|
+
assert (
|
422
|
+
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
|
423
|
+
), 'q, k, v batch dims must match.'
|
424
|
+
assert (
|
425
|
+
query.shape[-2] == key.shape[-2] == value.shape[-2]
|
426
|
+
), 'q, k, v num_heads must match.'
|
427
|
+
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
|
428
|
+
|
429
|
+
# compute attention weights using YatNMN
|
430
|
+
attn_weights = yat_attention_weights(
|
431
|
+
query,
|
432
|
+
key,
|
433
|
+
bias,
|
434
|
+
mask,
|
435
|
+
broadcast_dropout,
|
436
|
+
dropout_rng,
|
437
|
+
dropout_rate,
|
438
|
+
deterministic,
|
439
|
+
dtype,
|
440
|
+
precision,
|
441
|
+
module,
|
442
|
+
epsilon,
|
443
|
+
)
|
444
|
+
|
445
|
+
# return weighted sum over values for each query position
|
446
|
+
return jnp.einsum(
|
447
|
+
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
|
448
|
+
)
|
449
|
+
|
450
|
+
Array = jax.Array
|
451
|
+
|
452
|
+
# Add YatNMN class implementation
|
453
|
+
default_bias_init = initializers.zeros_init()
|
454
|
+
default_alpha_init = initializers.ones_init()
|
455
|
+
|
456
|
+
class YatNMN(Module):
|
457
|
+
"""A linear transformation with custom distance-based computation."""
|
458
|
+
|
459
|
+
def __init__(
|
460
|
+
self,
|
461
|
+
in_features: int,
|
462
|
+
out_features: int,
|
463
|
+
*,
|
464
|
+
use_bias: bool = True,
|
465
|
+
use_alpha: bool = True,
|
466
|
+
dtype: Optional[Dtype] = None,
|
467
|
+
param_dtype: Dtype = jnp.float32,
|
468
|
+
precision: PrecisionLike = None,
|
469
|
+
kernel_init: Initializer = default_kernel_init,
|
470
|
+
bias_init: Initializer = default_bias_init,
|
471
|
+
alpha_init: Initializer = default_alpha_init,
|
472
|
+
dot_general: DotGeneralT = lax.dot_general,
|
473
|
+
rngs: rnglib.Rngs,
|
474
|
+
epsilon: float = 1e-5,
|
475
|
+
):
|
476
|
+
|
477
|
+
kernel_key = rngs.params()
|
478
|
+
self.kernel = nnx.Param(
|
479
|
+
kernel_init(kernel_key, (in_features, out_features), param_dtype)
|
480
|
+
)
|
481
|
+
self.bias: nnx.Param[jax.Array] | None
|
482
|
+
if use_bias:
|
483
|
+
bias_key = rngs.params()
|
484
|
+
self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
|
485
|
+
else:
|
486
|
+
self.bias = None
|
487
|
+
|
488
|
+
self.alpha: nnx.Param[jax.Array] | None
|
489
|
+
if use_alpha:
|
490
|
+
alpha_key = rngs.params()
|
491
|
+
self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
|
492
|
+
else:
|
493
|
+
self.alpha = None
|
494
|
+
|
495
|
+
self.in_features = in_features
|
496
|
+
self.out_features = out_features
|
497
|
+
self.use_bias = use_bias
|
498
|
+
self.use_alpha = use_alpha
|
499
|
+
self.dtype = dtype
|
500
|
+
self.param_dtype = param_dtype
|
501
|
+
self.precision = precision
|
502
|
+
self.kernel_init = kernel_init
|
503
|
+
self.bias_init = bias_init
|
504
|
+
self.dot_general = dot_general
|
505
|
+
self.epsilon = epsilon
|
506
|
+
|
507
|
+
def __call__(self, inputs: Array) -> Array:
|
508
|
+
"""Applies YatNMN transformation to inputs."""
|
509
|
+
kernel = self.kernel.value
|
510
|
+
bias = self.bias.value if self.bias is not None else None
|
511
|
+
alpha = self.alpha.value if self.alpha is not None else None
|
512
|
+
|
513
|
+
y = self.dot_general(
|
514
|
+
inputs,
|
515
|
+
kernel,
|
516
|
+
(((inputs.ndim - 1,), (0,)), ((), ())),
|
517
|
+
precision=self.precision,
|
518
|
+
)
|
519
|
+
|
520
|
+
inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
|
521
|
+
kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True)
|
522
|
+
distances = inputs_squared_sum + kernel_squared_sum - 2 * y
|
523
|
+
|
524
|
+
# Element-wise operation
|
525
|
+
y = y ** 2 / (distances + self.epsilon)
|
526
|
+
|
527
|
+
if bias is not None:
|
528
|
+
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
|
529
|
+
|
530
|
+
if alpha is not None:
|
531
|
+
scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
|
532
|
+
y = y * scale
|
533
|
+
|
534
|
+
return y
|
535
|
+
|
536
|
+
|
537
|
+
def dot_product_attention_weights(
|
538
|
+
query: Array,
|
539
|
+
key: Array,
|
540
|
+
bias: Optional[Array] = None,
|
541
|
+
mask: Optional[Array] = None,
|
542
|
+
broadcast_dropout: bool = True,
|
543
|
+
dropout_rng: Optional[Array] = None,
|
544
|
+
dropout_rate: float = 0.0,
|
545
|
+
deterministic: bool = False,
|
546
|
+
dtype: Optional[Dtype] = None,
|
547
|
+
precision: PrecisionLike = None,
|
548
|
+
module: Optional[Module] = None,
|
549
|
+
):
|
550
|
+
"""Computes dot-product attention weights given query and key.
|
551
|
+
|
552
|
+
Used by :func:`dot_product_attention`, which is what you'll most likely use.
|
553
|
+
But if you want access to the attention weights for introspection, then
|
554
|
+
you can directly call this function and call einsum yourself.
|
555
|
+
|
556
|
+
Args:
|
557
|
+
query: queries for calculating attention with shape of `[batch..., q_length,
|
558
|
+
num_heads, qk_depth_per_head]`.
|
559
|
+
key: keys for calculating attention with shape of `[batch..., kv_length,
|
560
|
+
num_heads, qk_depth_per_head]`.
|
561
|
+
bias: bias for the attention weights. This should be broadcastable to the
|
562
|
+
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
|
563
|
+
incorporating causal masks, padding masks, proximity bias, etc.
|
564
|
+
mask: mask for the attention weights. This should be broadcastable to the
|
565
|
+
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
|
566
|
+
incorporating causal masks. Attention weights are masked out if their
|
567
|
+
corresponding mask value is `False`.
|
568
|
+
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
|
569
|
+
dropout_rng: JAX PRNGKey: to be used for dropout
|
570
|
+
dropout_rate: dropout rate
|
571
|
+
deterministic: bool, deterministic or not (to apply dropout)
|
572
|
+
dtype: the dtype of the computation (default: infer from inputs and params)
|
573
|
+
precision: numerical precision of the computation see `jax.lax.Precision`
|
574
|
+
for details.
|
575
|
+
module: the Module that will sow the attention weights into the
|
576
|
+
``nnx.Intermediate`` collection. If ``module`` is None, the attention
|
577
|
+
weights will not be sowed.
|
578
|
+
|
579
|
+
Returns:
|
580
|
+
Output of shape `[batch..., num_heads, q_length, kv_length]`.
|
581
|
+
"""
|
582
|
+
query, key = promote_dtype((query, key), dtype=dtype) # type: ignore[bad-unpacking]
|
583
|
+
dtype = query.dtype
|
584
|
+
|
585
|
+
assert query.ndim == key.ndim, 'q, k must have same rank.'
|
586
|
+
assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
|
587
|
+
assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
|
588
|
+
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
|
589
|
+
|
590
|
+
# calculate attention matrix
|
591
|
+
depth = query.shape[-1]
|
592
|
+
query = query / jnp.sqrt(depth).astype(dtype)
|
593
|
+
# attn weight shape is (batch..., num_heads, q_length, kv_length)
|
594
|
+
attn_weights = jnp.einsum(
|
595
|
+
'...qhd,...khd->...hqk', query, key, precision=precision
|
596
|
+
)
|
597
|
+
|
598
|
+
# apply attention bias: masking, dropout, proximity bias, etc.
|
599
|
+
if bias is not None:
|
600
|
+
attn_weights = attn_weights + bias
|
601
|
+
# apply attention mask
|
602
|
+
if mask is not None:
|
603
|
+
big_neg = jnp.finfo(dtype).min
|
604
|
+
attn_weights = jnp.where(mask, attn_weights, big_neg)
|
605
|
+
|
606
|
+
# normalize the attention weights
|
607
|
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
608
|
+
|
609
|
+
if module:
|
610
|
+
module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
|
611
|
+
|
612
|
+
# apply attention dropout
|
613
|
+
if not deterministic and dropout_rate > 0.0:
|
614
|
+
keep_prob = 1.0 - dropout_rate
|
615
|
+
if broadcast_dropout:
|
616
|
+
# dropout is broadcast across the batch + head dimensions
|
617
|
+
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
|
618
|
+
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore
|
619
|
+
else:
|
620
|
+
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore
|
621
|
+
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
|
622
|
+
attn_weights = attn_weights * multiplier
|
623
|
+
|
624
|
+
return attn_weights
|
625
|
+
|
626
|
+
|
627
|
+
def dot_product_attention(
|
628
|
+
query: Array,
|
629
|
+
key: Array,
|
630
|
+
value: Array,
|
631
|
+
bias: Optional[Array] = None,
|
632
|
+
mask: Optional[Array] = None,
|
633
|
+
broadcast_dropout: bool = True,
|
634
|
+
dropout_rng: Optional[Array] = None,
|
635
|
+
dropout_rate: float = 0.0,
|
636
|
+
deterministic: bool = False,
|
637
|
+
dtype: Optional[Dtype] = None,
|
638
|
+
precision: PrecisionLike = None,
|
639
|
+
module: Optional[Module] = None,
|
640
|
+
):
|
641
|
+
"""Computes dot-product attention given query, key, and value.
|
642
|
+
|
643
|
+
This is the core function for applying attention based on
|
644
|
+
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
|
645
|
+
query and key and combines the values using the attention weights.
|
646
|
+
|
647
|
+
.. note::
|
648
|
+
``query``, ``key``, ``value`` needn't have any batch dimensions.
|
649
|
+
|
650
|
+
Args:
|
651
|
+
query: queries for calculating attention with shape of ``[batch..., q_length,
|
652
|
+
num_heads, qk_depth_per_head]``.
|
653
|
+
key: keys for calculating attention with shape of ``[batch..., kv_length,
|
654
|
+
num_heads, qk_depth_per_head]``.
|
655
|
+
value: values to be used in attention with shape of ``[batch..., kv_length,
|
656
|
+
num_heads, v_depth_per_head]``.
|
657
|
+
bias: bias for the attention weights. This should be broadcastable to the
|
658
|
+
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
|
659
|
+
incorporating causal masks, padding masks, proximity bias, etc.
|
660
|
+
mask: mask for the attention weights. This should be broadcastable to the
|
661
|
+
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
|
662
|
+
incorporating causal masks. Attention weights are masked out if their
|
663
|
+
corresponding mask value is `False`.
|
664
|
+
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
|
665
|
+
dropout_rng: JAX PRNGKey: to be used for dropout
|
666
|
+
dropout_rate: dropout rate
|
667
|
+
deterministic: bool, deterministic or not (to apply dropout)
|
668
|
+
dtype: the dtype of the computation (default: infer from inputs)
|
669
|
+
precision: numerical precision of the computation see `jax.lax.Precision`
|
670
|
+
for details.
|
671
|
+
module: the Module that will sow the attention weights into the
|
672
|
+
``nnx.Intermediate`` collection. If ``module`` is None, the attention
|
673
|
+
weights will not be sowed.
|
674
|
+
|
675
|
+
Returns:
|
676
|
+
Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
|
677
|
+
"""
|
678
|
+
query, key, value = promote_dtype((query, key, value), dtype=dtype) # type: ignore[bad-unpacking]
|
679
|
+
dtype = query.dtype
|
680
|
+
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
|
681
|
+
assert (
|
682
|
+
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
|
683
|
+
), 'q, k, v batch dims must match.'
|
684
|
+
assert (
|
685
|
+
query.shape[-2] == key.shape[-2] == value.shape[-2]
|
686
|
+
), 'q, k, v num_heads must match.'
|
687
|
+
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
|
688
|
+
|
689
|
+
# compute attention weights
|
690
|
+
attn_weights = dot_product_attention_weights(
|
691
|
+
query,
|
692
|
+
key,
|
693
|
+
bias,
|
694
|
+
mask,
|
695
|
+
broadcast_dropout,
|
696
|
+
dropout_rng,
|
697
|
+
dropout_rate,
|
698
|
+
deterministic,
|
699
|
+
dtype,
|
700
|
+
precision,
|
701
|
+
module,
|
702
|
+
)
|
703
|
+
|
704
|
+
# return weighted sum over values for each query position
|
705
|
+
return jnp.einsum(
|
706
|
+
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
|
707
|
+
)
|
708
|
+
|
709
|
+
|
710
|
+
class MultiHeadAttention(Module):
|
711
|
+
"""Multi-head attention.
|
712
|
+
|
713
|
+
Example usage::
|
714
|
+
|
715
|
+
>>> import flax.linen as nn
|
716
|
+
>>> import jax
|
717
|
+
|
718
|
+
>>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16)
|
719
|
+
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
|
720
|
+
>>> shape = (4, 3, 2, 5)
|
721
|
+
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
|
722
|
+
>>> variables = layer.init(jax.random.key(0), q)
|
723
|
+
|
724
|
+
>>> # different inputs for inputs_q, inputs_k and inputs_v
|
725
|
+
>>> out = layer.apply(variables, q, k, v)
|
726
|
+
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
|
727
|
+
>>> out = layer.apply(variables, q, k)
|
728
|
+
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
|
729
|
+
>>> out = layer.apply(variables, q)
|
730
|
+
|
731
|
+
>>> attention_kwargs = dict(
|
732
|
+
... num_heads=8,
|
733
|
+
... qkv_features=16,
|
734
|
+
... kernel_init=nn.initializers.ones,
|
735
|
+
... bias_init=nn.initializers.zeros,
|
736
|
+
... dropout_rate=0.5,
|
737
|
+
... deterministic=False,
|
738
|
+
... )
|
739
|
+
>>> class Module(nn.Module):
|
740
|
+
... attention_kwargs: dict
|
741
|
+
...
|
742
|
+
... @nn.compact
|
743
|
+
... def __call__(self, x, dropout_rng=None):
|
744
|
+
... out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
|
745
|
+
... out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
|
746
|
+
... return out1, out2
|
747
|
+
>>> module = Module(attention_kwargs)
|
748
|
+
>>> variables = module.init({'params': key1, 'dropout': key2}, q)
|
749
|
+
|
750
|
+
>>> # out1 and out2 are different.
|
751
|
+
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
|
752
|
+
>>> # out3 and out4 are different.
|
753
|
+
>>> # out1 and out3 are different. out2 and out4 are different.
|
754
|
+
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
|
755
|
+
>>> # out1 and out2 are the same.
|
756
|
+
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
|
757
|
+
>>> # out1 and out2 are the same as out3 and out4.
|
758
|
+
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
|
759
|
+
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
|
760
|
+
|
761
|
+
Attributes:
|
762
|
+
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
|
763
|
+
should be divisible by the number of heads.
|
764
|
+
dtype: the dtype of the computation (default: infer from inputs and params)
|
765
|
+
param_dtype: the dtype passed to parameter initializers (default: float32)
|
766
|
+
qkv_features: dimension of the key, query, and value.
|
767
|
+
out_features: dimension of the last projection
|
768
|
+
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
|
769
|
+
dropout_rate: dropout rate
|
770
|
+
deterministic: if false, the attention weight is masked randomly using
|
771
|
+
dropout, whereas if true, the attention weights are deterministic.
|
772
|
+
precision: numerical precision of the computation see `jax.lax.Precision`
|
773
|
+
for details.
|
774
|
+
kernel_init: initializer for the kernel of the Dense layers.
|
775
|
+
out_kernel_init: optional initializer for the kernel of the output Dense layer,
|
776
|
+
if None, the kernel_init is used.
|
777
|
+
bias_init: initializer for the bias of the Dense layers.
|
778
|
+
out_bias_init: optional initializer for the bias of the output Dense layer,
|
779
|
+
if None, the bias_init is used.
|
780
|
+
use_bias: bool: whether pointwise QKVO dense transforms use bias.
|
781
|
+
attention_fn: dot_product_attention or compatible function. Accepts query,
|
782
|
+
key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,,
|
783
|
+
num_heads, value_channels]``
|
784
|
+
decode: whether to prepare and use an autoregressive cache.
|
785
|
+
normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442).
|
786
|
+
"""
|
787
|
+
|
788
|
+
def __init__(
|
789
|
+
self,
|
790
|
+
num_heads: int,
|
791
|
+
in_features: int,
|
792
|
+
qkv_features: int | None = None,
|
793
|
+
out_features: int | None = None,
|
794
|
+
*,
|
795
|
+
dtype: Dtype | None = None,
|
796
|
+
param_dtype: Dtype = jnp.float32,
|
797
|
+
broadcast_dropout: bool = True,
|
798
|
+
dropout_rate: float = 0.0,
|
799
|
+
deterministic: bool | None = None,
|
800
|
+
precision: PrecisionLike = None,
|
801
|
+
kernel_init: Initializer = default_kernel_init,
|
802
|
+
out_kernel_init: Initializer | None = None,
|
803
|
+
bias_init: Initializer = initializers.zeros_init(),
|
804
|
+
out_bias_init: Initializer | None = None,
|
805
|
+
use_bias: bool = True,
|
806
|
+
attention_fn: Callable[..., Array] = yat_attention,
|
807
|
+
decode: bool | None = None,
|
808
|
+
normalize_qk: bool = False,
|
809
|
+
# Deprecated, will be removed.
|
810
|
+
qkv_dot_general: DotGeneralT | None = None,
|
811
|
+
out_dot_general: DotGeneralT | None = None,
|
812
|
+
qkv_dot_general_cls: Any = None,
|
813
|
+
out_dot_general_cls: Any = None,
|
814
|
+
rngs: rnglib.Rngs,
|
815
|
+
epsilon: float = 1e-5,
|
816
|
+
):
|
817
|
+
self.num_heads = num_heads
|
818
|
+
self.in_features = in_features
|
819
|
+
self.qkv_features = (
|
820
|
+
qkv_features if qkv_features is not None else in_features
|
821
|
+
)
|
822
|
+
self.out_features = (
|
823
|
+
out_features if out_features is not None else in_features
|
824
|
+
)
|
825
|
+
self.dtype = dtype
|
826
|
+
self.param_dtype = param_dtype
|
827
|
+
self.broadcast_dropout = broadcast_dropout
|
828
|
+
self.dropout_rate = dropout_rate
|
829
|
+
self.deterministic = deterministic
|
830
|
+
self.precision = precision
|
831
|
+
self.kernel_init = kernel_init
|
832
|
+
self.out_kernel_init = out_kernel_init
|
833
|
+
self.bias_init = bias_init
|
834
|
+
self.out_bias_init = out_bias_init
|
835
|
+
self.use_bias = use_bias
|
836
|
+
self.attention_fn = attention_fn
|
837
|
+
self.decode = decode
|
838
|
+
self.normalize_qk = normalize_qk
|
839
|
+
self.qkv_dot_general = qkv_dot_general
|
840
|
+
self.out_dot_general = out_dot_general
|
841
|
+
self.qkv_dot_general_cls = qkv_dot_general_cls
|
842
|
+
self.out_dot_general_cls = out_dot_general_cls
|
843
|
+
self.epsilon = epsilon
|
844
|
+
|
845
|
+
if self.qkv_features % self.num_heads != 0:
|
846
|
+
raise ValueError(
|
847
|
+
f'Memory dimension ({self.qkv_features}) must be divisible by '
|
848
|
+
f"'num_heads' heads ({self.num_heads})."
|
849
|
+
)
|
850
|
+
|
851
|
+
self.head_dim = self.qkv_features // self.num_heads
|
852
|
+
|
853
|
+
# Replace LinearGeneral with YatNMN for query, key, value projections
|
854
|
+
yat_linear = functools.partial(
|
855
|
+
YatNMN,
|
856
|
+
in_features=self.in_features,
|
857
|
+
out_features=self.qkv_features, # Output total features, will reshape later
|
858
|
+
dtype=self.dtype,
|
859
|
+
param_dtype=self.param_dtype,
|
860
|
+
kernel_init=self.kernel_init,
|
861
|
+
bias_init=self.bias_init,
|
862
|
+
use_bias=self.use_bias,
|
863
|
+
precision=self.precision,
|
864
|
+
epsilon=self.epsilon,
|
865
|
+
)
|
866
|
+
|
867
|
+
# project inputs_q to multi-headed q/k/v
|
868
|
+
# dimensions will be reshaped to [batch..., length, n_heads, n_features_per_head]
|
869
|
+
self.query = yat_linear(rngs=rngs)
|
870
|
+
self.key = yat_linear(rngs=rngs)
|
871
|
+
self.value = yat_linear(rngs=rngs)
|
872
|
+
|
873
|
+
self.query_ln: LayerNorm | None
|
874
|
+
self.key_ln: LayerNorm | None
|
875
|
+
if self.normalize_qk:
|
876
|
+
# Normalizing query and key projections stabilizes training with higher
|
877
|
+
# LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
|
878
|
+
self.query_ln = LayerNorm(
|
879
|
+
self.head_dim,
|
880
|
+
use_bias=False,
|
881
|
+
dtype=self.dtype,
|
882
|
+
param_dtype=self.param_dtype,
|
883
|
+
rngs=rngs,
|
884
|
+
)
|
885
|
+
self.key_ln = LayerNorm(
|
886
|
+
self.head_dim,
|
887
|
+
use_bias=False,
|
888
|
+
dtype=self.dtype,
|
889
|
+
param_dtype=self.param_dtype,
|
890
|
+
rngs=rngs,
|
891
|
+
)
|
892
|
+
else:
|
893
|
+
self.query_ln = None
|
894
|
+
self.key_ln = None
|
895
|
+
|
896
|
+
# Remove the output layer - no more self.out
|
897
|
+
self.rngs = rngs if dropout_rate > 0.0 else None
|
898
|
+
|
899
|
+
self.cached_key: nnx.Cache[Array] | None = None
|
900
|
+
self.cached_value: nnx.Cache[Array] | None = None
|
901
|
+
self.cache_index: nnx.Cache[Array] | None = None
|
902
|
+
|
903
|
+
def __call__(
|
904
|
+
self,
|
905
|
+
inputs_q: Array,
|
906
|
+
inputs_k: Array | None = None,
|
907
|
+
inputs_v: Array | None = None,
|
908
|
+
*,
|
909
|
+
mask: Array | None = None,
|
910
|
+
deterministic: bool | None = None,
|
911
|
+
rngs: rnglib.Rngs | None = None,
|
912
|
+
sow_weights: bool = False,
|
913
|
+
decode: bool | None = None,
|
914
|
+
):
|
915
|
+
"""Applies multi-head dot product attention on the input data.
|
916
|
+
|
917
|
+
Projects the inputs into multi-headed query, key, and value vectors,
|
918
|
+
applies dot-product attention and project the results to an output vector.
|
919
|
+
|
920
|
+
If both inputs_k and inputs_v are None, they will both copy the value of
|
921
|
+
inputs_q (self attention).
|
922
|
+
If only inputs_v is None, it will copy the value of inputs_k.
|
923
|
+
|
924
|
+
Args:
|
925
|
+
inputs_q: input queries of shape `[batch_sizes..., length, features]`.
|
926
|
+
inputs_k: key of shape `[batch_sizes..., length, features]`. If None,
|
927
|
+
inputs_k will copy the value of inputs_q.
|
928
|
+
inputs_v: values of shape `[batch_sizes..., length, features]`. If None,
|
929
|
+
inputs_v will copy the value of inputs_k.
|
930
|
+
mask: attention mask of shape `[batch_sizes..., num_heads, query_length,
|
931
|
+
key/value_length]`. Attention weights are masked out if their
|
932
|
+
corresponding mask value is `False`.
|
933
|
+
deterministic: if false, the attention weight is masked randomly using
|
934
|
+
dropout, whereas if true, the attention weights are deterministic.
|
935
|
+
rngs: container for random number generators to generate the dropout
|
936
|
+
mask when `deterministic` is False. The `rngs` container should have a
|
937
|
+
`dropout` key.
|
938
|
+
sow_weights: if ``True``, the attention weights are sowed into the
|
939
|
+
'intermediates' collection.
|
940
|
+
|
941
|
+
Returns:
|
942
|
+
output of shape `[batch_sizes..., length, features]`.
|
943
|
+
"""
|
944
|
+
if rngs is None:
|
945
|
+
rngs = self.rngs
|
946
|
+
|
947
|
+
if inputs_k is None:
|
948
|
+
if inputs_v is not None:
|
949
|
+
raise ValueError(
|
950
|
+
'`inputs_k` cannot be None if `inputs_v` is not None. '
|
951
|
+
'To have both `inputs_k` and `inputs_v` be the same value, pass in the '
|
952
|
+
'value to `inputs_k` and leave `inputs_v` as None.'
|
953
|
+
)
|
954
|
+
inputs_k = inputs_q
|
955
|
+
if inputs_v is None:
|
956
|
+
inputs_v = inputs_k
|
957
|
+
|
958
|
+
if inputs_q.shape[-1] != self.in_features:
|
959
|
+
raise ValueError(
|
960
|
+
f'Incompatible input dimension, got {inputs_q.shape[-1]} '
|
961
|
+
f'but module expects {self.in_features}.'
|
962
|
+
)
|
963
|
+
|
964
|
+
# Apply YatNMN transformations and reshape to multi-head format
|
965
|
+
query = self.query(inputs_q)
|
966
|
+
key = self.key(inputs_k)
|
967
|
+
value = self.value(inputs_v)
|
968
|
+
|
969
|
+
# Reshape from [batch..., length, qkv_features] to [batch..., length, num_heads, head_dim]
|
970
|
+
query = query.reshape(query.shape[:-1] + (self.num_heads, self.head_dim))
|
971
|
+
key = key.reshape(key.shape[:-1] + (self.num_heads, self.head_dim))
|
972
|
+
value = value.reshape(value.shape[:-1] + (self.num_heads, self.head_dim))
|
973
|
+
|
974
|
+
if self.normalize_qk:
|
975
|
+
assert self.query_ln is not None and self.key_ln is not None
|
976
|
+
# Normalizing query and key projections stabilizes training with higher
|
977
|
+
# LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
|
978
|
+
query = self.query_ln(query)
|
979
|
+
key = self.key_ln(key)
|
980
|
+
|
981
|
+
# During fast autoregressive decoding, we feed one position at a time,
|
982
|
+
# and cache the keys and values step by step.
|
983
|
+
decode = first_from(
|
984
|
+
decode,
|
985
|
+
self.decode,
|
986
|
+
error_msg="""No `decode` argument was provided to MultiHeadAttention
|
987
|
+
as either a __call__ argument, class attribute, or nnx.flag.""",
|
988
|
+
)
|
989
|
+
|
990
|
+
if decode:
|
991
|
+
if (
|
992
|
+
self.cached_key is None
|
993
|
+
or self.cached_value is None
|
994
|
+
or self.cache_index is None
|
995
|
+
):
|
996
|
+
raise ValueError(
|
997
|
+
'Autoregressive cache not initialized, call ``init_cache`` first.'
|
998
|
+
)
|
999
|
+
(
|
1000
|
+
*batch_dims,
|
1001
|
+
max_length,
|
1002
|
+
num_heads,
|
1003
|
+
depth_per_head,
|
1004
|
+
) = self.cached_key.value.shape
|
1005
|
+
# shape check of cached keys against query input
|
1006
|
+
expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
|
1007
|
+
if expected_shape != query.shape:
|
1008
|
+
raise ValueError(
|
1009
|
+
'Autoregressive cache shape error, '
|
1010
|
+
'expected query shape %s instead got %s.'
|
1011
|
+
# % (expected_shape, query.shape)
|
1012
|
+
)
|
1013
|
+
# update key, value caches with our new 1d spatial slices
|
1014
|
+
cur_index = self.cache_index.value
|
1015
|
+
zero = jnp.array(0, dtype=lax.dtype(cur_index.dtype))
|
1016
|
+
indices = (zero,) * len(batch_dims) + (cur_index, zero, zero)
|
1017
|
+
key = lax.dynamic_update_slice(self.cached_key.value, key, indices)
|
1018
|
+
value = lax.dynamic_update_slice(self.cached_value.value, value, indices)
|
1019
|
+
self.cached_key.value = key
|
1020
|
+
self.cached_value.value = value
|
1021
|
+
self.cache_index.value += 1
|
1022
|
+
# causal mask for cached decoder self-attention:
|
1023
|
+
# our single query position should only attend to those key
|
1024
|
+
# positions that have already been generated and cached,
|
1025
|
+
# not the remaining zero elements.
|
1026
|
+
mask = combine_masks(
|
1027
|
+
mask,
|
1028
|
+
jnp.broadcast_to(
|
1029
|
+
jnp.arange(max_length) <= cur_index,
|
1030
|
+
tuple(batch_dims) + (1, 1, max_length),
|
1031
|
+
),
|
1032
|
+
)
|
1033
|
+
|
1034
|
+
if (
|
1035
|
+
self.dropout_rate > 0.0
|
1036
|
+
): # Require `deterministic` only if using dropout.
|
1037
|
+
deterministic = first_from(
|
1038
|
+
deterministic,
|
1039
|
+
self.deterministic,
|
1040
|
+
error_msg="""No `deterministic` argument was provided to MultiHeadAttention
|
1041
|
+
as either a __call__ argument, class attribute, or nnx.flag.""",
|
1042
|
+
)
|
1043
|
+
if not deterministic:
|
1044
|
+
if rngs is None:
|
1045
|
+
raise ValueError(
|
1046
|
+
"'rngs' must be provided if 'dropout_rng' is not given."
|
1047
|
+
)
|
1048
|
+
dropout_rng = rngs.dropout()
|
1049
|
+
else:
|
1050
|
+
dropout_rng = None
|
1051
|
+
else:
|
1052
|
+
deterministic = True
|
1053
|
+
dropout_rng = None
|
1054
|
+
|
1055
|
+
# apply attention with epsilon parameter for YatNMN
|
1056
|
+
x = self.attention_fn(
|
1057
|
+
query,
|
1058
|
+
key,
|
1059
|
+
value,
|
1060
|
+
mask=mask,
|
1061
|
+
dropout_rng=dropout_rng,
|
1062
|
+
dropout_rate=self.dropout_rate,
|
1063
|
+
broadcast_dropout=self.broadcast_dropout,
|
1064
|
+
deterministic=deterministic,
|
1065
|
+
dtype=self.dtype,
|
1066
|
+
precision=self.precision,
|
1067
|
+
module=self if sow_weights else None,
|
1068
|
+
epsilon=self.epsilon, # Pass epsilon to yat_attention
|
1069
|
+
)
|
1070
|
+
# Reshape attention output back to original embedding dimension
|
1071
|
+
# from [batch..., length, num_heads, head_dim] to [batch..., length, qkv_features]
|
1072
|
+
x = x.reshape(x.shape[:-2] + (self.qkv_features,))
|
1073
|
+
return x
|
1074
|
+
|
1075
|
+
def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32):
|
1076
|
+
"""Initializes cache for fast autoregressive decoding. When
|
1077
|
+
``decode=True``, this method must be called first before performing
|
1078
|
+
forward inference.
|
1079
|
+
|
1080
|
+
Example usage::
|
1081
|
+
|
1082
|
+
>>> from flax import nnx
|
1083
|
+
>>> import jax.numpy as jnp
|
1084
|
+
...
|
1085
|
+
>>> rngs = nnx.Rngs(42)
|
1086
|
+
...
|
1087
|
+
>>> x = jnp.ones((1, 3))
|
1088
|
+
>>> model_nnx = nnx.MultiHeadAttention(
|
1089
|
+
... num_heads=2,
|
1090
|
+
... in_features=3,
|
1091
|
+
... qkv_features=6,
|
1092
|
+
... out_features=6,
|
1093
|
+
... decode=True,
|
1094
|
+
... rngs=rngs,
|
1095
|
+
... )
|
1096
|
+
...
|
1097
|
+
>>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized
|
1098
|
+
...
|
1099
|
+
>>> model_nnx.init_cache(x.shape)
|
1100
|
+
>>> out_nnx = model_nnx(x)
|
1101
|
+
"""
|
1102
|
+
cache_shape = (*input_shape[:-1], self.num_heads, self.head_dim)
|
1103
|
+
self.cached_key = nnx.Cache(jnp.zeros(cache_shape, dtype))
|
1104
|
+
self.cached_value = nnx.Cache(jnp.zeros(cache_shape, dtype))
|
1105
|
+
self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32))
|
1106
|
+
|
1107
|
+
|
1108
|
+
# mask-making utility functions
|
1109
|
+
|
1110
|
+
|
1111
|
+
def make_attention_mask(
|
1112
|
+
query_input: Array,
|
1113
|
+
key_input: Array,
|
1114
|
+
pairwise_fn: Callable[..., Any] = jnp.multiply,
|
1115
|
+
extra_batch_dims: int = 0,
|
1116
|
+
dtype: Dtype = jnp.float32,
|
1117
|
+
):
|
1118
|
+
"""Mask-making helper for attention weights.
|
1119
|
+
|
1120
|
+
In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the
|
1121
|
+
attention weights will be `[batch..., heads, len_q, len_kv]` and this
|
1122
|
+
function will produce `[batch..., 1, len_q, len_kv]`.
|
1123
|
+
|
1124
|
+
Args:
|
1125
|
+
query_input: a batched, flat input of query_length size
|
1126
|
+
key_input: a batched, flat input of key_length size
|
1127
|
+
pairwise_fn: broadcasting elementwise comparison function
|
1128
|
+
extra_batch_dims: number of extra batch dims to add singleton axes for, none
|
1129
|
+
by default
|
1130
|
+
dtype: mask return dtype
|
1131
|
+
|
1132
|
+
Returns:
|
1133
|
+
A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention.
|
1134
|
+
"""
|
1135
|
+
mask = pairwise_fn(
|
1136
|
+
jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2)
|
1137
|
+
)
|
1138
|
+
mask = jnp.expand_dims(mask, axis=-3)
|
1139
|
+
mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
|
1140
|
+
return mask.astype(dtype)
|
1141
|
+
|
1142
|
+
|
1143
|
+
def make_causal_mask(
|
1144
|
+
x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32
|
1145
|
+
) -> Array:
|
1146
|
+
"""Make a causal mask for self-attention.
|
1147
|
+
|
1148
|
+
In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights
|
1149
|
+
will be `[batch..., heads, len, len]` and this function will produce a
|
1150
|
+
causal mask of shape `[batch..., 1, len, len]`.
|
1151
|
+
|
1152
|
+
Args:
|
1153
|
+
x: input array of shape `[batch..., len]`
|
1154
|
+
extra_batch_dims: number of batch dims to add singleton axes for, none by
|
1155
|
+
default
|
1156
|
+
dtype: mask return dtype
|
1157
|
+
|
1158
|
+
Returns:
|
1159
|
+
A `[batch..., 1, len, len]` shaped causal mask for 1d attention.
|
1160
|
+
"""
|
1161
|
+
idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
|
1162
|
+
return make_attention_mask(
|
1163
|
+
idxs,
|
1164
|
+
idxs,
|
1165
|
+
jnp.greater_equal,
|
1166
|
+
extra_batch_dims=extra_batch_dims,
|
1167
|
+
dtype=dtype,
|
1168
|
+
)
|
1169
|
+
|
1170
|
+
|
1171
|
+
def combine_masks(
|
1172
|
+
*masks: Optional[Array], dtype: Dtype = jnp.float32
|
1173
|
+
) -> Array | None:
|
1174
|
+
"""Combine attention masks.
|
1175
|
+
|
1176
|
+
Args:
|
1177
|
+
*masks: set of attention mask arguments to combine, some can be None.
|
1178
|
+
dtype: dtype for the returned mask.
|
1179
|
+
|
1180
|
+
Returns:
|
1181
|
+
Combined mask, reduced by logical and, returns None if no masks given.
|
1182
|
+
"""
|
1183
|
+
masks_list = [m for m in masks if m is not None]
|
1184
|
+
if not masks_list:
|
1185
|
+
return None
|
1186
|
+
assert all(
|
1187
|
+
map(lambda x: x.ndim == masks_list[0].ndim, masks_list)
|
1188
|
+
), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}'
|
1189
|
+
mask, *other_masks = masks_list
|
1190
|
+
for other_mask in other_masks:
|
1191
|
+
mask = jnp.logical_and(mask, other_mask)
|
1192
|
+
return mask.astype(dtype)
|
1193
|
+
|
1194
|
+
|
1195
|
+
|
1196
|
+
# Define a triangular mask for causal attention with `jax.numpy.tril` and `jax.numpy.ones`.
|
1197
|
+
def causal_attention_mask(seq_len):
|
1198
|
+
return jnp.tril(jnp.ones((seq_len, seq_len)))
|
1199
|
+
|
1200
|
+
class TransformerBlock(nnx.Module):
|
1201
|
+
""" A single Transformer block.
|
1202
|
+
|
1203
|
+
Each Transformer block processes input sequences via self-attention and feed-forward networks.
|
1204
|
+
|
1205
|
+
Args:
|
1206
|
+
embed_dim (int): Embedding dimensionality.
|
1207
|
+
num_heads (int): Number of attention heads.
|
1208
|
+
ff_dim (int): Dimensionality of the feed-forward network.
|
1209
|
+
rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.
|
1210
|
+
rate (float): Dropout rate. Defaults to 0.1.
|
1211
|
+
"""
|
1212
|
+
def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, *, rngs: nnx.Rngs, rate: float = 0.1):
|
1213
|
+
# Multi-Head Attention (MHA) with `flax.nnx.MultiHeadAttention`.
|
1214
|
+
# Specifies tensor sharding (depending on the mesh configuration)
|
1215
|
+
# where we shard the weights across devices for parallel computation.
|
1216
|
+
self.mha = MultiHeadAttention(num_heads=num_heads,
|
1217
|
+
in_features=embed_dim,
|
1218
|
+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
|
1219
|
+
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
|
1220
|
+
rngs=rngs)
|
1221
|
+
# The first dropout with `flax.nnx.Dropout`.
|
1222
|
+
self.dropout1 = nnx.Dropout(rate=rate)
|
1223
|
+
# The first linear transformation for the feed-forward network with `flax.nnx.Linear`.
|
1224
|
+
self.nonlinear1 = YatNMN(in_features=embed_dim,
|
1225
|
+
out_features=embed_dim,
|
1226
|
+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
|
1227
|
+
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
|
1228
|
+
alpha_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))),
|
1229
|
+
rngs=rngs)
|
1230
|
+
# The second dropout with `flax.nnx.Dropout`.
|
1231
|
+
self.dropout2 = nnx.Dropout(rate=rate)
|
1232
|
+
|
1233
|
+
|
1234
|
+
# Apply the Transformer block to the input sequence.
|
1235
|
+
def __call__(self, inputs, training: bool = False):
|
1236
|
+
input_shape = inputs.shape
|
1237
|
+
_, seq_len, _ = input_shape
|
1238
|
+
|
1239
|
+
# Instantiate the causal attention mask.
|
1240
|
+
mask = causal_attention_mask(seq_len)
|
1241
|
+
|
1242
|
+
# Apply Multi-Head Attention with the causal attention mask.
|
1243
|
+
attention_output = self.mha(
|
1244
|
+
inputs_q=inputs,
|
1245
|
+
mask=mask,
|
1246
|
+
decode=False
|
1247
|
+
)
|
1248
|
+
# Apply the first dropout.
|
1249
|
+
attention_output = self.dropout1(attention_output, deterministic=not training)
|
1250
|
+
# Apply the first layer normalization.
|
1251
|
+
out1 = inputs + attention_output
|
1252
|
+
|
1253
|
+
# The feed-forward network.
|
1254
|
+
# Apply the first linear transformation.
|
1255
|
+
ffn_output = self.nonlinear1(out1)
|
1256
|
+
# Apply the second dropout.
|
1257
|
+
ffn_output = self.dropout2(ffn_output, deterministic=not training)
|
1258
|
+
# Apply the second layer normalization and return the output of the Transformer block.
|
1259
|
+
return out1 + ffn_output
|
1260
|
+
|
1261
|
+
class TokenAndPositionEmbedding(nnx.Module):
|
1262
|
+
""" Combines token embeddings (words in an input sentence) with
|
1263
|
+
positional embeddings (the position of each word in a sentence).
|
1264
|
+
|
1265
|
+
Args:
|
1266
|
+
maxlen (int): Matimum sequence length.
|
1267
|
+
vocal_size (int): Vocabulary size.
|
1268
|
+
embed_dim (int): Embedding dimensionality.
|
1269
|
+
rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.
|
1270
|
+
"""
|
1271
|
+
def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, *, rngs: nnx.Rngs):
|
1272
|
+
# Initialize token embeddings (using `flax.nnx.Embed`).
|
1273
|
+
# Each unique word has an embedding vector.
|
1274
|
+
self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)
|
1275
|
+
# Initialize positional embeddings (using `flax.nnx.Embed`).
|
1276
|
+
self.pos_emb = nnx.Embed(num_embeddings=maxlen, features=embed_dim, rngs=rngs)
|
1277
|
+
|
1278
|
+
# Takes a token sequence (integers) and returns the combined token and positional embeddings.
|
1279
|
+
def __call__(self, x):
|
1280
|
+
# Generate a sequence of positions for the input tokens.
|
1281
|
+
positions = jnp.arange(0, x.shape[1])[None, :]
|
1282
|
+
# Look up the positional embeddings for each position in the input sequence.
|
1283
|
+
position_embedding = self.pos_emb(positions)
|
1284
|
+
# Look up the token embeddings for each token in the input sequence.
|
1285
|
+
token_embedding = self.token_emb(x)
|
1286
|
+
# Combine token and positional embeddings.
|
1287
|
+
return token_embedding + position_embedding
|
1288
|
+
|
1289
|
+
class MiniGPT(nnx.Module):
|
1290
|
+
""" A miniGPT transformer model, inherits from `flax.nnx.Module`.
|
1291
|
+
|
1292
|
+
Args:
|
1293
|
+
maxlen (int): Maximum sequence length.
|
1294
|
+
vocab_size (int): Vocabulary size.
|
1295
|
+
embed_dim (int): Embedding dimensionality.
|
1296
|
+
num_heads (int): Number of attention heads.
|
1297
|
+
feed_forward_dim (int): Dimensionality of the feed-forward network.
|
1298
|
+
num_transformer_blocks (int): Number of transformer blocks. Each block contains attention and feed-forward networks.
|
1299
|
+
rngs (nnx.Rngs): A Flax NNX stream of JAX PRNG keys.
|
1300
|
+
"""
|
1301
|
+
# Initialize miniGPT model components.
|
1302
|
+
def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, num_heads: int, feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs):
|
1303
|
+
# Initiliaze the `TokenAndPositionEmbedding` that combines token and positional embeddings.
|
1304
|
+
self.embedding_layer = TokenAndPositionEmbedding(
|
1305
|
+
maxlen, vocab_size, embed_dim, rngs=rngs
|
1306
|
+
)
|
1307
|
+
# Create a list of `TransformerBlock` instances.
|
1308
|
+
# Each block processes input sequences using attention and feed-forward networks.
|
1309
|
+
self.transformer_blocks = [TransformerBlock(
|
1310
|
+
embed_dim, num_heads, feed_forward_dim, rngs=rngs
|
1311
|
+
) for _ in range(num_transformer_blocks)]
|
1312
|
+
# Initialize the output `flax.nnx.Linear` layer producing logits over the vocabulary for next-token prediction.
|
1313
|
+
self.output_layer = YatNMN(in_features=embed_dim,
|
1314
|
+
out_features=vocab_size,
|
1315
|
+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
|
1316
|
+
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),
|
1317
|
+
alpha_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))),
|
1318
|
+
use_bias=False,
|
1319
|
+
rngs=rngs,
|
1320
|
+
)
|
1321
|
+
|
1322
|
+
def __call__(self, inputs, training: bool = False):
|
1323
|
+
# Pass the input tokens through the `embedding_layer` to get token embeddings.
|
1324
|
+
# Apply each transformer block sequentially to the embedded input, use the `training` flag for the behavior of `flax.nnx.Dropout`.
|
1325
|
+
x = self.embedding_layer(inputs)
|
1326
|
+
for transformer_block in self.transformer_blocks:
|
1327
|
+
x = transformer_block(x, training=training)
|
1328
|
+
# Pass the output of the transformer blocks through the output layer,
|
1329
|
+
# and obtain logits for each token in the vocabulary (for next token prediction).
|
1330
|
+
outputs = self.output_layer(x)
|
1331
|
+
return outputs
|
1332
|
+
|
1333
|
+
# Text generation.
|
1334
|
+
def generate_text(self, max_tokens: int, start_tokens: [int], top_k=10):
|
1335
|
+
# Sample the next token from a probability distribution based on
|
1336
|
+
# `logits` and `tok_k` (top-k) sampling strategy.
|
1337
|
+
def sample_from(logits):
|
1338
|
+
logits, indices = jax.lax.top_k(logits, k=top_k)
|
1339
|
+
# Convert logits to probabilities (using `flax.nnx.softmax`).
|
1340
|
+
logits = nnx.softmax(logits)
|
1341
|
+
return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)
|
1342
|
+
|
1343
|
+
# Generate text one token at a time until the maximum token limit is reached (`maxlen`).
|
1344
|
+
def generate_step(start_tokens):
|
1345
|
+
pad_len = maxlen - len(start_tokens)
|
1346
|
+
# Index of the last token in the current sequence.
|
1347
|
+
sample_index = len(start_tokens) - 1
|
1348
|
+
# If the input is longer than `maxlen`, then truncate it.
|
1349
|
+
if pad_len < 0:
|
1350
|
+
x = jnp.array(start_tokens[:maxlen])
|
1351
|
+
sample_index = maxlen - 1
|
1352
|
+
# If the input is shorter than `maxlen`, then pad it (`pad_len`).
|
1353
|
+
elif pad_len > 0:
|
1354
|
+
x = jnp.array(start_tokens + [0] * pad_len)
|
1355
|
+
else:
|
1356
|
+
x = jnp.array(start_tokens)
|
1357
|
+
|
1358
|
+
# Add a batch dimension.
|
1359
|
+
x = x[None, :]
|
1360
|
+
logits = self(x)
|
1361
|
+
next_token = sample_from(logits[0][sample_index])
|
1362
|
+
return next_token
|
1363
|
+
|
1364
|
+
# Store generated tokens.
|
1365
|
+
generated = []
|
1366
|
+
# Generate tokens until the end-of-text token is encountered or the maximum token limit is reached.
|
1367
|
+
for _ in range(max_tokens):
|
1368
|
+
next_token = generate_step(start_tokens + generated)
|
1369
|
+
# Truncate whatever is after '<|endoftext|>' (stop word)
|
1370
|
+
if next_token == tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]:
|
1371
|
+
# Stop text generation if the end-of-text token is encountered.
|
1372
|
+
break
|
1373
|
+
generated.append(int(next_token))
|
1374
|
+
# Decode the generated token IDs into text.
|
1375
|
+
return tokenizer.decode(start_tokens + generated)
|
1376
|
+
|
1377
|
+
# Creates the miniGPT model with 4 transformer blocks.
|
1378
|
+
def create_model(rngs):
|
1379
|
+
return MiniGPT(maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim, num_transformer_blocks=4, rngs=rngs)
|
1380
|
+
|
1381
|
+
vocab_size = tokenizer.n_vocab
|
1382
|
+
num_transformer_blocks = 3
|
1383
|
+
maxlen = 1024
|
1384
|
+
embed_dim = 512
|
1385
|
+
num_heads = 8
|
1386
|
+
feed_forward_dim = 512
|
1387
|
+
batch_size = 64 # You can set a bigger batch size if you use Kaggle's Cloud TPU.
|
1388
|
+
num_epochs = 5
|
1389
|
+
|
1390
|
+
"""## Loading and preprocessing the data
|
1391
|
+
|
1392
|
+
Enhanced data loading with support for multiple Hugging Face datasets and local files.
|
1393
|
+
"""
|
1394
|
+
|
1395
|
+
# Configuration for dataset selection
|
1396
|
+
CURRENT_DATASET = "fineweb" # Change this to use different datasets
|
1397
|
+
# Available options: tinystories, wikitext, openwebtext, bookscorpus, c4,
|
1398
|
+
# tiny_shakespeare, gutenberg, pile, common_crawl, local_file
|
1399
|
+
|
1400
|
+
@dataclass
|
1401
|
+
class EnhancedTextDataset:
|
1402
|
+
"""Enhanced TextDataset with better preprocessing and flexible tokenization"""
|
1403
|
+
data: list
|
1404
|
+
maxlen: int
|
1405
|
+
tokenizer: any = tokenizer
|
1406
|
+
separator_token: str = "<|endoftext|>"
|
1407
|
+
|
1408
|
+
def __len__(self):
|
1409
|
+
return len(self.data)
|
1410
|
+
|
1411
|
+
def __getitem__(self, idx: int):
|
1412
|
+
text = self.data[idx]
|
1413
|
+
# Use Tiktoken for tokenization with proper handling of special tokens
|
1414
|
+
try:
|
1415
|
+
encoding = self.tokenizer.encode(
|
1416
|
+
text,
|
1417
|
+
allowed_special={self.separator_token}
|
1418
|
+
)[:self.maxlen]
|
1419
|
+
except Exception:
|
1420
|
+
# Fallback for texts without special tokens
|
1421
|
+
encoding = self.tokenizer.encode(text)[:self.maxlen]
|
1422
|
+
|
1423
|
+
# Pad to maxlen
|
1424
|
+
return encoding + [0] * (self.maxlen - len(encoding))
|
1425
|
+
|
1426
|
+
def load_and_preprocess_data_enhanced(
|
1427
|
+
dataset_name: str = CURRENT_DATASET,
|
1428
|
+
batch_size: int = batch_size,
|
1429
|
+
maxlen: int = maxlen,
|
1430
|
+
custom_config: Optional[DatasetConfig] = None
|
1431
|
+
) -> pygrain.DataLoader:
|
1432
|
+
"""
|
1433
|
+
Enhanced data loading function that supports multiple datasets
|
1434
|
+
|
1435
|
+
Args:
|
1436
|
+
dataset_name: Name of the dataset configuration to use
|
1437
|
+
batch_size: Batch size for data loading
|
1438
|
+
maxlen: Maximum sequence length
|
1439
|
+
custom_config: Custom dataset configuration (overrides dataset_name)
|
1440
|
+
|
1441
|
+
Returns:
|
1442
|
+
pygrain.DataLoader: Configured data loader
|
1443
|
+
"""
|
1444
|
+
|
1445
|
+
print(f"=== Dataset Loading ===")
|
1446
|
+
print(f"Requested dataset: {dataset_name}")
|
1447
|
+
|
1448
|
+
# Use custom config if provided, otherwise get predefined config
|
1449
|
+
if custom_config:
|
1450
|
+
config = custom_config
|
1451
|
+
print("Using custom dataset configuration")
|
1452
|
+
elif dataset_name in DATASET_CONFIGS:
|
1453
|
+
config = DATASET_CONFIGS[dataset_name]
|
1454
|
+
print(f"Using predefined configuration for {dataset_name}")
|
1455
|
+
else:
|
1456
|
+
print(f"Dataset '{dataset_name}' not found. Available datasets:")
|
1457
|
+
for name in list_available_datasets():
|
1458
|
+
info = get_dataset_info(name)
|
1459
|
+
print(f" - {name}: {info.get('name', 'N/A')}")
|
1460
|
+
|
1461
|
+
# Fallback to local file
|
1462
|
+
print("Falling back to local file (TinyStories)")
|
1463
|
+
config = DATASET_CONFIGS["local_file"]
|
1464
|
+
|
1465
|
+
# Load the dataset
|
1466
|
+
texts = []
|
1467
|
+
|
1468
|
+
if config.name == "local" or config.file_path:
|
1469
|
+
# Load from local file
|
1470
|
+
texts = load_local_file(config)
|
1471
|
+
else:
|
1472
|
+
# Try to load from Hugging Face
|
1473
|
+
if HF_DATASETS_AVAILABLE:
|
1474
|
+
texts = load_huggingface_dataset(config)
|
1475
|
+
|
1476
|
+
# If HF loading failed or not available, try local fallback
|
1477
|
+
if not texts and dataset_name != "local_file":
|
1478
|
+
print("Attempting to load from local TinyStories file as fallback...")
|
1479
|
+
fallback_config = DATASET_CONFIGS["local_file"]
|
1480
|
+
texts = load_local_file(fallback_config)
|
1481
|
+
|
1482
|
+
if not texts:
|
1483
|
+
raise RuntimeError(f"Failed to load any data for dataset: {dataset_name}")
|
1484
|
+
|
1485
|
+
# Create enhanced dataset
|
1486
|
+
dataset = EnhancedTextDataset(
|
1487
|
+
data=texts,
|
1488
|
+
maxlen=maxlen,
|
1489
|
+
separator_token=config.separator
|
1490
|
+
)
|
1491
|
+
|
1492
|
+
print(f"Dataset created successfully:")
|
1493
|
+
print(f" - Total samples: {len(dataset)}")
|
1494
|
+
print(f" - Max sequence length: {maxlen}")
|
1495
|
+
print(f" - Separator token: {config.separator}")
|
1496
|
+
|
1497
|
+
# Create sampler and data loader
|
1498
|
+
sampler = pygrain.IndexSampler(
|
1499
|
+
len(dataset),
|
1500
|
+
shuffle=True, # Enable shuffling for better training
|
1501
|
+
seed=42,
|
1502
|
+
shard_options=pygrain.NoSharding(),
|
1503
|
+
num_epochs=num_epochs,
|
1504
|
+
)
|
1505
|
+
|
1506
|
+
dl = pygrain.DataLoader(
|
1507
|
+
data_source=dataset,
|
1508
|
+
sampler=sampler,
|
1509
|
+
operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
|
1510
|
+
)
|
1511
|
+
|
1512
|
+
print(f"Data loader created with batch size: {batch_size}")
|
1513
|
+
print("=== Dataset Loading Complete ===\n")
|
1514
|
+
|
1515
|
+
return dl
|
1516
|
+
|
1517
|
+
def switch_dataset(new_dataset: str) -> pygrain.DataLoader:
|
1518
|
+
"""
|
1519
|
+
Utility function to quickly switch datasets
|
1520
|
+
|
1521
|
+
Args:
|
1522
|
+
new_dataset: Name of the new dataset to load
|
1523
|
+
|
1524
|
+
Returns:
|
1525
|
+
pygrain.DataLoader: New data loader
|
1526
|
+
"""
|
1527
|
+
global CURRENT_DATASET
|
1528
|
+
CURRENT_DATASET = new_dataset
|
1529
|
+
return load_and_preprocess_data_enhanced(new_dataset, batch_size, maxlen)
|
1530
|
+
|
1531
|
+
def create_custom_dataset(
|
1532
|
+
name: str,
|
1533
|
+
subset: Optional[str] = None,
|
1534
|
+
text_column: str = "text",
|
1535
|
+
separator: str = "<|endoftext|>",
|
1536
|
+
streaming: bool = False,
|
1537
|
+
min_length: int = 10
|
1538
|
+
) -> pygrain.DataLoader:
|
1539
|
+
"""
|
1540
|
+
Create a data loader for a custom Hugging Face dataset
|
1541
|
+
|
1542
|
+
Args:
|
1543
|
+
name: Hugging Face dataset name (e.g., "username/dataset_name")
|
1544
|
+
subset: Dataset subset/configuration name
|
1545
|
+
text_column: Name of the text column in the dataset
|
1546
|
+
separator: Token to separate documents
|
1547
|
+
streaming: Whether to use streaming (for large datasets)
|
1548
|
+
min_length: Minimum text length to include
|
1549
|
+
|
1550
|
+
Returns:
|
1551
|
+
pygrain.DataLoader: Configured data loader for the custom dataset
|
1552
|
+
"""
|
1553
|
+
config = DatasetConfig(
|
1554
|
+
name=name,
|
1555
|
+
subset=subset,
|
1556
|
+
text_column=text_column,
|
1557
|
+
separator=separator,
|
1558
|
+
streaming=streaming,
|
1559
|
+
min_length=min_length
|
1560
|
+
)
|
1561
|
+
|
1562
|
+
return load_and_preprocess_data_enhanced(
|
1563
|
+
dataset_name="custom",
|
1564
|
+
custom_config=config
|
1565
|
+
)
|
1566
|
+
|
1567
|
+
# Load the default dataset
|
1568
|
+
print("Loading default dataset...")
|
1569
|
+
text_dl = load_and_preprocess_data_enhanced()
|
1570
|
+
|
1571
|
+
# Print available datasets for reference
|
1572
|
+
print(f"Available predefined datasets: {list_available_datasets()}")
|
1573
|
+
print(f"Current dataset: {CURRENT_DATASET}")
|
1574
|
+
|
1575
|
+
"""## Defining the loss function and training step function"""
|
1576
|
+
|
1577
|
+
# Defines the loss function using `optax.softmax_cross_entropy_with_integer_labels`.
|
1578
|
+
def loss_fn(model, batch):
|
1579
|
+
logits = model(batch[0])
|
1580
|
+
loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
|
1581
|
+
return loss, logits
|
1582
|
+
|
1583
|
+
# Define the training step with the `flax.nnx.jit` transformation decorator.
|
1584
|
+
@nnx.jit
|
1585
|
+
def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
|
1586
|
+
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
|
1587
|
+
(loss, logits), grads = grad_fn(model, batch)
|
1588
|
+
metrics.update(loss=loss, logits=logits, lables=batch[1])
|
1589
|
+
optimizer.update(grads)
|
1590
|
+
|
1591
|
+
model = create_model(rngs=nnx.Rngs(0))
|
1592
|
+
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
|
1593
|
+
metrics = nnx.MultiMetric(
|
1594
|
+
loss=nnx.metrics.Average('loss'),
|
1595
|
+
)
|
1596
|
+
rng = jax.random.PRNGKey(0)
|
1597
|
+
|
1598
|
+
start_prompt = "Once upon a time, "
|
1599
|
+
start_tokens = tokenizer.encode(start_prompt)[:maxlen]
|
1600
|
+
generated_text = model.generate_text(
|
1601
|
+
maxlen, start_tokens
|
1602
|
+
)
|
1603
|
+
print(f"Initial generated text:\n{generated_text}\n")
|
1604
|
+
|
1605
|
+
|
1606
|
+
metrics_history = {
|
1607
|
+
'train_loss': [],
|
1608
|
+
}
|
1609
|
+
|
1610
|
+
prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))
|
1611
|
+
|
1612
|
+
step = 0
|
1613
|
+
for epoch in range(num_epochs):
|
1614
|
+
start_time = time.time()
|
1615
|
+
for batch in text_dl:
|
1616
|
+
if len(batch) % len(jax.devices()) != 0:
|
1617
|
+
continue # skip the remaining elements
|
1618
|
+
input_batch = jnp.array(jnp.array(batch).T)
|
1619
|
+
target_batch = prep_target_batch(input_batch)
|
1620
|
+
train_step(model, optimizer, metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P('batch', None))))
|
1621
|
+
|
1622
|
+
if (step + 1) % 200 == 0:
|
1623
|
+
for metric, value in metrics.compute().items():
|
1624
|
+
metrics_history[f'train_{metric}'].append(value)
|
1625
|
+
metrics.reset()
|
1626
|
+
|
1627
|
+
elapsed_time = time.time() - start_time
|
1628
|
+
print(f"Step {step + 1}, Loss: {metrics_history['train_loss'][-1]}, Elapsed Time: {elapsed_time:.2f} seconds")
|
1629
|
+
start_time = time.time()
|
1630
|
+
|
1631
|
+
generated_text = model.generate_text(
|
1632
|
+
maxlen, start_tokens
|
1633
|
+
)
|
1634
|
+
print(f"Generated text:\n{generated_text}\n")
|
1635
|
+
step += 1
|
1636
|
+
|
1637
|
+
# Final text generation
|
1638
|
+
generated_text = model.generate_text(
|
1639
|
+
maxlen, start_tokens
|
1640
|
+
)
|
1641
|
+
print(f"Final generated text:\n{generated_text}")
|
1642
|
+
|
1643
|
+
"""Visualize the training loss."""
|
1644
|
+
|
1645
|
+
import matplotlib.pyplot as plt
|
1646
|
+
plt.plot(metrics_history['train_loss'])
|
1647
|
+
plt.title('Training Loss')
|
1648
|
+
plt.xlabel('Step')
|
1649
|
+
plt.ylabel('Loss')
|
1650
|
+
plt.show()
|