froog 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
froog/ops.py CHANGED
@@ -7,9 +7,9 @@
7
7
  # |___| |___| |_||_______||_______||_______|
8
8
 
9
9
  import numpy as np
10
- from froog.tensor import Function, register
10
+ from typing import Tuple, Union, Optional, Any, Callable
11
+ from froog.tensor import Function, register, Tensor
11
12
  from froog.utils import im2col, col2im
12
- from froog.tensor import Tensor
13
13
 
14
14
  # *****************************************************
15
15
  # ____ ___ _____ __________ ____ ____ _____
@@ -23,32 +23,74 @@ from froog.tensor import Tensor
23
23
 
24
24
  class Add(Function):# x.add(y)
25
25
  @staticmethod # @staticmethod doesn't require an instance of Add to work, so you can do x.add(y)
26
- def forward(ctx, x, y):
26
+ def forward(ctx: Any, x: np.ndarray, y: np.ndarray) -> np.ndarray:
27
+ # Check if we have GPU buffers
28
+ is_metal_buffer = lambda x: hasattr(x, '__pyobjc_object__') or str(type(x)).find('Buffer') >= 0
29
+ if is_metal_buffer(x) or is_metal_buffer(y):
30
+ # Import get_buffer_data helper for Metal buffers
31
+ try:
32
+ from froog.gpu.buffer_utils import get_buffer_data
33
+ x_data = get_buffer_data(x)
34
+ y_data = get_buffer_data(y)
35
+ ctx.save_for_backward(x_data, y_data)
36
+ return x_data + y_data
37
+ except ImportError:
38
+ print("Warning: buffer_utils not available")
39
+ # Fall back to regular implementation
40
+ ctx.save_for_backward(x, y)
41
+ return x + y
42
+
43
+ # Regular implementation
44
+ ctx.save_for_backward(x, y)
27
45
  return x + y
28
46
 
29
47
  @staticmethod
30
- def backward(ctx, grad_output):
48
+ def backward(ctx: Any, grad_output: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
31
49
  return grad_output, grad_output
32
50
  register("add", Add)
33
51
 
34
52
  class Sub(Function): # x.sub(y)
35
53
  @staticmethod
36
- def forward(ctx, x, y):
54
+ def forward(ctx: Any, x: np.ndarray, y: np.ndarray) -> np.ndarray:
37
55
  return x-y
38
56
 
39
57
  @staticmethod
40
- def backward(ctx, grad_output):
58
+ def backward(ctx: Any, grad_output: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
41
59
  return grad_output, -grad_output
42
60
  register('sub', Sub)
43
61
 
44
62
  class Mul(Function): # x.mul(y)
45
63
  @staticmethod
46
- def forward(ctx, x, y):
64
+ def forward(ctx: Any, x: np.ndarray, y: np.ndarray) -> np.ndarray:
65
+ # Check if we have GPU buffers
66
+ is_metal_buffer = lambda x: hasattr(x, '__pyobjc_object__') or str(type(x)).find('Buffer') >= 0
67
+ if is_metal_buffer(x) or is_metal_buffer(y):
68
+ # Import get_buffer_data helper for Metal buffers
69
+ try:
70
+ from froog.gpu.buffer_utils import get_buffer_data, buffer_mul
71
+ x_data = get_buffer_data(x)
72
+ y_data = get_buffer_data(y)
73
+ ctx.save_for_backward(x_data, y_data)
74
+ return buffer_mul(x, y)
75
+ except Exception as e:
76
+ print(f"Error in Mul.forward with buffer: {e}")
77
+ # Fall back to CPU implementation if buffer handling fails
78
+ from froog.gpu import get_device
79
+ device = get_device()
80
+ if device:
81
+ x_cpu = device.download_tensor(x) if is_metal_buffer(x) else x
82
+ y_cpu = device.download_tensor(y) if is_metal_buffer(y) else y
83
+ ctx.save_for_backward(x_cpu, y_cpu)
84
+ result = x_cpu * y_cpu
85
+ return device.upload_tensor(result)
86
+ raise
87
+
88
+ # Standard CPU implementation
47
89
  ctx.save_for_backward(x, y)
48
90
  return x * y
49
91
 
50
92
  @staticmethod
51
- def backward(ctx, grad_output):
93
+ def backward(ctx: Any, grad_output: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
52
94
  x, y = ctx.saved_tensors
53
95
  return y * grad_output, x * grad_output
54
96
  register("mul", Mul)
@@ -59,24 +101,49 @@ class Sum(Function): # x.sum()
59
101
  reduces its input tensor to a single value by summing all the elements
60
102
  """
61
103
  @staticmethod
62
- def forward(ctx, input):
104
+ def forward(ctx: Any, input: np.ndarray) -> np.ndarray:
105
+ # Check if we have a GPU buffer
106
+ is_metal_buffer = lambda x: hasattr(x, '__pyobjc_object__') or str(type(x)).find('Buffer') >= 0
107
+ if is_metal_buffer(input):
108
+ # Use buffer utilities
109
+ try:
110
+ from froog.gpu.buffer_utils import get_buffer_data, buffer_sum
111
+ input_data = get_buffer_data(input)
112
+ ctx.save_for_backward(input_data)
113
+ ctx.input_shape = input_data.shape
114
+ return buffer_sum(input)
115
+ except Exception as e:
116
+ print(f"Error in Sum.forward with buffer: {e}")
117
+ # Fall back to CPU implementation
118
+ from froog.gpu import get_device
119
+ device = get_device()
120
+ if device:
121
+ input_cpu = device.download_tensor(input)
122
+ ctx.save_for_backward(input_cpu)
123
+ ctx.input_shape = input_cpu.shape
124
+ result = np.array([np.sum(input_cpu)])
125
+ return device.upload_tensor(result)
126
+ raise
127
+
128
+ # Standard CPU implementation
63
129
  ctx.save_for_backward(input)
64
- return np.array([input.sum()])
130
+ ctx.input_shape = input.shape
131
+ return np.array([np.sum(input)])
65
132
 
66
133
  @staticmethod
67
- def backward(ctx, grad_output):
134
+ def backward(ctx: Any, grad_output: np.ndarray) -> np.ndarray:
68
135
  (input,) = ctx.saved_tensors
69
136
  return grad_output * np.ones_like(input)
70
137
  register("sum", Sum)
71
138
 
72
139
  class Pow(Function): # x.pow(y)
73
140
  @staticmethod
74
- def forward(ctx, x, y):
141
+ def forward(ctx: Any, x: np.ndarray, y: np.ndarray) -> np.ndarray:
75
142
  ctx.save_for_backward(x, y)
76
143
  return x ** y
77
144
 
78
145
  @staticmethod
79
- def backward(ctx, grad_output):
146
+ def backward(ctx: Any, grad_output: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
80
147
  x, y = ctx.saved_tensors
81
148
  return y * (x**(y-1.0)) * grad_output, (x**y) * np.log(x) * grad_output # power rule, d/dx (y^x)
82
149
  register("pow", Pow)
@@ -93,16 +160,65 @@ register("pow", Pow)
93
160
 
94
161
  class Dot(Function): # x.dot(y)
95
162
  @staticmethod
96
- def forward(ctx, input, weight):
163
+ def forward(ctx: Any, input: np.ndarray, weight: np.ndarray) -> np.ndarray:
97
164
  ctx.save_for_backward(input, weight)
98
- return input.dot(weight)
165
+
166
+ # Check if we're working with GPU buffers
167
+ try:
168
+ from froog.tensor import is_buffer
169
+ from froog.gpu import download_tensor
170
+
171
+ # Convert any GPU buffers to CPU for the operation
172
+ if is_buffer(input):
173
+ input_cpu = download_tensor(input)
174
+ else:
175
+ input_cpu = input
176
+
177
+ if is_buffer(weight):
178
+ weight_cpu = download_tensor(weight)
179
+ else:
180
+ weight_cpu = weight
181
+
182
+ return input_cpu.dot(weight_cpu)
183
+ except Exception as e:
184
+ import traceback
185
+ print(f"Error in dot operation: {str(e)}")
186
+ print(f" Self: {input}")
187
+ print(f" Arg 0: {weight}")
188
+ print(f" Kwargs: {{}}")
189
+ traceback.print_exc()
190
+ # Try the original method as fallback
191
+ return input.dot(weight)
99
192
 
100
193
  @staticmethod
101
- def backward(ctx, grad_output):
194
+ def backward(ctx: Any, grad_output: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
102
195
  input, weight = ctx.saved_tensors
103
- grad_input = grad_output.dot(weight.T)
104
- grad_weight = input.T.dot(grad_output)
105
- return grad_input, grad_weight
196
+
197
+ # Convert GPU buffers to CPU if needed
198
+ try:
199
+ from froog.tensor import is_buffer
200
+ from froog.gpu import download_tensor
201
+
202
+ if is_buffer(input):
203
+ input_cpu = download_tensor(input)
204
+ else:
205
+ input_cpu = input
206
+
207
+ if is_buffer(weight):
208
+ weight_cpu = download_tensor(weight)
209
+ else:
210
+ weight_cpu = weight
211
+
212
+ if is_buffer(grad_output):
213
+ grad_output_cpu = download_tensor(grad_output)
214
+ else:
215
+ grad_output_cpu = grad_output
216
+
217
+ return grad_output_cpu.dot(weight_cpu.T), input_cpu.T.dot(grad_output_cpu)
218
+ except Exception as e:
219
+ print(f"Error in dot backward: {str(e)}")
220
+ # Fallback
221
+ return grad_output.dot(weight.T), input.T.dot(grad_output)
106
222
  register('dot', Dot)
107
223
  register('matmul', Dot)
108
224
 
@@ -118,12 +234,12 @@ register('matmul', Dot)
118
234
 
119
235
  class ReLU(Function):
120
236
  @staticmethod
121
- def forward(ctx, input):
237
+ def forward(ctx: Any, input: np.ndarray) -> np.ndarray:
122
238
  ctx.save_for_backward(input)
123
239
  return np.maximum(input, 0) # relu(x) = max(0,x)
124
240
 
125
241
  @staticmethod
126
- def backward(ctx, grad_output):
242
+ def backward(ctx: Any, grad_output: np.ndarray) -> np.ndarray:
127
243
  input, = ctx.saved_tensors
128
244
  grad_input = grad_output * (input >= 0)
129
245
  return grad_input
@@ -131,49 +247,61 @@ register("relu", ReLU)
131
247
 
132
248
  class Sigmoid(Function):
133
249
  @staticmethod
134
- def forward(ctx, input):
250
+ def forward(ctx: Any, input: np.ndarray) -> np.ndarray:
135
251
  ctx.save_for_backward(input)
136
252
  ret = 1/(1 + np.exp(-input)) # sigmoid(x) = 1 / (1 + exp(-x))
137
253
  return ret
138
254
 
139
255
  @staticmethod
140
- def backward(ctx, grad_output):
256
+ def backward(ctx: Any, grad_output: np.ndarray) -> np.ndarray:
141
257
  ret, = ctx.saved_tensors
142
258
  grad_input = grad_output * (ret * (1 - ret)) # just take the derivative of sigmoid
143
259
  return grad_input
144
260
  register("sigmoid", Sigmoid)
145
261
 
146
- # class Dropout(Function):
147
- # """
148
- # Randomly zeroes some of the elements of the input tensor with probability p during training.
149
- # The elements to zero are randomized on every forward call.
150
- # During inference, dropout is disabled and the input is scaled by (1-p) to maintain the expected value.
151
- # """
152
- # @staticmethod
153
- # def forward(ctx, input, p=0.5, training=True):
154
- # if training:
155
- # # Create a binary mask with probability (1-p) of being 1
156
- # mask = (np.random.random(input.shape) > p).astype(np.float32)
157
- # ctx.save_for_backward(mask)
158
- # return input * mask
159
- # else:
160
- # # during inference, scale the input by (1-p)
161
- # return input * (1-p)
162
-
163
- # @staticmethod
164
- # def backward(ctx, grad_output):
165
- # mask, = ctx.saved_tensors
166
- # return grad_output * mask
167
- # register("dropout", Dropout)
262
+ class DropoutLayer:
263
+ """
264
+ Dropout layer that randomly sets a fraction of input units to 0 during training time.
265
+ pytorch version: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
266
+ """
267
+ def __init__(self, p: float = 0.5) -> None:
268
+ self.p = p
269
+ self.training = True
270
+
271
+ def __call__(self, x):
272
+ # build a CPU‐side random mask of the same shape as the tensor x
273
+ mask_np = (np.random.rand(*x.shape) >= self.p).astype(np.float32) / (1.0 - self.p)
274
+ from froog.tensor import Tensor
275
+ mask_t = Tensor(mask_np)
276
+ if getattr(x, "is_gpu", False): mask_t = mask_t.to_gpu()
277
+ return x.mul(mask_t)
278
+
279
+ class Dropout(Function):
280
+ @staticmethod
281
+ def forward(ctx: Any, input: np.ndarray, p: float = 0.5, training: bool = True) -> np.ndarray:
282
+ if not training: return input
283
+ # create a binary mask with probability (1-p) of being 1
284
+ # scale by 1/(1-p) to keep expectation same
285
+ ctx.training = training
286
+ mask = (np.random.rand(*input.shape) >= p).astype(np.float32) / (1.0 - p if p < 1.0 else 1e-9) # avoid division by zero if p is 1.0
287
+ ctx.save_for_backward(mask)
288
+ return input * mask
289
+
290
+ @staticmethod
291
+ def backward(ctx: Any, grad_output: np.ndarray) -> np.ndarray:
292
+ if not ctx.training: return grad_output
293
+ mask, = ctx.saved_tensors
294
+ return grad_output * mask
295
+ register("dropout", Dropout)
168
296
 
169
297
  class Reshape(Function):
170
298
  @staticmethod
171
- def forward(ctx, x, shape):
299
+ def forward(ctx: Any, x: np.ndarray, shape: Tuple[int, ...]) -> np.ndarray:
172
300
  ctx.save_for_backward(x.shape)
173
301
  return x.reshape(shape)
174
302
 
175
303
  @staticmethod
176
- def backward(ctx, grad_output):
304
+ def backward(ctx: Any, grad_output: np.ndarray) -> np.ndarray:
177
305
  in_shape, = ctx.saved_tensors
178
306
  return grad_output.reshape(in_shape)
179
307
  register('reshape', Reshape)
@@ -183,11 +311,13 @@ class Pad2D(Function):
183
311
  The first element (0,0) corresponds to padding along the batch dimension, which indicates no padding on both sides (0 elements added).
184
312
  """
185
313
  @staticmethod
186
- def forward(ctx, x, padding=None):
314
+ def forward(ctx: Any, x: np.ndarray, padding: Optional[Tuple[int, int, int, int]] = None) -> np.ndarray:
315
+ if padding is None:
316
+ padding = (0, 0, 0, 0)
187
317
  return np.pad(x, ((0,0), (0,0), (padding[0], padding[1]), (padding[2], padding[3]))) # (top, bottom, left, right)
188
318
 
189
319
  @staticmethod
190
- def backward(ctx, grad_output):
320
+ def backward(ctx: Any, grad_output: np.ndarray) -> np.ndarray:
191
321
  raise Exception("write this")
192
322
  register('pad2d', Pad2D)
193
323
 
@@ -197,8 +327,8 @@ class LogSoftmax(Function):
197
327
  probabilities of each value are proportional to the scale of each value
198
328
  """
199
329
  @staticmethod
200
- def forward(ctx, input):
201
- def logsumexp(x):
330
+ def forward(ctx: Any, input: np.ndarray) -> np.ndarray:
331
+ def logsumexp(x: np.ndarray) -> np.ndarray:
202
332
  c = x.max(axis=1)
203
333
  return c + np.log(np.exp(x - c.reshape((-1, 1))).sum(axis=1)) # axis=1 refers to the columns
204
334
 
@@ -207,7 +337,7 @@ class LogSoftmax(Function):
207
337
  return output
208
338
 
209
339
  @staticmethod
210
- def backward(ctx, grad_output):
340
+ def backward(ctx: Any, grad_output: np.ndarray) -> np.ndarray:
211
341
  (output,) = ctx.saved_tensors
212
342
  return grad_output - np.exp(output)*(grad_output.sum(axis=1).reshape((-1, 1)))
213
343
  register("logsoftmax", LogSoftmax)
@@ -224,62 +354,123 @@ register("logsoftmax", LogSoftmax)
224
354
 
225
355
  class Conv2D(Function): # TODO: understand group splits
226
356
  @staticmethod
227
- def forward(ctx, x, w, stride=1, groups=1):
357
+ def forward(ctx: Any, x: np.ndarray, w: np.ndarray, stride: Union[int, Tuple[int, int]] = 1, groups: int = 1) -> np.ndarray:
228
358
  """
229
359
  https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
230
360
  WARNING: doesn't handle padding or strides yet
231
361
  Args:
232
- x.shape[0] --> number of input examples (batch size)
233
- cout --> number of output channels
234
- x.shape[2]-(H-1) --> non-padded height of conv output, need to subtract because this is an unpadded conv
235
- x.shape[3]-(W-1) --> width of output
362
+ x.shape[0] --> number of input examples (batch size)
363
+ cout --> number of output channels
364
+ x.shape[2]-(H-1) --> non-padded height of conv output, need to subtract because this is an unpadded conv
365
+ x.shape[3]-(W-1) --> width of output
236
366
  Shape:
237
367
  (a, b, c, d)(e, f, g, h) --> (a, e, c-(g-1), d-(h-1))
238
368
  in general, output x and y = [(W−K+2P)/S]+1
239
369
  """
240
- if type(ctx.stride) == int: # ctx stores function params
370
+ ctx.stride = stride
371
+ ctx.groups = groups
372
+
373
+ if isinstance(ctx.stride, int): # ctx stores function params
241
374
  ctx.stride = (ctx.stride, ctx.stride)
242
375
 
243
- cout, cin, H, W = w.shape
244
-
245
- tw = w.reshape(cout, -1).T # slice of kernel
246
- y_stride, x_stride = ctx.stride
247
-
248
- bs,cin_,oy,ox = x.shape[0], x.shape[1], (x.shape[2]-(H-y_stride))//y_stride, (x.shape[3]-(W-x_stride))//x_stride
249
- assert cin*ctx.groups == cin_ # ensures that the channel dimensions match appropriately for grouping
250
- assert cout % ctx.groups == 0 # ensures that the number of output channels can be evenly divided among the groups
251
- g_w_chans = cout//ctx.groups # number of output channels per group
252
-
253
- ctx.save_for_backward(x, w)
254
- ret = np.zeros((bs, cout, oy, ox), dtype=w.dtype)
255
-
256
- for g in range(ctx.groups):
257
- tw = w[g*g_w_chans:(g*g_w_chans+g_w_chans)].reshape(g_w_chans, -1).T # transformed kernel weights
258
- for Y in range(oy):
259
- for X in range(ox):
260
- iY,iX = Y*y_stride, X*x_stride
261
- tx = x[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W].reshape(bs, -1)
262
- ret[:, g*g_w_chans:(g*g_w_chans+g_w_chans), Y, X] += tx.dot(tw)
263
- return ret
376
+ # Check if we're working with GPU buffers and convert to CPU
377
+ try:
378
+ from froog.tensor import is_buffer
379
+ from froog.gpu import download_tensor
380
+
381
+ # Convert input to CPU if it's a GPU buffer
382
+ if is_buffer(x):
383
+ x_cpu = download_tensor(x)
384
+ else:
385
+ x_cpu = x
386
+
387
+ # Convert weight to CPU if it's a GPU buffer
388
+ if is_buffer(w):
389
+ w_cpu = download_tensor(w)
390
+ else:
391
+ w_cpu = w
392
+
393
+ # Now use the CPU tensors for the rest of the computation
394
+ cout, cin, H, W = w_cpu.shape
395
+
396
+ tw = w_cpu.reshape(cout, -1).T # slice of kernel
397
+ y_stride, x_stride = ctx.stride
398
+
399
+ bs,cin_,oy,ox = x_cpu.shape[0], x_cpu.shape[1], (x_cpu.shape[2]-(H-y_stride))//y_stride, (x_cpu.shape[3]-(W-x_stride))//x_stride
400
+ assert cin*ctx.groups == cin_ # ensures that the channel dimensions match appropriately for grouping
401
+ assert cout % ctx.groups == 0 # ensures that the number of output channels can be evenly divided among the groups
402
+ g_w_chans = cout//ctx.groups # number of output channels per group
403
+
404
+ ctx.save_for_backward(x_cpu, w_cpu)
405
+ ret = np.zeros((bs, cout, oy, ox), dtype=w_cpu.dtype)
406
+
407
+ for g in range(ctx.groups):
408
+ tw = w_cpu[g*g_w_chans:(g*g_w_chans+g_w_chans)].reshape(g_w_chans, -1).T # transformed kernel weights
409
+ for Y in range(oy):
410
+ for X in range(ox):
411
+ iY,iX = Y*y_stride, X*x_stride
412
+ tx = x_cpu[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W].reshape(bs, -1)
413
+ ret[:, g*g_w_chans:(g*g_w_chans+g_w_chans), Y, X] += tx.dot(tw)
414
+ return ret
415
+ except Exception as e:
416
+ import traceback
417
+ print(f"Error in conv2d operation: {str(e)}")
418
+ print(f" Self: {x}")
419
+ print(f" Arg 0: {w}")
420
+ print(f" Kwargs: {{stride: {stride}, groups: {groups}}}")
421
+ traceback.print_exc()
422
+ raise
264
423
 
265
424
  @staticmethod
266
- def backward(ctx, grad_output):
267
- x, w = ctx.saved_tensors
268
- cout, cin, H, W = w.shape
269
- dx, dw = np.zeros_like(x), np.zeros_like(w)
270
- y_stride, x_stride = ctx.stride
271
- g_w_chans = cout//ctx.groups
272
-
273
- for g in range(ctx.groups):
274
- tw = w[g*g_w_chans:(g*g_w_chans+g_w_chans)].reshape(g_w_chans, -1)
275
- for Y in range(grad_output.shape[2]):
276
- for X in range(grad_output.shape[3]):
277
- iY,iX = Y*y_stride, X*x_stride
278
- gg = grad_output[:, g*g_w_chans:(g*g_w_chans+g_w_chans), Y, X] # current multiply element in chain rule
279
- tx = x[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W].reshape(x.shape[0], -1) # slice of tensor at current conv op
280
- dw[g*g_w_chans:(g*g_w_chans+g_w_chans)] += gg.T.dot(tx).reshape((g_w_chans,cin,H,W)) # gradient with respect to input
281
- dx[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W] += gg.dot(tw).reshape(dx.shape[0], cin, H, W) # accumulate gradient with respect to weights
282
- return dx, dw
425
+ def backward(ctx: Any, grad_output: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
426
+ try:
427
+ from froog.tensor import is_buffer
428
+ from froog.gpu import download_tensor
429
+
430
+ # Convert grad_output to CPU if it's a GPU buffer
431
+ if is_buffer(grad_output):
432
+ grad_output_cpu = download_tensor(grad_output)
433
+ else:
434
+ grad_output_cpu = grad_output
435
+
436
+ x, w = ctx.saved_tensors
437
+ cout, cin, H, W = w.shape
438
+ dx, dw = np.zeros_like(x), np.zeros_like(w)
439
+ y_stride, x_stride = ctx.stride
440
+ g_w_chans = cout//ctx.groups
441
+
442
+ for g in range(ctx.groups):
443
+ tw = w[g*g_w_chans:(g*g_w_chans+g_w_chans)].reshape(g_w_chans, -1)
444
+ for Y in range(grad_output_cpu.shape[2]):
445
+ for X in range(grad_output_cpu.shape[3]):
446
+ iY,iX = Y*y_stride, X*x_stride
447
+ gg = grad_output_cpu[:, g*g_w_chans:(g*g_w_chans+g_w_chans), Y, X] # current multiply element in chain rule
448
+ tx = x[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W].reshape(x.shape[0], -1) # slice of tensor at current conv op
449
+ dw[g*g_w_chans:(g*g_w_chans+g_w_chans)] += gg.T.dot(tx).reshape((g_w_chans,cin,H,W)) # gradient with respect to input
450
+ dx[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W] += gg.dot(tw).reshape(dx.shape[0], cin, H, W) # accumulate gradient with respect to weights
451
+ return dx, dw
452
+ except Exception as e:
453
+ import traceback
454
+ print(f"Error in conv2d backward: {str(e)}")
455
+ print(f" Grad Output: {grad_output}")
456
+ traceback.print_exc()
457
+ # Fallback to original implementation
458
+ x, w = ctx.saved_tensors
459
+ cout, cin, H, W = w.shape
460
+ dx, dw = np.zeros_like(x), np.zeros_like(w)
461
+ y_stride, x_stride = ctx.stride
462
+ g_w_chans = cout//ctx.groups
463
+
464
+ for g in range(ctx.groups):
465
+ tw = w[g*g_w_chans:(g*g_w_chans+g_w_chans)].reshape(g_w_chans, -1)
466
+ for Y in range(grad_output.shape[2]):
467
+ for X in range(grad_output.shape[3]):
468
+ iY,iX = Y*y_stride, X*x_stride
469
+ gg = grad_output[:, g*g_w_chans:(g*g_w_chans+g_w_chans), Y, X] # current multiply element in chain rule
470
+ tx = x[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W].reshape(x.shape[0], -1) # slice of tensor at current conv op
471
+ dw[g*g_w_chans:(g*g_w_chans+g_w_chans)] += gg.T.dot(tx).reshape((g_w_chans,cin,H,W)) # gradient with respect to input
472
+ dx[:, g*cin:(g*cin+cin), iY:iY+H, iX:iX+W] += gg.dot(tw).reshape(dx.shape[0], cin, H, W) # accumulate gradient with respect to weights
473
+ return dx, dw
283
474
  register('conv2d', Conv2D)
284
475
 
285
476
 
@@ -290,7 +481,7 @@ class im2ColConv(Function):
290
481
  """
291
482
 
292
483
  @staticmethod
293
- def forward(ctx, x, w):
484
+ def forward(ctx: Any, x: np.ndarray, w: np.ndarray) -> np.ndarray:
294
485
  cout, cin, k_h, k_x = w.shape
295
486
  bs, oy, ox = x.shape[0], x.shape[2]-(k_h-1), x.shape[3]-(k_x-1)
296
487
  tw = w.reshape(cout, -1).T # each filter flattened into a row
@@ -300,7 +491,7 @@ class im2ColConv(Function):
300
491
  return np.moveaxis(ret, [0,1,2,3], [0,2,3,1]) # reorders the axes (batch size, number of channels, height, width)
301
492
 
302
493
  @staticmethod
303
- def backward(ctx, grad_output):
494
+ def backward(ctx: Any, grad_output: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
304
495
  bs,_,oy,ox = grad_output.shape
305
496
  tx, w = ctx.saved_tensors # transformed input, filter weights
306
497
  cout,cin,H,W = w.shape
@@ -322,7 +513,7 @@ register('im2col2dconv', im2ColConv)
322
513
  #
323
514
  # **************** pooling ops ***************
324
515
 
325
- def stack_for_pool(x, pool_y, pool_x):
516
+ def stack_for_pool(x: np.ndarray, pool_y: int, pool_x: int) -> np.ndarray:
326
517
  my, mx = (x.shape[2]//pool_y)*pool_y, (x.shape[3]//pool_x)*pool_x # ensures input tensor can be evenly divided into 2x2 blocks for max pooling
327
518
  stack = []
328
519
  cropped_x = x[:, :, :my, :mx] # crop input so 2x2 max pool can be taken
@@ -332,47 +523,49 @@ def stack_for_pool(x, pool_y, pool_x):
332
523
  return np.concatenate(stack, axis=0) # put all into one row
333
524
 
334
525
 
335
- def unstack_for_pool(fxn, s, py, px):
526
+ def unstack_for_pool(fxn: Callable[[int], np.ndarray], s: Tuple[int, ...], py: int, px: int) -> np.ndarray:
336
527
  max_y, max_x = (s[2]//py)*py, (s[3]//px)*px # get shape that allows (pool_size_y,pool_size_x) max pool
528
+ ret = None
337
529
  for Y in range(py):
338
530
  for X in range(px):
339
531
  level_w_new_grad = fxn(Y*px+X)
340
532
  if X == 0 and Y == 0: # pool of zero size
341
533
  ret = np.zeros(s, dtype=level_w_new_grad.dtype)
342
- ret[:, :, Y:max_y:py, X:max_x:px] = level_w_new_grad
343
- return ret
534
+ if ret is not None:
535
+ ret[:, :, Y:max_y:py, X:max_x:px] = level_w_new_grad
536
+ return ret if ret is not None else np.zeros(s)
344
537
 
345
538
 
346
539
  class MaxPool2D(Function):
347
540
  @staticmethod
348
- def forward(ctx, x, kernel_size=(2,2)):
541
+ def forward(ctx: Any, x: np.ndarray, kernel_size: Tuple[int, int] = (2,2)) -> np.ndarray:
542
+ ctx.kernel_size = kernel_size
349
543
  stack = stack_for_pool(x, *kernel_size)
350
544
  idx_of_max = np.argmax(stack, axis=0)
351
545
  ctx.save_for_backward(idx_of_max, x.shape)
352
546
  return np.max(stack, axis=0)
353
547
 
354
548
  @staticmethod
355
- def backward(ctx, grad_output):
549
+ def backward(ctx: Any, grad_output: np.ndarray) -> np.ndarray:
356
550
  """
357
551
  Distributes the gradient from the output of the max pooling layer to its inputs
358
552
  The purpose of (idxs == idx) is to generate a boolean mask indicating the locations of the maximum values in each 2x2 block of the original input
359
553
  The expression (Y*2+X) is a way to iterate through the four possible positions within the kernel block: e.g. (0,0), (0,1), (1,0), and (1,1), which get mapped to the indices 0, 1, 2, and 3
360
554
  """
361
- idxs, s = ctx.saved_tensors
362
- return unstack_for_pool(lambda idx: grad_output * (idxs == idx),
363
- s,
364
- *ctx.kernel_size)
555
+ idxs, s = ctx.saved_tensors
556
+ return unstack_for_pool(lambda idx: grad_output * (idxs == idx), s, *ctx.kernel_size)
365
557
  register('max_pool2d', MaxPool2D)
366
558
 
367
559
  class AvgPool2D(Function):
368
560
  @staticmethod
369
- def forward(ctx, x, kernel_size=(2,2)):
561
+ def forward(ctx: Any, x: np.ndarray, kernel_size: Tuple[int, int] = (2,2)) -> np.ndarray:
562
+ ctx.kernel_size = kernel_size
370
563
  stack = stack_for_pool(x, *kernel_size)
371
564
  ctx.save_for_backward(x.shape)
372
565
  return np.mean(stack, axis=0)
373
566
 
374
567
  @staticmethod
375
- def backward(ctx, grad_output):
568
+ def backward(ctx: Any, grad_output: np.ndarray) -> np.ndarray:
376
569
  s, = ctx.saved_tensors
377
570
  py, px = ctx.kernel_size # kernel_size passed from forward context
378
571
  my, mx = (s[2]//py)*py, (s[3]//px)*px
@@ -392,12 +585,12 @@ register('avg_pool2d', AvgPool2D)
392
585
  #
393
586
  # ************* nn ops ************
394
587
 
395
- def Linear(*x):
588
+ def Linear(*x: int) -> np.ndarray:
396
589
  # random Glorot initialization
397
590
  ret = np.random.uniform(-1., 1., size=x)/np.sqrt(np.prod(x))
398
591
  return ret.astype(np.float32)
399
592
 
400
- def swish(x):
593
+ def swish(x: Tensor) -> Tensor:
401
594
  return x.mul(x.sigmoid())
402
595
 
403
596
  class BatchNorm2D:
@@ -416,7 +609,7 @@ class BatchNorm2D:
416
609
  self.running_mean has shape [num_channels].
417
610
  self.running_mean.reshape(shape=[1, -1, 1, 1]) reshapes it to [1, num_channels, 1, 1]
418
611
  """
419
- def __init__(self, sz, eps=0.001):
612
+ def __init__(self, sz: int, eps: float = 0.001) -> None:
420
613
  self.eps = eps
421
614
  self.weight = Tensor.zeros(sz)
422
615
  self.bias = Tensor.zeros(sz)
@@ -426,7 +619,7 @@ class BatchNorm2D:
426
619
  self.running_var = Tensor.zeros(sz)
427
620
  self.num_batches_tracked = Tensor.zeros(1)
428
621
 
429
- def __call__(self, x):
622
+ def __call__(self, x: Tensor) -> Tensor:
430
623
  x = x.sub(self.running_mean.reshape(shape=[1, -1, 1, 1]))
431
624
  x = x.mul(self.weight.reshape(shape=[1, -1, 1, 1]))
432
625
  x = x.div(self.running_var.add(Tensor([self.eps], gpu=x.gpu)).reshape(shape=[1, -1, 1, 1]).sqrt())