torch-dad 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_dad-0.1.0/PKG-INFO +14 -0
- torch_dad-0.1.0/README.md +0 -0
- torch_dad-0.1.0/pyproject.toml +28 -0
- torch_dad-0.1.0/setup.cfg +4 -0
- torch_dad-0.1.0/torch_dad/__init__.py +6 -0
- torch_dad-0.1.0/torch_dad/layers.py +64 -0
- torch_dad-0.1.0/torch_dad/models.py +33 -0
- torch_dad-0.1.0/torch_dad/trainers.py +120 -0
- torch_dad-0.1.0/torch_dad.egg-info/PKG-INFO +14 -0
- torch_dad-0.1.0/torch_dad.egg-info/SOURCES.txt +11 -0
- torch_dad-0.1.0/torch_dad.egg-info/dependency_links.txt +1 -0
- torch_dad-0.1.0/torch_dad.egg-info/requires.txt +2 -0
- torch_dad-0.1.0/torch_dad.egg-info/top_level.txt +1 -0
torch_dad-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torch-dad
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A highly accelerated, backprop-free Decoupled Analytical Dense (DAD) target propagation training engine on top of PyTorch.
|
|
5
|
+
Author: Mukundan Ramaswamy
|
|
6
|
+
License: MIT
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
10
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
11
|
+
Requires-Python: >=3.8
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
Requires-Dist: torch>=2.0.0
|
|
14
|
+
Requires-Dist: torchvision
|
|
File without changes
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "torch-dad"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "A highly accelerated, backprop-free Decoupled Analytical Dense (DAD) target propagation training engine on top of PyTorch."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.8"
|
|
11
|
+
license = {text = "MIT"}
|
|
12
|
+
authors = [
|
|
13
|
+
{name = "Mukundan Ramaswamy"}
|
|
14
|
+
]
|
|
15
|
+
classifiers = [
|
|
16
|
+
"Programming Language :: Python :: 3",
|
|
17
|
+
"License :: OSI Approved :: MIT License",
|
|
18
|
+
"Operating System :: OS Independent",
|
|
19
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
|
20
|
+
]
|
|
21
|
+
dependencies = [
|
|
22
|
+
"torch>=2.0.0",
|
|
23
|
+
"torchvision"
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
[tool.setuptools.packages.find]
|
|
27
|
+
where = ["."]
|
|
28
|
+
include = ["torch_dad*"]
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
def adamw_step_fn(p, g, m, v, lr, t):
|
|
6
|
+
"""Fused tensor-operation AdamW optimization step."""
|
|
7
|
+
b1, b2, eps = 0.9, 0.999, 1e-8
|
|
8
|
+
|
|
9
|
+
# Update moment estimates in-place
|
|
10
|
+
m.mul_(b1).add_(g, alpha=1.0 - b1)
|
|
11
|
+
v.mul_(b2).addcmul_(g, g, value=1.0 - b2)
|
|
12
|
+
|
|
13
|
+
bias_correction1 = 1.0 - b1 ** t
|
|
14
|
+
bias_correction2 = 1.0 - b2 ** t
|
|
15
|
+
|
|
16
|
+
step_size = lr / bias_correction1
|
|
17
|
+
denom = (v.sqrt() / torch.sqrt(bias_correction2)).add_(eps)
|
|
18
|
+
|
|
19
|
+
# In-place weight decay and gradient descent step
|
|
20
|
+
p.mul_(1.0 - lr * 0.01)
|
|
21
|
+
p.addcdiv_(m, denom, value=-step_size)
|
|
22
|
+
|
|
23
|
+
class DADLinear(nn.Module):
|
|
24
|
+
"""
|
|
25
|
+
Decoupled Analytical Dense (DAD) target propagation layer.
|
|
26
|
+
Acts as a drop-in high-performance alternative to nn.Linear for backprop-free networks.
|
|
27
|
+
"""
|
|
28
|
+
def __init__(self, in_features, out_features, num_classes=10, device=None):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.in_features = in_features
|
|
31
|
+
self.out_features = out_features
|
|
32
|
+
self.num_classes = num_classes
|
|
33
|
+
self.device = device
|
|
34
|
+
|
|
35
|
+
# Main Weight & Bias (No autograd gradients tracked)
|
|
36
|
+
self.W = nn.Parameter(
|
|
37
|
+
torch.randn(out_features, in_features, device=device) * (2.0 / in_features) ** 0.5,
|
|
38
|
+
requires_grad=False
|
|
39
|
+
)
|
|
40
|
+
self.bias = nn.Parameter(
|
|
41
|
+
torch.zeros(out_features, device=device),
|
|
42
|
+
requires_grad=False
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Local Task Classifier Head (Trainable)
|
|
46
|
+
self.W_loc = nn.Parameter(
|
|
47
|
+
torch.randn(num_classes, out_features, device=device) * 0.02,
|
|
48
|
+
requires_grad=False
|
|
49
|
+
)
|
|
50
|
+
self.b_loc = nn.Parameter(
|
|
51
|
+
torch.zeros(num_classes, device=device),
|
|
52
|
+
requires_grad=False
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Optimizer Moments (No Autograd tracking)
|
|
56
|
+
self.m_W = nn.Parameter(torch.zeros_like(self.W, device=device), requires_grad=False)
|
|
57
|
+
self.v_W = nn.Parameter(torch.zeros_like(self.W, device=device), requires_grad=False)
|
|
58
|
+
self.m_bias = nn.Parameter(torch.zeros_like(self.bias, device=device), requires_grad=False)
|
|
59
|
+
self.v_bias = nn.Parameter(torch.zeros_like(self.bias, device=device), requires_grad=False)
|
|
60
|
+
|
|
61
|
+
def forward(self, x):
|
|
62
|
+
z = F.linear(x, self.W, self.bias)
|
|
63
|
+
out = torch.relu(z)
|
|
64
|
+
return out, z
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .layers import DADLinear
|
|
4
|
+
|
|
5
|
+
class DADModel(nn.Module):
|
|
6
|
+
"""
|
|
7
|
+
Base DAD Neural Network Container.
|
|
8
|
+
Inherits from nn.Module, automatically scanning and managing any nested DADLinear layers.
|
|
9
|
+
"""
|
|
10
|
+
def __init__(self):
|
|
11
|
+
super().__init__()
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def dad_layers(self):
|
|
15
|
+
"""Dynamically scans and returns all DADLinear layers in the model in registration order."""
|
|
16
|
+
return [module for module in self.modules() if isinstance(module, DADLinear)]
|
|
17
|
+
|
|
18
|
+
def forward_inference(self, x):
|
|
19
|
+
"""
|
|
20
|
+
Runs standard inference through the model without target propagation overhead.
|
|
21
|
+
Expects a final linear layer named `self.classifier` in subclasses.
|
|
22
|
+
"""
|
|
23
|
+
x = x.view(x.size(0), -1)
|
|
24
|
+
with torch.no_grad():
|
|
25
|
+
for layer in self.dad_layers:
|
|
26
|
+
out, _ = layer(x)
|
|
27
|
+
x = out
|
|
28
|
+
|
|
29
|
+
# Subclasses must define self.classifier (e.g. standard nn.Linear final head)
|
|
30
|
+
if hasattr(self, 'classifier'):
|
|
31
|
+
return self.classifier(x)
|
|
32
|
+
else:
|
|
33
|
+
raise AttributeError("DADModel subclasses must define self.classifier as their final output head.")
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from .layers import adamw_step_fn
|
|
5
|
+
|
|
6
|
+
class DADTrainer:
|
|
7
|
+
"""
|
|
8
|
+
Decoupled Analytical Dense (DAD) Training Engine.
|
|
9
|
+
Manages JIT step training loops, autocasting, and final closed-form classifier solving.
|
|
10
|
+
"""
|
|
11
|
+
def __init__(self, model, device: torch.device):
|
|
12
|
+
self.model = model
|
|
13
|
+
self.device = device
|
|
14
|
+
|
|
15
|
+
# Verify model has DAD layers
|
|
16
|
+
dad_layers = model.dad_layers
|
|
17
|
+
if not dad_layers:
|
|
18
|
+
raise ValueError("Provided model has no DADLinear layers registered!")
|
|
19
|
+
|
|
20
|
+
self.num_classes = dad_layers[0].num_classes
|
|
21
|
+
self.step_counter = 0
|
|
22
|
+
self.t_tensor = torch.tensor(0.0, dtype=torch.float32, device=device)
|
|
23
|
+
|
|
24
|
+
# Dynamic JIT Compilation: JIT compile only on GPU to bypass compilation latency on CPU
|
|
25
|
+
if device.type == 'cuda':
|
|
26
|
+
self.compiled_step = torch.compile(self.unified_step)
|
|
27
|
+
else:
|
|
28
|
+
self.compiled_step = self.unified_step
|
|
29
|
+
|
|
30
|
+
# Pre-allocate closed-form solving matrices based on final layer output dimension
|
|
31
|
+
last_layer = dad_layers[-1]
|
|
32
|
+
out_features = last_layer.W.size(0)
|
|
33
|
+
self.HTH = torch.zeros(out_features + 1, out_features + 1, device=device)
|
|
34
|
+
self.HTY = torch.zeros(out_features + 1, self.num_classes, device=device)
|
|
35
|
+
|
|
36
|
+
def reset_step_counter(self):
|
|
37
|
+
"""Resets the training step counters and moment updates."""
|
|
38
|
+
self.step_counter = 0
|
|
39
|
+
self.t_tensor.zero_()
|
|
40
|
+
|
|
41
|
+
def unified_step(self, x, y, lr, t):
|
|
42
|
+
"""Unified, JIT-fusible forward-backward training step."""
|
|
43
|
+
# 1. Forward Pass
|
|
44
|
+
acts = [x.view(x.size(0), -1)]
|
|
45
|
+
zs = []
|
|
46
|
+
x_d = acts[0]
|
|
47
|
+
for layer in self.model.dad_layers:
|
|
48
|
+
out, z = layer(x_d)
|
|
49
|
+
acts.append(out)
|
|
50
|
+
zs.append(z)
|
|
51
|
+
x_d = out
|
|
52
|
+
|
|
53
|
+
# 2. Decoupled Backward Pass
|
|
54
|
+
y_true = F.one_hot(y, num_classes=self.num_classes).float()
|
|
55
|
+
inv_batch_size = 1.0 / x.size(0)
|
|
56
|
+
|
|
57
|
+
for i, layer in enumerate(self.model.dad_layers):
|
|
58
|
+
out = acts[i+1]
|
|
59
|
+
z = zs[i]
|
|
60
|
+
x_prev = acts[i]
|
|
61
|
+
|
|
62
|
+
# Local alignment projection
|
|
63
|
+
y_pred = F.linear(out, layer.W_loc, layer.b_loc)
|
|
64
|
+
probs = F.softmax(y_pred, dim=-1)
|
|
65
|
+
d_pred = (probs - y_true) * inv_batch_size
|
|
66
|
+
|
|
67
|
+
# Analytical local gradients calculation
|
|
68
|
+
g_W_loc = d_pred.t() @ out
|
|
69
|
+
g_b_loc = d_pred.sum(0)
|
|
70
|
+
d_h = d_pred @ layer.W_loc
|
|
71
|
+
d_z = d_h * (z > 0).float()
|
|
72
|
+
g_W = d_z.t() @ x_prev
|
|
73
|
+
g_bias = d_z.sum(0)
|
|
74
|
+
|
|
75
|
+
# In-place updates: Fused AdamW for main, ultra-fast SGD for local classifiers
|
|
76
|
+
adamw_step_fn(layer.W, g_W, layer.m_W, layer.v_W, lr, t)
|
|
77
|
+
adamw_step_fn(layer.bias, g_bias, layer.m_bias, layer.v_bias, lr, t)
|
|
78
|
+
layer.W_loc.add_(g_W_loc, alpha=-1e-3)
|
|
79
|
+
layer.b_loc.add_(g_b_loc, alpha=-1e-3)
|
|
80
|
+
|
|
81
|
+
return acts[-1]
|
|
82
|
+
|
|
83
|
+
@torch.no_grad()
|
|
84
|
+
def train_epoch(self, loader, amp_dtype, accumulate_solver=False):
|
|
85
|
+
"""Trains the model for one full epoch, with optional closed-form matrix accumulations."""
|
|
86
|
+
self.model.train()
|
|
87
|
+
for x, y in loader:
|
|
88
|
+
self.step_counter += 1
|
|
89
|
+
self.t_tensor.add_(1.0)
|
|
90
|
+
|
|
91
|
+
with torch.amp.autocast(self.device.type, dtype=amp_dtype, enabled=(self.device.type == 'cuda')):
|
|
92
|
+
out_last = self.compiled_step(x, y, 1e-3, self.t_tensor)
|
|
93
|
+
|
|
94
|
+
# Eager solver accumulation outside JIT to guarantee zero graph breaks
|
|
95
|
+
if accumulate_solver:
|
|
96
|
+
ones = torch.ones(out_last.size(0), 1, device=self.device)
|
|
97
|
+
h_aug = torch.cat([out_last, ones], dim=1)
|
|
98
|
+
y_onehot = F.one_hot(y, num_classes=self.num_classes).float()
|
|
99
|
+
self.HTH.add_(h_aug.t() @ h_aug)
|
|
100
|
+
self.HTY.add_(h_aug.t() @ y_onehot)
|
|
101
|
+
|
|
102
|
+
@torch.no_grad()
|
|
103
|
+
def solve_head(self, lambda_reg=1e-3):
|
|
104
|
+
"""Instantly solves the optimal final classifier linear mapping in VRAM."""
|
|
105
|
+
t0 = time.time()
|
|
106
|
+
out_features = self.model.dad_layers[-1].W.size(0)
|
|
107
|
+
|
|
108
|
+
reg = lambda_reg * torch.eye(out_features + 1, device=self.device)
|
|
109
|
+
try:
|
|
110
|
+
W_aug = torch.linalg.solve(self.HTH + reg, self.HTY)
|
|
111
|
+
except RuntimeError:
|
|
112
|
+
W_aug = torch.linalg.pinv(self.HTH + reg) @ self.HTY
|
|
113
|
+
|
|
114
|
+
W = W_aug[:-1, :].t()
|
|
115
|
+
b = W_aug[-1, :]
|
|
116
|
+
self.model.classifier.weight.copy_(W)
|
|
117
|
+
self.model.classifier.bias.copy_(b)
|
|
118
|
+
|
|
119
|
+
elapsed = time.time() - t0
|
|
120
|
+
return elapsed
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torch-dad
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A highly accelerated, backprop-free Decoupled Analytical Dense (DAD) target propagation training engine on top of PyTorch.
|
|
5
|
+
Author: Mukundan Ramaswamy
|
|
6
|
+
License: MIT
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
10
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
11
|
+
Requires-Python: >=3.8
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
Requires-Dist: torch>=2.0.0
|
|
14
|
+
Requires-Dist: torchvision
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
torch_dad/__init__.py
|
|
4
|
+
torch_dad/layers.py
|
|
5
|
+
torch_dad/models.py
|
|
6
|
+
torch_dad/trainers.py
|
|
7
|
+
torch_dad.egg-info/PKG-INFO
|
|
8
|
+
torch_dad.egg-info/SOURCES.txt
|
|
9
|
+
torch_dad.egg-info/dependency_links.txt
|
|
10
|
+
torch_dad.egg-info/requires.txt
|
|
11
|
+
torch_dad.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
torch_dad
|