BwETAF 0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,47 @@
1
+ import flax.serialization
2
+ import os
3
+ import json
4
+ from huggingface_hub import hf_hub_download, create_repo, upload_file, login
5
+ import os
6
+ from .independent import debug_state
7
+ from ._utils import convert_tree
8
+ from .model import *
9
+
10
+
11
+ @debug_state.trace_func
12
+ def load_model(path,dtype= None):
13
+ with open(os.path.join(path, "understanding_good_stuff.json"), "r") as f:
14
+ data = json.load(f)
15
+
16
+ model = ModelManager(data["num_heads"],data["attention_dim"],data["vocab_size"],data["num_blocks"],data["ff_dim"],data["dropout_rate"],dtype=dtype)
17
+ with open(os.path.join(path, "good_stuff.pkl"), "rb") as f:
18
+ model.params = convert_tree(dtype,flax.serialization.from_bytes(model.params, f.read()))
19
+ return model
20
+
21
+ @debug_state.trace_func
22
+ def load_hf(path,dtype= None):
23
+ model_repo = path
24
+ filenames = ["understanding_good_stuff.json","good_stuff.pkl","make_stuff_better.pkl"]
25
+ for i in filenames:
26
+ try:
27
+ print(hf_hub_download(repo_id=model_repo, filename=i,local_dir="Loaded_model"))
28
+ except:
29
+ print(f"No {i} found")
30
+ return load_model("Loaded_model",dtype)
31
+
32
+ @debug_state.trace_func
33
+ def push_model(repo_name, folder_path):
34
+ files_to_upload = ["good_stuff.pkl", "understanding_good_stuff.json","make_stuff_better.pkl"]
35
+
36
+ create_repo(repo_name, exist_ok=True) # Create repo if it doesn’t exist
37
+
38
+ for file_name in files_to_upload:
39
+ file_path = os.path.join(folder_path, file_name)
40
+ if os.path.isfile(file_path): # Only upload if the file exists
41
+ upload_file(
42
+ path_or_fileobj=file_path,
43
+ path_in_repo=file_name, # Save with the same filename
44
+ repo_id=repo_name,
45
+ repo_type="model",
46
+ )
47
+ print(f"Uploaded {files_to_upload} to {repo_name}")
@@ -0,0 +1,3 @@
1
+ class IncorrectDtype(Exception):
2
+ def __init__(self,message):
3
+ super().__init__(message)
@@ -0,0 +1,75 @@
1
+ import jax
2
+ import optax
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+ from ._errors import *
6
+ import time
7
+ from functools import partial
8
+
9
+
10
+ def time_it(fn, *args):
11
+ t0 = time.time()
12
+ out = fn(*args)
13
+ t1 = time.time()
14
+ return out, t1 - t0
15
+
16
+
17
+ def loss_fn(params, model,batch, rng):
18
+ """Computes sparse categorical cross-entropy loss with autoregressive shifting."""
19
+ inputs, mask, targets = batch # Inputs: (batch, seq_len), Targets: (batch, seq_len)
20
+
21
+ logits = model.apply(params, inputs, mask,rngs={"dropout": rng}) # Forward pass
22
+ log_probs = jax.nn.log_softmax(logits) # Convert logits to log probabilities
23
+
24
+ # Shift targets left: Model predicts targets[:, 1:] given inputs[:, :-1]
25
+ shifted_targets = targets[:, 1:] # Remove first token
26
+ shifted_logits = log_probs[:, :-1, :] # Remove last token prediction
27
+
28
+ # Get probability of correct class
29
+ target_probs = jnp.take_along_axis(shifted_logits, shifted_targets[..., None], axis=-1)[..., 0]
30
+
31
+ loss = -target_probs.mean() # Negative log-likelihood loss
32
+
33
+ return loss
34
+
35
+ def val_loss(params, loss_fn, model_struct, x,mask,y, key):
36
+ return loss_fn(params, model_struct, [x,mask,y], key)
37
+
38
+
39
+ val_loss = jax.jit(
40
+ val_loss,
41
+ static_argnums=(1,2)
42
+ )
43
+
44
+ val_loss = jax.pmap(
45
+ val_loss,
46
+ in_axes=(None,None,None,0,0,0,None),
47
+ static_broadcasted_argnums=(1,2),
48
+ axis_name='batch'
49
+ )
50
+ def BatchTrain(params, grad_fn, model_struct, x, mask, y, key, optimizer, opt_state):
51
+ # Yeh this too... Sry I have to hide it
52
+ return None
53
+
54
+ BatchTrain = jax.jit(
55
+ BatchTrain,
56
+ static_argnums=(1, 2, 7)
57
+ )
58
+
59
+ BatchTrain = jax.pmap(
60
+ BatchTrain,
61
+ static_broadcasted_argnums=(1, 2, 7),
62
+ in_axes=(None, None, None, 0, 0, 0, None, None, None),
63
+ axis_name="batch",
64
+ out_axes=None
65
+ )
66
+
67
+ def get_first(pytree):
68
+ return jax.tree_util.tree_map(lambda x: x[0], pytree)
69
+
70
+ def convert_tree(dtype,pytree):
71
+ return jax.tree_util.tree_map(lambda x: x.astype(dtype),pytree)
72
+
73
+ @partial(jax.jit, static_argnums=(0,))
74
+ def call_model_jit(model_struct, params, x, mask, rngs):
75
+ return model_struct.apply(params, x, mask, rngs=rngs, training=False)
@@ -0,0 +1,204 @@
1
+ import numpy as np
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ import traceback
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from ._errors import *
7
+ import optax
8
+ from ._utils import *
9
+ import os
10
+ import flax
11
+ import time
12
+
13
+
14
+ class Debugger():
15
+ def __init__(self,debug = False,path=None) -> None:
16
+ self.debug = debug
17
+ self.path = path
18
+ if path is not None:
19
+ self.logfile = open(path, "a")
20
+
21
+ def turn_debugger_on(self):
22
+ self.debug = True
23
+ self.logfile = open(self.path, "a")
24
+
25
+ def logger(self,message,state="DEBUG"):
26
+ if self.debug:
27
+ if self.path is None:
28
+ print(f"[{state}] {message}")
29
+ else:
30
+ self.logfile.write(f"[{state}] {message}\n")
31
+ self.logfile.flush()
32
+
33
+ def trace_func(self,func):
34
+ def wrapper(*args, **kwargs):
35
+ self.logger(f"Calling {func.__name__} with args: {[type(i) for i in args]}, kwargs: {[type(i) for i in kwargs]}",state="FUNC_CALL")
36
+ start_time = time.time()
37
+ try:
38
+ out = func(*args, **kwargs)
39
+ except Exception as e:
40
+ self.logger(f"{func.__name__} Exception:{e}",state="ERROR")
41
+ self.logger(traceback.format_exc(), state="TRACEBACK")
42
+ raise
43
+ name = f"{args[0].__class__.__name__}.__init__" if func.__name__ == "__init__" else func.__name__
44
+ self.logger(f"{name} returned: {type(out)} with shape/info: {getattr(out, 'shape', 'uk')}, {getattr(out, 'dtype', 'uk')}, Time taken:{time.time()-start_time}s",state="FUNC_RETURN")
45
+ return out
46
+ return wrapper
47
+
48
+ debug_state = Debugger()
49
+
50
+
51
+ class Tokenization():
52
+ @debug_state.trace_func
53
+ def __init__(self,vocab="gpt2") -> None:
54
+ import tiktoken
55
+ self.stuff = tiktoken
56
+ self.vocab = vocab
57
+ self.enc = tiktoken.get_encoding(self.vocab)
58
+
59
+ @debug_state.trace_func
60
+ def tokenize(self,batch:list, workers:int, max_length:int):
61
+ self.enc = self.stuff.get_encoding(self.vocab)
62
+ enc = self.enc
63
+ eos_token = 50256
64
+ pad_token = 0
65
+
66
+ def encode_and_pad(text):
67
+ tokens = enc.encode(text, allowed_special={'<|endoftext|>'})[:max_length]
68
+ padded = np.full(max_length, pad_token, dtype=np.int32)
69
+ mask = np.zeros(max_length, dtype=np.int32)
70
+ padded[:len(tokens)] = tokens
71
+ mask[:len(tokens)] = 1 # Mark real tokens as 1
72
+ return padded, mask
73
+
74
+ with ThreadPoolExecutor(max_workers=workers) as executor:
75
+ results = list(executor.map(encode_and_pad, batch))
76
+
77
+ encoded_batch, mask = zip(*results) # Split tokens and masks
78
+ encoded_batch = np.array(encoded_batch)
79
+ mask = np.array(mask)
80
+
81
+ return encoded_batch, mask
82
+
83
+ @debug_state.trace_func
84
+ def tokenize_(self, batch: list):
85
+ self.enc = self.stuff.get_encoding(self.vocab)
86
+ enc = self.enc
87
+ eos_token = 50256
88
+
89
+ encoded_batch = []
90
+ mask = []
91
+
92
+ for text in batch:
93
+ tokens = enc.encode(text, allowed_special={'<|endoftext|>'})
94
+ encoded_batch.append(np.array(tokens, dtype=np.int32))
95
+ mask.append(np.ones(len(tokens), dtype=np.int32)) # Mask matches token length
96
+
97
+ return jnp.array(encoded_batch), jnp.array(mask)
98
+
99
+ @debug_state.trace_func
100
+ def decode(self,tokens):
101
+ return self.enc.decode(tokens)
102
+
103
+ class Flax_ds():
104
+ @debug_state.trace_func
105
+ def __init__(self,x_eq_y:bool) -> None:
106
+ self.x_eq_y = x_eq_y
107
+ self.x = None
108
+ self.mask = None
109
+ self.y = None
110
+ self.batch = None
111
+
112
+ @debug_state.trace_func
113
+ def load_data(self,x,mask,y):
114
+ self.x = np.array(x)
115
+ self.mask = mask
116
+ if not self.x_eq_y:
117
+ self.y = np.array(y)
118
+
119
+ @debug_state.trace_func
120
+ def batch_it_(self,batch_size):
121
+ if not self.x_eq_y:
122
+ self.x = jnp.array(self.x)
123
+ self.mask = jnp.array(self.mask)
124
+ self.y = jnp.array(self.y)
125
+ seq_len = self.x.shape[1]
126
+
127
+ n_batches = len(self.x) // batch_size
128
+ num_devices = jax.device_count()
129
+
130
+ x_batch = [self.x[i * batch_size:(i + 1) * batch_size] for i in range(n_batches)]
131
+ mask_batch = [self.mask[i * batch_size:(i + 1) * batch_size] for i in range(n_batches)]
132
+ y_batch = [self.y[i * batch_size:(i + 1) * batch_size] for i in range(n_batches)]
133
+
134
+
135
+ self.batch = [[i.reshape(num_devices,-1, seq_len), j.reshape(num_devices,-1, seq_len), k.reshape(num_devices,-1, seq_len)] for i, j, k in zip(x_batch, mask_batch, y_batch)]
136
+
137
+ del self.x, self.mask, self.y
138
+ return self.batch
139
+
140
+ else:
141
+ self.x = jnp.array(self.x)
142
+ self.mask = jnp.array(self.mask)
143
+ seq_len = self.x.shape[1]
144
+
145
+ n_batches = len(self.x) // batch_size
146
+ num_devices = jax.device_count()
147
+
148
+ x_batch = [self.x[i * batch_size:(i + 1) * batch_size] for i in range(n_batches)]
149
+ mask_batch = [self.mask[i * batch_size:(i + 1) * batch_size] for i in range(n_batches)]
150
+
151
+
152
+ self.batch = [[i.reshape(num_devices,-1, seq_len), j.reshape(num_devices,-1, seq_len)] for i, j in zip(x_batch, mask_batch)]
153
+
154
+ del self.x, self.mask
155
+ return self.batch
156
+
157
+
158
+ def __len__(self):
159
+ return len(self.batch)
160
+
161
+ def stream_it(self):
162
+ if self.batch == None:
163
+ IncorrectDtype("Bruh... You forgot to run '.batch_it' before trying to stream it.... T~T")
164
+ if self.x_eq_y:
165
+ for i in self.batch:
166
+ yield i[0],i[1],i[0]
167
+ else:
168
+ for i in self.batch:
169
+ yield i[0],i[1],i[2]
170
+
171
+ @property
172
+ def gimme_the_data(self):
173
+ return self.batch
174
+
175
+ class Optimizer():
176
+ @debug_state.trace_func
177
+ def __init__(self,optimizer,lr,lrf,batches,epochs,params):
178
+ decay_rate = (lrf / lr) ** (1 / (batches * epochs))
179
+ self.lr_schedule = optax.exponential_decay(
180
+ init_value=lr,
181
+ transition_steps=1,
182
+ decay_rate=decay_rate,
183
+ staircase=False # Smooth decay
184
+ )
185
+ self.optimizer = optimizer(self.lr_schedule)
186
+ self.state = self.optimizer.init(params)
187
+
188
+ @debug_state.trace_func
189
+ def load(self,path,dtype=None):
190
+ try:
191
+ with open(os.path.join(path, "make_stuff_better.pkl"), "rb") as f:
192
+ self.state = flax.serialization.from_bytes(self.state, f.read())
193
+ if dtype is not None:
194
+ self.state = convert_tree(dtype,self.state)
195
+ print("Using loaded optimizer states")
196
+ except:
197
+ print("No optimizers states found")
198
+
199
+ @debug_state.trace_func
200
+ def save(self,path):
201
+ os.makedirs(path, exist_ok=True)
202
+ with open(os.path.join(path, "make_stuff_better.pkl"), "wb") as f:
203
+ f.write(flax.serialization.to_bytes(self.state))
204
+
@@ -0,0 +1,81 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import flax.linen as nn
4
+
5
+ class PosEnc(nn.Module):
6
+ dim : int
7
+ dtype: jnp.dtype = jnp.bfloat16
8
+
9
+ @nn.compact
10
+ def __call__(self, x):
11
+ batch_size, sequence_length, _ = x.shape
12
+
13
+ # Compute div term once (avoiding repeated `exp` calls)
14
+ div_term = jnp.exp(-jnp.arange(0, self.dim, 2) * (jnp.log(10000.0) / self.dim)).astype(self.dtype)
15
+
16
+ # Compute positions in one step (efficiently broadcasting)
17
+ position = jnp.arange(sequence_length)[:, None] * div_term # (seq_len, emb_dim/2)
18
+
19
+ # Directly compute sine & cosine, then interleave them
20
+ pos_enc = jnp.zeros((sequence_length, self.dim),dtype=self.dtype)
21
+ pos_enc = pos_enc.at[:, 0::2].set(jnp.sin(position)).astype(self.dtype)
22
+ pos_enc = pos_enc.at[:, 1::2].set(jnp.cos(position)).astype(self.dtype)
23
+
24
+ # Expand for batch & return
25
+ return x + pos_enc[None, :, :]
26
+
27
+ class Attention(nn.Module):
28
+ num_heads: int
29
+ d_model: int
30
+ dtype: jnp.dtype = jnp.bfloat16
31
+
32
+ def setup(self):
33
+ assert self.d_model % self.num_heads == 0, "d_model must be divisible by num_heads"
34
+ self.depth = self.d_model // self.num_heads
35
+ self.qkv_dense = nn.Dense(features=3 * self.d_model, kernel_init=nn.initializers.normal(stddev=0.02),dtype=self.dtype)
36
+ self.out_dense = nn.Dense(features=self.d_model, kernel_init=nn.initializers.normal(stddev=0.02),dtype=self.dtype)
37
+
38
+ def __call__(self, x, mask):
39
+ batch_size, seq_len, _ = x.shape
40
+ qkv = self.qkv_dense(x) # (batch, seq_len, 3 * d_model)
41
+ qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.depth)
42
+ qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4)) # (3, batch, num_heads, seq_len, depth)
43
+ Q, K, V = qkv # Unpacking (batch, num_heads, seq_len, depth)
44
+
45
+ # Scaled Dot-Product Attention
46
+ logits = jnp.einsum("bhqd,bhkd->bhqk", Q, K) / jnp.sqrt(self.depth)
47
+
48
+ if mask is not None:
49
+ mask = mask[:, None, :] # Expand for broadcasting (batch, 1, seq_len, seq_len)
50
+ logits = jnp.where(mask, logits, -1e9)
51
+
52
+ attn_weights = jax.nn.softmax(logits, axis=-1)
53
+ attn_output = jnp.einsum("bhqk,bhkd->bhqd", attn_weights, V)
54
+
55
+ # Concatenate heads
56
+ attn_output = jnp.transpose(attn_output, (0, 2, 1, 3)) # (batch, seq_len, num_heads, depth)
57
+ concat_output = attn_output.reshape(batch_size, seq_len, self.d_model) # (batch, seq_len, d_model)
58
+
59
+ return self.out_dense(concat_output)
60
+
61
+
62
+ class Block(nn.Module):
63
+ num_heads : int
64
+ attention_dim : int
65
+ ff_dim : int
66
+ dropout_rate : float
67
+ dtype: jnp.dtype = jnp.bfloat16
68
+
69
+ @nn.compact
70
+ def __call__(self, x_inp, mask, train: bool):
71
+ x = nn.LayerNorm(dtype=self.dtype)(x_inp)
72
+ x = Attention(self.num_heads, self.attention_dim,dtype=self.dtype)(x, mask)
73
+ x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
74
+ x_inp = x + x_inp
75
+
76
+ # Pre-LN before FFN
77
+ x = nn.LayerNorm(dtype=self.dtype)(x_inp)
78
+ x = nn.Dense(self.ff_dim, kernel_init=nn.initializers.normal(stddev=0.02),dtype=self.dtype)(x)
79
+ x = nn.gelu(x)
80
+ x = nn.Dense(self.attention_dim, kernel_init=nn.initializers.normal(stddev=0.02),dtype=self.dtype)(x)
81
+ return x + x_inp
@@ -0,0 +1,186 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import flax.linen as nn
4
+ from .layers import *
5
+ import json
6
+ import os
7
+ import flax.serialization
8
+ import numpy as np
9
+ from ._errors import *
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ from ._utils import *
14
+ from .independent import *
15
+
16
+
17
+ rng = jax.random.PRNGKey(0)
18
+
19
+ class Model(nn.Module):
20
+ num_heads: int
21
+ attention_dim: int
22
+ vocab_size: int
23
+ num_blocks: int
24
+ ff_dim: int
25
+ dropout_rate: float
26
+ dtype: jnp.dtype = jnp.float32
27
+
28
+
29
+ def setup(self):
30
+ self.emb = nn.Embed(num_embeddings=self.vocab_size,features=self.attention_dim, embedding_init=nn.initializers.normal(stddev=0.02),dtype=self.dtype)
31
+ self.pos_enc = PosEnc(self.attention_dim)
32
+ self.blocks = [Block(num_heads=self.num_heads,attention_dim=self.attention_dim,ff_dim=self.ff_dim,dropout_rate=self.dropout_rate,dtype=self.dtype)for i in range(self.num_blocks)]
33
+
34
+ def __call__(self,x,mask,training=True):
35
+ mask = self.process_mask(mask)
36
+ x = x.astype(jnp.int32)
37
+ x = self.emb(x)
38
+ x = self.pos_enc(x)
39
+ for i in self.blocks:
40
+ x = i(x,mask,training)
41
+ return x @ self.emb.embedding.T
42
+
43
+
44
+ def process_mask(self,mask):
45
+ _, seq_len = mask.shape
46
+
47
+ # Create causal mask (lower triangular matrix)
48
+ causal_mask = jnp.tril(jnp.ones((seq_len, seq_len)))
49
+
50
+ # Reshape padding mask and apply to causal mask
51
+ mask = mask[:, None, :] # (batch_size, 1, seq_len)
52
+ mask_sq = causal_mask[None, :, :] * mask # (batch_size, seq_len, seq_len)
53
+ mask_sq = jnp.transpose(mask_sq, (0, 2, 1)) * mask
54
+ mask_sq = jnp.transpose(mask_sq, (0, 2, 1))
55
+
56
+ return mask_sq
57
+
58
+
59
+ class ModelManager():
60
+ @debug_state.trace_func
61
+ def __init__(self,num_heads,attention_dim,vocab_size,num_blocks,ff_dim,dropout_rate,dtype = None) -> None:
62
+ self.key = jax.random.PRNGKey(0)
63
+ self.model_struct = Model(num_heads,attention_dim,vocab_size,num_blocks,ff_dim,dropout_rate,dtype)
64
+ self.params = self.model_struct.init(self.key,jax.random.normal(self.key,(2, 11)),jnp.ones((2,11)))
65
+ if dtype is not None:
66
+ self.params = convert_tree(dtype,self.params)
67
+
68
+ self.optimizer = None
69
+
70
+ self.data = {
71
+ "num_heads":num_heads,
72
+ "attention_dim":attention_dim,
73
+ "vocab_size":vocab_size,
74
+ "num_blocks":num_blocks,
75
+ "ff_dim":ff_dim,
76
+ "dropout_rate":dropout_rate
77
+ }
78
+
79
+
80
+ def __call__(self,input,mask):
81
+ return self.model_struct.apply(self.params,input,mask,rngs={"dropout": self.key_bruh},training=False)
82
+
83
+ def jax_call(self,input,mask):
84
+ rngs = rngs={"dropout": self.key}
85
+ return call_model_jit(self.model_struct,self.params,input,mask,rngs)
86
+
87
+
88
+ @property
89
+ def trainable_variables(self):
90
+ return self.params
91
+
92
+ @property
93
+ def key_bruh(self):
94
+ self.key, subkey = jax.random.split(self.key)
95
+ return subkey
96
+
97
+ @debug_state.trace_func
98
+ def training_setup(self,optimizer,lr,lrf,batches,epochs,state_path="",opt_state_dtype=None):
99
+ self.optimizer = Optimizer(optimizer,lr,lrf,batches,epochs,self.params)
100
+ self.optimizer.load(state_path,opt_state_dtype)
101
+ self.grad_fn = jax.value_and_grad(loss_fn)
102
+ return self.optimizer.lr_schedule
103
+
104
+ @debug_state.trace_func
105
+ def train_batch(self,x,mask,y):
106
+ key = self.key_bruh
107
+ (loss, self.params, self.optimizer.state), first_time = time_it(BatchTrain,self.params,self.grad_fn,self.model_struct,x,mask,y,key,self.optimizer.optimizer,self.optimizer.state)
108
+ return loss , [first_time,0,0]
109
+
110
+ @debug_state.trace_func
111
+ def save_model(self,name,opt_state=True):
112
+ os.makedirs(name, exist_ok=True)
113
+ with open(os.path.join(name, "good_stuff.pkl"), "wb") as f:
114
+ f.write(flax.serialization.to_bytes(self.trainable_variables))
115
+
116
+ with open(os.path.join(name, "understanding_good_stuff.json"),"w") as f:
117
+ json.dump(self.data, f, indent=2)
118
+
119
+ if (opt_state) and (self.optimizer is not None):
120
+ with open(os.path.join(name, "make_stuff_better.pkl"), "wb") as f:
121
+ f.write(flax.serialization.to_bytes(self.optimizer.state))
122
+
123
+ @debug_state.trace_func
124
+ def batch_it(self, x, mask, y, batch_size, x_eq_y=True):
125
+ dataset = Flax_ds(x_eq_y)
126
+ dataset.load_data(x,mask,y)
127
+ dataset.batch_it_(batch_size=batch_size)
128
+ return dataset
129
+
130
+ @debug_state.trace_func
131
+ def train(self,x,mask,y,epochs,batch_size,optimizer,lr,lrf,val_x=None,val_mask=None,val_y=None,val_step=100,updates_in=1,avg_mem=1500,state_path=None):
132
+ pass # I don't want to expose the training stuff... So... I hope you understand :D
133
+
134
+
135
+ @debug_state.trace_func
136
+ def summary(self):
137
+ def count_params(params):
138
+ total = 0
139
+ for value in params.values():
140
+ if isinstance(value, dict):
141
+ total += count_params(value)
142
+ elif hasattr(value, 'size'):
143
+ total += value.size
144
+ return total
145
+
146
+ for i in list(self.trainable_variables['params'].keys()):
147
+ print(f"{i} :{count_params(self.trainable_variables['params'].get(i, {})):,}")
148
+ print("-------------------")
149
+ print(f"Total :{count_params(self.trainable_variables['params']):,}")
150
+
151
+ @debug_state.trace_func
152
+ def change_precision(self,dtype):
153
+ self.params = jax.tree_util.tree_map(lambda x: x.astype(dtype),self.params)
154
+
155
+ @property
156
+ def precision(self):
157
+ type_tree = jax.tree_util.tree_map(lambda x: x.dtype,self.model)
158
+ types = jax.tree_util.tree_leaves(type_tree)
159
+ if len(set(types)) == 1:
160
+ print(f"Model dtype:{types[0]}")
161
+ else:
162
+ print("Model contrains mixed dtypes")
163
+
164
+
165
+
166
+ ### Test stuff for now ok?
167
+ ### Bruh your forgot to get the better predict from googel collab ;-;
168
+
169
+ @debug_state.trace_func
170
+ def plot(losses, num_points=100, chop_off=100):
171
+ if chop_off >= len(losses):
172
+ raise ValueError("chop_off is greater than or equal to the length of losses")
173
+
174
+ smoothed_losses = np.cumsum(losses) / (np.arange(len(losses)) + 1)
175
+
176
+ interval = max(len(losses) // num_points, 1)
177
+ sampled_losses = smoothed_losses[::interval][chop_off:]
178
+ sampled_batches = np.arange(len(losses))[::interval][chop_off:]
179
+
180
+ # Plot
181
+ plt.figure(figsize=(10, 5))
182
+ plt.plot(sampled_batches, sampled_losses, marker='o', linestyle='-')
183
+ plt.xlabel('Batch')
184
+ plt.ylabel('Loss')
185
+ plt.title('Smoothed Loss over Batches')
186
+ plt.show()
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.1
2
+ Name: BwETAF
3
+ Version: 0.1
4
+ Summary: Module to load BwETAF models (Flax)
5
+ Author: Boring._.wicked
6
+ Classifier: Programming Language :: Python :: 3
7
+ Classifier: License :: OSI Approved :: MIT License
@@ -0,0 +1,12 @@
1
+ setup.py
2
+ BwETAF/__init__.py
3
+ BwETAF/_errors.py
4
+ BwETAF/_utils.py
5
+ BwETAF/independent.py
6
+ BwETAF/layers.py
7
+ BwETAF/model.py
8
+ BwETAF.egg-info/PKG-INFO
9
+ BwETAF.egg-info/SOURCES.txt
10
+ BwETAF.egg-info/dependency_links.txt
11
+ BwETAF.egg-info/requires.txt
12
+ BwETAF.egg-info/top_level.txt
@@ -0,0 +1,5 @@
1
+ flax
2
+ jax
3
+ huggingface_hub
4
+ optax
5
+ numpy
@@ -0,0 +1 @@
1
+ BwETAF
BwETAF-0.1/PKG-INFO ADDED
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.1
2
+ Name: BwETAF
3
+ Version: 0.1
4
+ Summary: Module to load BwETAF models (Flax)
5
+ Author: Boring._.wicked
6
+ Classifier: Programming Language :: Python :: 3
7
+ Classifier: License :: OSI Approved :: MIT License
BwETAF-0.1/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
BwETAF-0.1/setup.py ADDED
@@ -0,0 +1,20 @@
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="BwETAF",
5
+ version="0.1",
6
+ packages=find_packages(),
7
+ install_requires=[
8
+ "flax",
9
+ "jax",
10
+ "huggingface_hub",
11
+ "optax",
12
+ "numpy"
13
+ ],
14
+ description="Module to load BwETAF models (Flax)",
15
+ author="Boring._.wicked",
16
+ classifiers=[
17
+ "Programming Language :: Python :: 3",
18
+ "License :: OSI Approved :: MIT License",
19
+ ],
20
+ )