BwETAF 0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- BwETAF-0.1/BwETAF/__init__.py +47 -0
- BwETAF-0.1/BwETAF/_errors.py +3 -0
- BwETAF-0.1/BwETAF/_utils.py +75 -0
- BwETAF-0.1/BwETAF/independent.py +204 -0
- BwETAF-0.1/BwETAF/layers.py +81 -0
- BwETAF-0.1/BwETAF/model.py +186 -0
- BwETAF-0.1/BwETAF.egg-info/PKG-INFO +7 -0
- BwETAF-0.1/BwETAF.egg-info/SOURCES.txt +12 -0
- BwETAF-0.1/BwETAF.egg-info/dependency_links.txt +1 -0
- BwETAF-0.1/BwETAF.egg-info/requires.txt +5 -0
- BwETAF-0.1/BwETAF.egg-info/top_level.txt +1 -0
- BwETAF-0.1/PKG-INFO +7 -0
- BwETAF-0.1/setup.cfg +4 -0
- BwETAF-0.1/setup.py +20 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import flax.serialization
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
from huggingface_hub import hf_hub_download, create_repo, upload_file, login
|
|
5
|
+
import os
|
|
6
|
+
from .independent import debug_state
|
|
7
|
+
from ._utils import convert_tree
|
|
8
|
+
from .model import *
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@debug_state.trace_func
|
|
12
|
+
def load_model(path,dtype= None):
|
|
13
|
+
with open(os.path.join(path, "understanding_good_stuff.json"), "r") as f:
|
|
14
|
+
data = json.load(f)
|
|
15
|
+
|
|
16
|
+
model = ModelManager(data["num_heads"],data["attention_dim"],data["vocab_size"],data["num_blocks"],data["ff_dim"],data["dropout_rate"],dtype=dtype)
|
|
17
|
+
with open(os.path.join(path, "good_stuff.pkl"), "rb") as f:
|
|
18
|
+
model.params = convert_tree(dtype,flax.serialization.from_bytes(model.params, f.read()))
|
|
19
|
+
return model
|
|
20
|
+
|
|
21
|
+
@debug_state.trace_func
|
|
22
|
+
def load_hf(path,dtype= None):
|
|
23
|
+
model_repo = path
|
|
24
|
+
filenames = ["understanding_good_stuff.json","good_stuff.pkl","make_stuff_better.pkl"]
|
|
25
|
+
for i in filenames:
|
|
26
|
+
try:
|
|
27
|
+
print(hf_hub_download(repo_id=model_repo, filename=i,local_dir="Loaded_model"))
|
|
28
|
+
except:
|
|
29
|
+
print(f"No {i} found")
|
|
30
|
+
return load_model("Loaded_model",dtype)
|
|
31
|
+
|
|
32
|
+
@debug_state.trace_func
|
|
33
|
+
def push_model(repo_name, folder_path):
|
|
34
|
+
files_to_upload = ["good_stuff.pkl", "understanding_good_stuff.json","make_stuff_better.pkl"]
|
|
35
|
+
|
|
36
|
+
create_repo(repo_name, exist_ok=True) # Create repo if it doesn’t exist
|
|
37
|
+
|
|
38
|
+
for file_name in files_to_upload:
|
|
39
|
+
file_path = os.path.join(folder_path, file_name)
|
|
40
|
+
if os.path.isfile(file_path): # Only upload if the file exists
|
|
41
|
+
upload_file(
|
|
42
|
+
path_or_fileobj=file_path,
|
|
43
|
+
path_in_repo=file_name, # Save with the same filename
|
|
44
|
+
repo_id=repo_name,
|
|
45
|
+
repo_type="model",
|
|
46
|
+
)
|
|
47
|
+
print(f"Uploaded {files_to_upload} to {repo_name}")
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import optax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import numpy as np
|
|
5
|
+
from ._errors import *
|
|
6
|
+
import time
|
|
7
|
+
from functools import partial
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def time_it(fn, *args):
|
|
11
|
+
t0 = time.time()
|
|
12
|
+
out = fn(*args)
|
|
13
|
+
t1 = time.time()
|
|
14
|
+
return out, t1 - t0
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def loss_fn(params, model,batch, rng):
|
|
18
|
+
"""Computes sparse categorical cross-entropy loss with autoregressive shifting."""
|
|
19
|
+
inputs, mask, targets = batch # Inputs: (batch, seq_len), Targets: (batch, seq_len)
|
|
20
|
+
|
|
21
|
+
logits = model.apply(params, inputs, mask,rngs={"dropout": rng}) # Forward pass
|
|
22
|
+
log_probs = jax.nn.log_softmax(logits) # Convert logits to log probabilities
|
|
23
|
+
|
|
24
|
+
# Shift targets left: Model predicts targets[:, 1:] given inputs[:, :-1]
|
|
25
|
+
shifted_targets = targets[:, 1:] # Remove first token
|
|
26
|
+
shifted_logits = log_probs[:, :-1, :] # Remove last token prediction
|
|
27
|
+
|
|
28
|
+
# Get probability of correct class
|
|
29
|
+
target_probs = jnp.take_along_axis(shifted_logits, shifted_targets[..., None], axis=-1)[..., 0]
|
|
30
|
+
|
|
31
|
+
loss = -target_probs.mean() # Negative log-likelihood loss
|
|
32
|
+
|
|
33
|
+
return loss
|
|
34
|
+
|
|
35
|
+
def val_loss(params, loss_fn, model_struct, x,mask,y, key):
|
|
36
|
+
return loss_fn(params, model_struct, [x,mask,y], key)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
val_loss = jax.jit(
|
|
40
|
+
val_loss,
|
|
41
|
+
static_argnums=(1,2)
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
val_loss = jax.pmap(
|
|
45
|
+
val_loss,
|
|
46
|
+
in_axes=(None,None,None,0,0,0,None),
|
|
47
|
+
static_broadcasted_argnums=(1,2),
|
|
48
|
+
axis_name='batch'
|
|
49
|
+
)
|
|
50
|
+
def BatchTrain(params, grad_fn, model_struct, x, mask, y, key, optimizer, opt_state):
|
|
51
|
+
# Yeh this too... Sry I have to hide it
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
BatchTrain = jax.jit(
|
|
55
|
+
BatchTrain,
|
|
56
|
+
static_argnums=(1, 2, 7)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
BatchTrain = jax.pmap(
|
|
60
|
+
BatchTrain,
|
|
61
|
+
static_broadcasted_argnums=(1, 2, 7),
|
|
62
|
+
in_axes=(None, None, None, 0, 0, 0, None, None, None),
|
|
63
|
+
axis_name="batch",
|
|
64
|
+
out_axes=None
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def get_first(pytree):
|
|
68
|
+
return jax.tree_util.tree_map(lambda x: x[0], pytree)
|
|
69
|
+
|
|
70
|
+
def convert_tree(dtype,pytree):
|
|
71
|
+
return jax.tree_util.tree_map(lambda x: x.astype(dtype),pytree)
|
|
72
|
+
|
|
73
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
74
|
+
def call_model_jit(model_struct, params, x, mask, rngs):
|
|
75
|
+
return model_struct.apply(params, x, mask, rngs=rngs, training=False)
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
3
|
+
import traceback
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from ._errors import *
|
|
7
|
+
import optax
|
|
8
|
+
from ._utils import *
|
|
9
|
+
import os
|
|
10
|
+
import flax
|
|
11
|
+
import time
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Debugger():
|
|
15
|
+
def __init__(self,debug = False,path=None) -> None:
|
|
16
|
+
self.debug = debug
|
|
17
|
+
self.path = path
|
|
18
|
+
if path is not None:
|
|
19
|
+
self.logfile = open(path, "a")
|
|
20
|
+
|
|
21
|
+
def turn_debugger_on(self):
|
|
22
|
+
self.debug = True
|
|
23
|
+
self.logfile = open(self.path, "a")
|
|
24
|
+
|
|
25
|
+
def logger(self,message,state="DEBUG"):
|
|
26
|
+
if self.debug:
|
|
27
|
+
if self.path is None:
|
|
28
|
+
print(f"[{state}] {message}")
|
|
29
|
+
else:
|
|
30
|
+
self.logfile.write(f"[{state}] {message}\n")
|
|
31
|
+
self.logfile.flush()
|
|
32
|
+
|
|
33
|
+
def trace_func(self,func):
|
|
34
|
+
def wrapper(*args, **kwargs):
|
|
35
|
+
self.logger(f"Calling {func.__name__} with args: {[type(i) for i in args]}, kwargs: {[type(i) for i in kwargs]}",state="FUNC_CALL")
|
|
36
|
+
start_time = time.time()
|
|
37
|
+
try:
|
|
38
|
+
out = func(*args, **kwargs)
|
|
39
|
+
except Exception as e:
|
|
40
|
+
self.logger(f"{func.__name__} Exception:{e}",state="ERROR")
|
|
41
|
+
self.logger(traceback.format_exc(), state="TRACEBACK")
|
|
42
|
+
raise
|
|
43
|
+
name = f"{args[0].__class__.__name__}.__init__" if func.__name__ == "__init__" else func.__name__
|
|
44
|
+
self.logger(f"{name} returned: {type(out)} with shape/info: {getattr(out, 'shape', 'uk')}, {getattr(out, 'dtype', 'uk')}, Time taken:{time.time()-start_time}s",state="FUNC_RETURN")
|
|
45
|
+
return out
|
|
46
|
+
return wrapper
|
|
47
|
+
|
|
48
|
+
debug_state = Debugger()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Tokenization():
|
|
52
|
+
@debug_state.trace_func
|
|
53
|
+
def __init__(self,vocab="gpt2") -> None:
|
|
54
|
+
import tiktoken
|
|
55
|
+
self.stuff = tiktoken
|
|
56
|
+
self.vocab = vocab
|
|
57
|
+
self.enc = tiktoken.get_encoding(self.vocab)
|
|
58
|
+
|
|
59
|
+
@debug_state.trace_func
|
|
60
|
+
def tokenize(self,batch:list, workers:int, max_length:int):
|
|
61
|
+
self.enc = self.stuff.get_encoding(self.vocab)
|
|
62
|
+
enc = self.enc
|
|
63
|
+
eos_token = 50256
|
|
64
|
+
pad_token = 0
|
|
65
|
+
|
|
66
|
+
def encode_and_pad(text):
|
|
67
|
+
tokens = enc.encode(text, allowed_special={'<|endoftext|>'})[:max_length]
|
|
68
|
+
padded = np.full(max_length, pad_token, dtype=np.int32)
|
|
69
|
+
mask = np.zeros(max_length, dtype=np.int32)
|
|
70
|
+
padded[:len(tokens)] = tokens
|
|
71
|
+
mask[:len(tokens)] = 1 # Mark real tokens as 1
|
|
72
|
+
return padded, mask
|
|
73
|
+
|
|
74
|
+
with ThreadPoolExecutor(max_workers=workers) as executor:
|
|
75
|
+
results = list(executor.map(encode_and_pad, batch))
|
|
76
|
+
|
|
77
|
+
encoded_batch, mask = zip(*results) # Split tokens and masks
|
|
78
|
+
encoded_batch = np.array(encoded_batch)
|
|
79
|
+
mask = np.array(mask)
|
|
80
|
+
|
|
81
|
+
return encoded_batch, mask
|
|
82
|
+
|
|
83
|
+
@debug_state.trace_func
|
|
84
|
+
def tokenize_(self, batch: list):
|
|
85
|
+
self.enc = self.stuff.get_encoding(self.vocab)
|
|
86
|
+
enc = self.enc
|
|
87
|
+
eos_token = 50256
|
|
88
|
+
|
|
89
|
+
encoded_batch = []
|
|
90
|
+
mask = []
|
|
91
|
+
|
|
92
|
+
for text in batch:
|
|
93
|
+
tokens = enc.encode(text, allowed_special={'<|endoftext|>'})
|
|
94
|
+
encoded_batch.append(np.array(tokens, dtype=np.int32))
|
|
95
|
+
mask.append(np.ones(len(tokens), dtype=np.int32)) # Mask matches token length
|
|
96
|
+
|
|
97
|
+
return jnp.array(encoded_batch), jnp.array(mask)
|
|
98
|
+
|
|
99
|
+
@debug_state.trace_func
|
|
100
|
+
def decode(self,tokens):
|
|
101
|
+
return self.enc.decode(tokens)
|
|
102
|
+
|
|
103
|
+
class Flax_ds():
|
|
104
|
+
@debug_state.trace_func
|
|
105
|
+
def __init__(self,x_eq_y:bool) -> None:
|
|
106
|
+
self.x_eq_y = x_eq_y
|
|
107
|
+
self.x = None
|
|
108
|
+
self.mask = None
|
|
109
|
+
self.y = None
|
|
110
|
+
self.batch = None
|
|
111
|
+
|
|
112
|
+
@debug_state.trace_func
|
|
113
|
+
def load_data(self,x,mask,y):
|
|
114
|
+
self.x = np.array(x)
|
|
115
|
+
self.mask = mask
|
|
116
|
+
if not self.x_eq_y:
|
|
117
|
+
self.y = np.array(y)
|
|
118
|
+
|
|
119
|
+
@debug_state.trace_func
|
|
120
|
+
def batch_it_(self,batch_size):
|
|
121
|
+
if not self.x_eq_y:
|
|
122
|
+
self.x = jnp.array(self.x)
|
|
123
|
+
self.mask = jnp.array(self.mask)
|
|
124
|
+
self.y = jnp.array(self.y)
|
|
125
|
+
seq_len = self.x.shape[1]
|
|
126
|
+
|
|
127
|
+
n_batches = len(self.x) // batch_size
|
|
128
|
+
num_devices = jax.device_count()
|
|
129
|
+
|
|
130
|
+
x_batch = [self.x[i * batch_size:(i + 1) * batch_size] for i in range(n_batches)]
|
|
131
|
+
mask_batch = [self.mask[i * batch_size:(i + 1) * batch_size] for i in range(n_batches)]
|
|
132
|
+
y_batch = [self.y[i * batch_size:(i + 1) * batch_size] for i in range(n_batches)]
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
self.batch = [[i.reshape(num_devices,-1, seq_len), j.reshape(num_devices,-1, seq_len), k.reshape(num_devices,-1, seq_len)] for i, j, k in zip(x_batch, mask_batch, y_batch)]
|
|
136
|
+
|
|
137
|
+
del self.x, self.mask, self.y
|
|
138
|
+
return self.batch
|
|
139
|
+
|
|
140
|
+
else:
|
|
141
|
+
self.x = jnp.array(self.x)
|
|
142
|
+
self.mask = jnp.array(self.mask)
|
|
143
|
+
seq_len = self.x.shape[1]
|
|
144
|
+
|
|
145
|
+
n_batches = len(self.x) // batch_size
|
|
146
|
+
num_devices = jax.device_count()
|
|
147
|
+
|
|
148
|
+
x_batch = [self.x[i * batch_size:(i + 1) * batch_size] for i in range(n_batches)]
|
|
149
|
+
mask_batch = [self.mask[i * batch_size:(i + 1) * batch_size] for i in range(n_batches)]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
self.batch = [[i.reshape(num_devices,-1, seq_len), j.reshape(num_devices,-1, seq_len)] for i, j in zip(x_batch, mask_batch)]
|
|
153
|
+
|
|
154
|
+
del self.x, self.mask
|
|
155
|
+
return self.batch
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def __len__(self):
|
|
159
|
+
return len(self.batch)
|
|
160
|
+
|
|
161
|
+
def stream_it(self):
|
|
162
|
+
if self.batch == None:
|
|
163
|
+
IncorrectDtype("Bruh... You forgot to run '.batch_it' before trying to stream it.... T~T")
|
|
164
|
+
if self.x_eq_y:
|
|
165
|
+
for i in self.batch:
|
|
166
|
+
yield i[0],i[1],i[0]
|
|
167
|
+
else:
|
|
168
|
+
for i in self.batch:
|
|
169
|
+
yield i[0],i[1],i[2]
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def gimme_the_data(self):
|
|
173
|
+
return self.batch
|
|
174
|
+
|
|
175
|
+
class Optimizer():
|
|
176
|
+
@debug_state.trace_func
|
|
177
|
+
def __init__(self,optimizer,lr,lrf,batches,epochs,params):
|
|
178
|
+
decay_rate = (lrf / lr) ** (1 / (batches * epochs))
|
|
179
|
+
self.lr_schedule = optax.exponential_decay(
|
|
180
|
+
init_value=lr,
|
|
181
|
+
transition_steps=1,
|
|
182
|
+
decay_rate=decay_rate,
|
|
183
|
+
staircase=False # Smooth decay
|
|
184
|
+
)
|
|
185
|
+
self.optimizer = optimizer(self.lr_schedule)
|
|
186
|
+
self.state = self.optimizer.init(params)
|
|
187
|
+
|
|
188
|
+
@debug_state.trace_func
|
|
189
|
+
def load(self,path,dtype=None):
|
|
190
|
+
try:
|
|
191
|
+
with open(os.path.join(path, "make_stuff_better.pkl"), "rb") as f:
|
|
192
|
+
self.state = flax.serialization.from_bytes(self.state, f.read())
|
|
193
|
+
if dtype is not None:
|
|
194
|
+
self.state = convert_tree(dtype,self.state)
|
|
195
|
+
print("Using loaded optimizer states")
|
|
196
|
+
except:
|
|
197
|
+
print("No optimizers states found")
|
|
198
|
+
|
|
199
|
+
@debug_state.trace_func
|
|
200
|
+
def save(self,path):
|
|
201
|
+
os.makedirs(path, exist_ok=True)
|
|
202
|
+
with open(os.path.join(path, "make_stuff_better.pkl"), "wb") as f:
|
|
203
|
+
f.write(flax.serialization.to_bytes(self.state))
|
|
204
|
+
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import flax.linen as nn
|
|
4
|
+
|
|
5
|
+
class PosEnc(nn.Module):
|
|
6
|
+
dim : int
|
|
7
|
+
dtype: jnp.dtype = jnp.bfloat16
|
|
8
|
+
|
|
9
|
+
@nn.compact
|
|
10
|
+
def __call__(self, x):
|
|
11
|
+
batch_size, sequence_length, _ = x.shape
|
|
12
|
+
|
|
13
|
+
# Compute div term once (avoiding repeated `exp` calls)
|
|
14
|
+
div_term = jnp.exp(-jnp.arange(0, self.dim, 2) * (jnp.log(10000.0) / self.dim)).astype(self.dtype)
|
|
15
|
+
|
|
16
|
+
# Compute positions in one step (efficiently broadcasting)
|
|
17
|
+
position = jnp.arange(sequence_length)[:, None] * div_term # (seq_len, emb_dim/2)
|
|
18
|
+
|
|
19
|
+
# Directly compute sine & cosine, then interleave them
|
|
20
|
+
pos_enc = jnp.zeros((sequence_length, self.dim),dtype=self.dtype)
|
|
21
|
+
pos_enc = pos_enc.at[:, 0::2].set(jnp.sin(position)).astype(self.dtype)
|
|
22
|
+
pos_enc = pos_enc.at[:, 1::2].set(jnp.cos(position)).astype(self.dtype)
|
|
23
|
+
|
|
24
|
+
# Expand for batch & return
|
|
25
|
+
return x + pos_enc[None, :, :]
|
|
26
|
+
|
|
27
|
+
class Attention(nn.Module):
|
|
28
|
+
num_heads: int
|
|
29
|
+
d_model: int
|
|
30
|
+
dtype: jnp.dtype = jnp.bfloat16
|
|
31
|
+
|
|
32
|
+
def setup(self):
|
|
33
|
+
assert self.d_model % self.num_heads == 0, "d_model must be divisible by num_heads"
|
|
34
|
+
self.depth = self.d_model // self.num_heads
|
|
35
|
+
self.qkv_dense = nn.Dense(features=3 * self.d_model, kernel_init=nn.initializers.normal(stddev=0.02),dtype=self.dtype)
|
|
36
|
+
self.out_dense = nn.Dense(features=self.d_model, kernel_init=nn.initializers.normal(stddev=0.02),dtype=self.dtype)
|
|
37
|
+
|
|
38
|
+
def __call__(self, x, mask):
|
|
39
|
+
batch_size, seq_len, _ = x.shape
|
|
40
|
+
qkv = self.qkv_dense(x) # (batch, seq_len, 3 * d_model)
|
|
41
|
+
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.depth)
|
|
42
|
+
qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4)) # (3, batch, num_heads, seq_len, depth)
|
|
43
|
+
Q, K, V = qkv # Unpacking (batch, num_heads, seq_len, depth)
|
|
44
|
+
|
|
45
|
+
# Scaled Dot-Product Attention
|
|
46
|
+
logits = jnp.einsum("bhqd,bhkd->bhqk", Q, K) / jnp.sqrt(self.depth)
|
|
47
|
+
|
|
48
|
+
if mask is not None:
|
|
49
|
+
mask = mask[:, None, :] # Expand for broadcasting (batch, 1, seq_len, seq_len)
|
|
50
|
+
logits = jnp.where(mask, logits, -1e9)
|
|
51
|
+
|
|
52
|
+
attn_weights = jax.nn.softmax(logits, axis=-1)
|
|
53
|
+
attn_output = jnp.einsum("bhqk,bhkd->bhqd", attn_weights, V)
|
|
54
|
+
|
|
55
|
+
# Concatenate heads
|
|
56
|
+
attn_output = jnp.transpose(attn_output, (0, 2, 1, 3)) # (batch, seq_len, num_heads, depth)
|
|
57
|
+
concat_output = attn_output.reshape(batch_size, seq_len, self.d_model) # (batch, seq_len, d_model)
|
|
58
|
+
|
|
59
|
+
return self.out_dense(concat_output)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class Block(nn.Module):
|
|
63
|
+
num_heads : int
|
|
64
|
+
attention_dim : int
|
|
65
|
+
ff_dim : int
|
|
66
|
+
dropout_rate : float
|
|
67
|
+
dtype: jnp.dtype = jnp.bfloat16
|
|
68
|
+
|
|
69
|
+
@nn.compact
|
|
70
|
+
def __call__(self, x_inp, mask, train: bool):
|
|
71
|
+
x = nn.LayerNorm(dtype=self.dtype)(x_inp)
|
|
72
|
+
x = Attention(self.num_heads, self.attention_dim,dtype=self.dtype)(x, mask)
|
|
73
|
+
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
|
|
74
|
+
x_inp = x + x_inp
|
|
75
|
+
|
|
76
|
+
# Pre-LN before FFN
|
|
77
|
+
x = nn.LayerNorm(dtype=self.dtype)(x_inp)
|
|
78
|
+
x = nn.Dense(self.ff_dim, kernel_init=nn.initializers.normal(stddev=0.02),dtype=self.dtype)(x)
|
|
79
|
+
x = nn.gelu(x)
|
|
80
|
+
x = nn.Dense(self.attention_dim, kernel_init=nn.initializers.normal(stddev=0.02),dtype=self.dtype)(x)
|
|
81
|
+
return x + x_inp
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import flax.linen as nn
|
|
4
|
+
from .layers import *
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import flax.serialization
|
|
8
|
+
import numpy as np
|
|
9
|
+
from ._errors import *
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
11
|
+
import numpy as np
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
from ._utils import *
|
|
14
|
+
from .independent import *
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
rng = jax.random.PRNGKey(0)
|
|
18
|
+
|
|
19
|
+
class Model(nn.Module):
|
|
20
|
+
num_heads: int
|
|
21
|
+
attention_dim: int
|
|
22
|
+
vocab_size: int
|
|
23
|
+
num_blocks: int
|
|
24
|
+
ff_dim: int
|
|
25
|
+
dropout_rate: float
|
|
26
|
+
dtype: jnp.dtype = jnp.float32
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def setup(self):
|
|
30
|
+
self.emb = nn.Embed(num_embeddings=self.vocab_size,features=self.attention_dim, embedding_init=nn.initializers.normal(stddev=0.02),dtype=self.dtype)
|
|
31
|
+
self.pos_enc = PosEnc(self.attention_dim)
|
|
32
|
+
self.blocks = [Block(num_heads=self.num_heads,attention_dim=self.attention_dim,ff_dim=self.ff_dim,dropout_rate=self.dropout_rate,dtype=self.dtype)for i in range(self.num_blocks)]
|
|
33
|
+
|
|
34
|
+
def __call__(self,x,mask,training=True):
|
|
35
|
+
mask = self.process_mask(mask)
|
|
36
|
+
x = x.astype(jnp.int32)
|
|
37
|
+
x = self.emb(x)
|
|
38
|
+
x = self.pos_enc(x)
|
|
39
|
+
for i in self.blocks:
|
|
40
|
+
x = i(x,mask,training)
|
|
41
|
+
return x @ self.emb.embedding.T
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def process_mask(self,mask):
|
|
45
|
+
_, seq_len = mask.shape
|
|
46
|
+
|
|
47
|
+
# Create causal mask (lower triangular matrix)
|
|
48
|
+
causal_mask = jnp.tril(jnp.ones((seq_len, seq_len)))
|
|
49
|
+
|
|
50
|
+
# Reshape padding mask and apply to causal mask
|
|
51
|
+
mask = mask[:, None, :] # (batch_size, 1, seq_len)
|
|
52
|
+
mask_sq = causal_mask[None, :, :] * mask # (batch_size, seq_len, seq_len)
|
|
53
|
+
mask_sq = jnp.transpose(mask_sq, (0, 2, 1)) * mask
|
|
54
|
+
mask_sq = jnp.transpose(mask_sq, (0, 2, 1))
|
|
55
|
+
|
|
56
|
+
return mask_sq
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ModelManager():
|
|
60
|
+
@debug_state.trace_func
|
|
61
|
+
def __init__(self,num_heads,attention_dim,vocab_size,num_blocks,ff_dim,dropout_rate,dtype = None) -> None:
|
|
62
|
+
self.key = jax.random.PRNGKey(0)
|
|
63
|
+
self.model_struct = Model(num_heads,attention_dim,vocab_size,num_blocks,ff_dim,dropout_rate,dtype)
|
|
64
|
+
self.params = self.model_struct.init(self.key,jax.random.normal(self.key,(2, 11)),jnp.ones((2,11)))
|
|
65
|
+
if dtype is not None:
|
|
66
|
+
self.params = convert_tree(dtype,self.params)
|
|
67
|
+
|
|
68
|
+
self.optimizer = None
|
|
69
|
+
|
|
70
|
+
self.data = {
|
|
71
|
+
"num_heads":num_heads,
|
|
72
|
+
"attention_dim":attention_dim,
|
|
73
|
+
"vocab_size":vocab_size,
|
|
74
|
+
"num_blocks":num_blocks,
|
|
75
|
+
"ff_dim":ff_dim,
|
|
76
|
+
"dropout_rate":dropout_rate
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def __call__(self,input,mask):
|
|
81
|
+
return self.model_struct.apply(self.params,input,mask,rngs={"dropout": self.key_bruh},training=False)
|
|
82
|
+
|
|
83
|
+
def jax_call(self,input,mask):
|
|
84
|
+
rngs = rngs={"dropout": self.key}
|
|
85
|
+
return call_model_jit(self.model_struct,self.params,input,mask,rngs)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def trainable_variables(self):
|
|
90
|
+
return self.params
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def key_bruh(self):
|
|
94
|
+
self.key, subkey = jax.random.split(self.key)
|
|
95
|
+
return subkey
|
|
96
|
+
|
|
97
|
+
@debug_state.trace_func
|
|
98
|
+
def training_setup(self,optimizer,lr,lrf,batches,epochs,state_path="",opt_state_dtype=None):
|
|
99
|
+
self.optimizer = Optimizer(optimizer,lr,lrf,batches,epochs,self.params)
|
|
100
|
+
self.optimizer.load(state_path,opt_state_dtype)
|
|
101
|
+
self.grad_fn = jax.value_and_grad(loss_fn)
|
|
102
|
+
return self.optimizer.lr_schedule
|
|
103
|
+
|
|
104
|
+
@debug_state.trace_func
|
|
105
|
+
def train_batch(self,x,mask,y):
|
|
106
|
+
key = self.key_bruh
|
|
107
|
+
(loss, self.params, self.optimizer.state), first_time = time_it(BatchTrain,self.params,self.grad_fn,self.model_struct,x,mask,y,key,self.optimizer.optimizer,self.optimizer.state)
|
|
108
|
+
return loss , [first_time,0,0]
|
|
109
|
+
|
|
110
|
+
@debug_state.trace_func
|
|
111
|
+
def save_model(self,name,opt_state=True):
|
|
112
|
+
os.makedirs(name, exist_ok=True)
|
|
113
|
+
with open(os.path.join(name, "good_stuff.pkl"), "wb") as f:
|
|
114
|
+
f.write(flax.serialization.to_bytes(self.trainable_variables))
|
|
115
|
+
|
|
116
|
+
with open(os.path.join(name, "understanding_good_stuff.json"),"w") as f:
|
|
117
|
+
json.dump(self.data, f, indent=2)
|
|
118
|
+
|
|
119
|
+
if (opt_state) and (self.optimizer is not None):
|
|
120
|
+
with open(os.path.join(name, "make_stuff_better.pkl"), "wb") as f:
|
|
121
|
+
f.write(flax.serialization.to_bytes(self.optimizer.state))
|
|
122
|
+
|
|
123
|
+
@debug_state.trace_func
|
|
124
|
+
def batch_it(self, x, mask, y, batch_size, x_eq_y=True):
|
|
125
|
+
dataset = Flax_ds(x_eq_y)
|
|
126
|
+
dataset.load_data(x,mask,y)
|
|
127
|
+
dataset.batch_it_(batch_size=batch_size)
|
|
128
|
+
return dataset
|
|
129
|
+
|
|
130
|
+
@debug_state.trace_func
|
|
131
|
+
def train(self,x,mask,y,epochs,batch_size,optimizer,lr,lrf,val_x=None,val_mask=None,val_y=None,val_step=100,updates_in=1,avg_mem=1500,state_path=None):
|
|
132
|
+
pass # I don't want to expose the training stuff... So... I hope you understand :D
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@debug_state.trace_func
|
|
136
|
+
def summary(self):
|
|
137
|
+
def count_params(params):
|
|
138
|
+
total = 0
|
|
139
|
+
for value in params.values():
|
|
140
|
+
if isinstance(value, dict):
|
|
141
|
+
total += count_params(value)
|
|
142
|
+
elif hasattr(value, 'size'):
|
|
143
|
+
total += value.size
|
|
144
|
+
return total
|
|
145
|
+
|
|
146
|
+
for i in list(self.trainable_variables['params'].keys()):
|
|
147
|
+
print(f"{i} :{count_params(self.trainable_variables['params'].get(i, {})):,}")
|
|
148
|
+
print("-------------------")
|
|
149
|
+
print(f"Total :{count_params(self.trainable_variables['params']):,}")
|
|
150
|
+
|
|
151
|
+
@debug_state.trace_func
|
|
152
|
+
def change_precision(self,dtype):
|
|
153
|
+
self.params = jax.tree_util.tree_map(lambda x: x.astype(dtype),self.params)
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def precision(self):
|
|
157
|
+
type_tree = jax.tree_util.tree_map(lambda x: x.dtype,self.model)
|
|
158
|
+
types = jax.tree_util.tree_leaves(type_tree)
|
|
159
|
+
if len(set(types)) == 1:
|
|
160
|
+
print(f"Model dtype:{types[0]}")
|
|
161
|
+
else:
|
|
162
|
+
print("Model contrains mixed dtypes")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
### Test stuff for now ok?
|
|
167
|
+
### Bruh your forgot to get the better predict from googel collab ;-;
|
|
168
|
+
|
|
169
|
+
@debug_state.trace_func
|
|
170
|
+
def plot(losses, num_points=100, chop_off=100):
|
|
171
|
+
if chop_off >= len(losses):
|
|
172
|
+
raise ValueError("chop_off is greater than or equal to the length of losses")
|
|
173
|
+
|
|
174
|
+
smoothed_losses = np.cumsum(losses) / (np.arange(len(losses)) + 1)
|
|
175
|
+
|
|
176
|
+
interval = max(len(losses) // num_points, 1)
|
|
177
|
+
sampled_losses = smoothed_losses[::interval][chop_off:]
|
|
178
|
+
sampled_batches = np.arange(len(losses))[::interval][chop_off:]
|
|
179
|
+
|
|
180
|
+
# Plot
|
|
181
|
+
plt.figure(figsize=(10, 5))
|
|
182
|
+
plt.plot(sampled_batches, sampled_losses, marker='o', linestyle='-')
|
|
183
|
+
plt.xlabel('Batch')
|
|
184
|
+
plt.ylabel('Loss')
|
|
185
|
+
plt.title('Smoothed Loss over Batches')
|
|
186
|
+
plt.show()
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
setup.py
|
|
2
|
+
BwETAF/__init__.py
|
|
3
|
+
BwETAF/_errors.py
|
|
4
|
+
BwETAF/_utils.py
|
|
5
|
+
BwETAF/independent.py
|
|
6
|
+
BwETAF/layers.py
|
|
7
|
+
BwETAF/model.py
|
|
8
|
+
BwETAF.egg-info/PKG-INFO
|
|
9
|
+
BwETAF.egg-info/SOURCES.txt
|
|
10
|
+
BwETAF.egg-info/dependency_links.txt
|
|
11
|
+
BwETAF.egg-info/requires.txt
|
|
12
|
+
BwETAF.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
BwETAF
|
BwETAF-0.1/PKG-INFO
ADDED
BwETAF-0.1/setup.cfg
ADDED
BwETAF-0.1/setup.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from setuptools import setup, find_packages
|
|
2
|
+
|
|
3
|
+
setup(
|
|
4
|
+
name="BwETAF",
|
|
5
|
+
version="0.1",
|
|
6
|
+
packages=find_packages(),
|
|
7
|
+
install_requires=[
|
|
8
|
+
"flax",
|
|
9
|
+
"jax",
|
|
10
|
+
"huggingface_hub",
|
|
11
|
+
"optax",
|
|
12
|
+
"numpy"
|
|
13
|
+
],
|
|
14
|
+
description="Module to load BwETAF models (Flax)",
|
|
15
|
+
author="Boring._.wicked",
|
|
16
|
+
classifiers=[
|
|
17
|
+
"Programming Language :: Python :: 3",
|
|
18
|
+
"License :: OSI Approved :: MIT License",
|
|
19
|
+
],
|
|
20
|
+
)
|