torch-dad 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-dad
3
+ Version: 0.1.0
4
+ Summary: A highly accelerated, backprop-free Decoupled Analytical Dense (DAD) target propagation training engine on top of PyTorch.
5
+ Author: Mukundan Ramaswamy
6
+ License: MIT
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
11
+ Requires-Python: >=3.8
12
+ Description-Content-Type: text/markdown
13
+ Requires-Dist: torch>=2.0.0
14
+ Requires-Dist: torchvision
File without changes
@@ -0,0 +1,28 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "torch-dad"
7
+ version = "0.1.0"
8
+ description = "A highly accelerated, backprop-free Decoupled Analytical Dense (DAD) target propagation training engine on top of PyTorch."
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ license = {text = "MIT"}
12
+ authors = [
13
+ {name = "Mukundan Ramaswamy"}
14
+ ]
15
+ classifiers = [
16
+ "Programming Language :: Python :: 3",
17
+ "License :: OSI Approved :: MIT License",
18
+ "Operating System :: OS Independent",
19
+ "Topic :: Scientific/Engineering :: Artificial Intelligence"
20
+ ]
21
+ dependencies = [
22
+ "torch>=2.0.0",
23
+ "torchvision"
24
+ ]
25
+
26
+ [tool.setuptools.packages.find]
27
+ where = ["."]
28
+ include = ["torch_dad*"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,6 @@
1
+ from .layers import DADLinear
2
+ from .models import DADModel
3
+ from .trainers import DADTrainer
4
+
5
+ __version__ = "0.1.0"
6
+ __all__ = ["DADLinear", "DADModel", "DADTrainer"]
@@ -0,0 +1,64 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def adamw_step_fn(p, g, m, v, lr, t):
6
+ """Fused tensor-operation AdamW optimization step."""
7
+ b1, b2, eps = 0.9, 0.999, 1e-8
8
+
9
+ # Update moment estimates in-place
10
+ m.mul_(b1).add_(g, alpha=1.0 - b1)
11
+ v.mul_(b2).addcmul_(g, g, value=1.0 - b2)
12
+
13
+ bias_correction1 = 1.0 - b1 ** t
14
+ bias_correction2 = 1.0 - b2 ** t
15
+
16
+ step_size = lr / bias_correction1
17
+ denom = (v.sqrt() / torch.sqrt(bias_correction2)).add_(eps)
18
+
19
+ # In-place weight decay and gradient descent step
20
+ p.mul_(1.0 - lr * 0.01)
21
+ p.addcdiv_(m, denom, value=-step_size)
22
+
23
+ class DADLinear(nn.Module):
24
+ """
25
+ Decoupled Analytical Dense (DAD) target propagation layer.
26
+ Acts as a drop-in high-performance alternative to nn.Linear for backprop-free networks.
27
+ """
28
+ def __init__(self, in_features, out_features, num_classes=10, device=None):
29
+ super().__init__()
30
+ self.in_features = in_features
31
+ self.out_features = out_features
32
+ self.num_classes = num_classes
33
+ self.device = device
34
+
35
+ # Main Weight & Bias (No autograd gradients tracked)
36
+ self.W = nn.Parameter(
37
+ torch.randn(out_features, in_features, device=device) * (2.0 / in_features) ** 0.5,
38
+ requires_grad=False
39
+ )
40
+ self.bias = nn.Parameter(
41
+ torch.zeros(out_features, device=device),
42
+ requires_grad=False
43
+ )
44
+
45
+ # Local Task Classifier Head (Trainable)
46
+ self.W_loc = nn.Parameter(
47
+ torch.randn(num_classes, out_features, device=device) * 0.02,
48
+ requires_grad=False
49
+ )
50
+ self.b_loc = nn.Parameter(
51
+ torch.zeros(num_classes, device=device),
52
+ requires_grad=False
53
+ )
54
+
55
+ # Optimizer Moments (No Autograd tracking)
56
+ self.m_W = nn.Parameter(torch.zeros_like(self.W, device=device), requires_grad=False)
57
+ self.v_W = nn.Parameter(torch.zeros_like(self.W, device=device), requires_grad=False)
58
+ self.m_bias = nn.Parameter(torch.zeros_like(self.bias, device=device), requires_grad=False)
59
+ self.v_bias = nn.Parameter(torch.zeros_like(self.bias, device=device), requires_grad=False)
60
+
61
+ def forward(self, x):
62
+ z = F.linear(x, self.W, self.bias)
63
+ out = torch.relu(z)
64
+ return out, z
@@ -0,0 +1,33 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .layers import DADLinear
4
+
5
+ class DADModel(nn.Module):
6
+ """
7
+ Base DAD Neural Network Container.
8
+ Inherits from nn.Module, automatically scanning and managing any nested DADLinear layers.
9
+ """
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ @property
14
+ def dad_layers(self):
15
+ """Dynamically scans and returns all DADLinear layers in the model in registration order."""
16
+ return [module for module in self.modules() if isinstance(module, DADLinear)]
17
+
18
+ def forward_inference(self, x):
19
+ """
20
+ Runs standard inference through the model without target propagation overhead.
21
+ Expects a final linear layer named `self.classifier` in subclasses.
22
+ """
23
+ x = x.view(x.size(0), -1)
24
+ with torch.no_grad():
25
+ for layer in self.dad_layers:
26
+ out, _ = layer(x)
27
+ x = out
28
+
29
+ # Subclasses must define self.classifier (e.g. standard nn.Linear final head)
30
+ if hasattr(self, 'classifier'):
31
+ return self.classifier(x)
32
+ else:
33
+ raise AttributeError("DADModel subclasses must define self.classifier as their final output head.")
@@ -0,0 +1,120 @@
1
+ import time
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from .layers import adamw_step_fn
5
+
6
+ class DADTrainer:
7
+ """
8
+ Decoupled Analytical Dense (DAD) Training Engine.
9
+ Manages JIT step training loops, autocasting, and final closed-form classifier solving.
10
+ """
11
+ def __init__(self, model, device: torch.device):
12
+ self.model = model
13
+ self.device = device
14
+
15
+ # Verify model has DAD layers
16
+ dad_layers = model.dad_layers
17
+ if not dad_layers:
18
+ raise ValueError("Provided model has no DADLinear layers registered!")
19
+
20
+ self.num_classes = dad_layers[0].num_classes
21
+ self.step_counter = 0
22
+ self.t_tensor = torch.tensor(0.0, dtype=torch.float32, device=device)
23
+
24
+ # Dynamic JIT Compilation: JIT compile only on GPU to bypass compilation latency on CPU
25
+ if device.type == 'cuda':
26
+ self.compiled_step = torch.compile(self.unified_step)
27
+ else:
28
+ self.compiled_step = self.unified_step
29
+
30
+ # Pre-allocate closed-form solving matrices based on final layer output dimension
31
+ last_layer = dad_layers[-1]
32
+ out_features = last_layer.W.size(0)
33
+ self.HTH = torch.zeros(out_features + 1, out_features + 1, device=device)
34
+ self.HTY = torch.zeros(out_features + 1, self.num_classes, device=device)
35
+
36
+ def reset_step_counter(self):
37
+ """Resets the training step counters and moment updates."""
38
+ self.step_counter = 0
39
+ self.t_tensor.zero_()
40
+
41
+ def unified_step(self, x, y, lr, t):
42
+ """Unified, JIT-fusible forward-backward training step."""
43
+ # 1. Forward Pass
44
+ acts = [x.view(x.size(0), -1)]
45
+ zs = []
46
+ x_d = acts[0]
47
+ for layer in self.model.dad_layers:
48
+ out, z = layer(x_d)
49
+ acts.append(out)
50
+ zs.append(z)
51
+ x_d = out
52
+
53
+ # 2. Decoupled Backward Pass
54
+ y_true = F.one_hot(y, num_classes=self.num_classes).float()
55
+ inv_batch_size = 1.0 / x.size(0)
56
+
57
+ for i, layer in enumerate(self.model.dad_layers):
58
+ out = acts[i+1]
59
+ z = zs[i]
60
+ x_prev = acts[i]
61
+
62
+ # Local alignment projection
63
+ y_pred = F.linear(out, layer.W_loc, layer.b_loc)
64
+ probs = F.softmax(y_pred, dim=-1)
65
+ d_pred = (probs - y_true) * inv_batch_size
66
+
67
+ # Analytical local gradients calculation
68
+ g_W_loc = d_pred.t() @ out
69
+ g_b_loc = d_pred.sum(0)
70
+ d_h = d_pred @ layer.W_loc
71
+ d_z = d_h * (z > 0).float()
72
+ g_W = d_z.t() @ x_prev
73
+ g_bias = d_z.sum(0)
74
+
75
+ # In-place updates: Fused AdamW for main, ultra-fast SGD for local classifiers
76
+ adamw_step_fn(layer.W, g_W, layer.m_W, layer.v_W, lr, t)
77
+ adamw_step_fn(layer.bias, g_bias, layer.m_bias, layer.v_bias, lr, t)
78
+ layer.W_loc.add_(g_W_loc, alpha=-1e-3)
79
+ layer.b_loc.add_(g_b_loc, alpha=-1e-3)
80
+
81
+ return acts[-1]
82
+
83
+ @torch.no_grad()
84
+ def train_epoch(self, loader, amp_dtype, accumulate_solver=False):
85
+ """Trains the model for one full epoch, with optional closed-form matrix accumulations."""
86
+ self.model.train()
87
+ for x, y in loader:
88
+ self.step_counter += 1
89
+ self.t_tensor.add_(1.0)
90
+
91
+ with torch.amp.autocast(self.device.type, dtype=amp_dtype, enabled=(self.device.type == 'cuda')):
92
+ out_last = self.compiled_step(x, y, 1e-3, self.t_tensor)
93
+
94
+ # Eager solver accumulation outside JIT to guarantee zero graph breaks
95
+ if accumulate_solver:
96
+ ones = torch.ones(out_last.size(0), 1, device=self.device)
97
+ h_aug = torch.cat([out_last, ones], dim=1)
98
+ y_onehot = F.one_hot(y, num_classes=self.num_classes).float()
99
+ self.HTH.add_(h_aug.t() @ h_aug)
100
+ self.HTY.add_(h_aug.t() @ y_onehot)
101
+
102
+ @torch.no_grad()
103
+ def solve_head(self, lambda_reg=1e-3):
104
+ """Instantly solves the optimal final classifier linear mapping in VRAM."""
105
+ t0 = time.time()
106
+ out_features = self.model.dad_layers[-1].W.size(0)
107
+
108
+ reg = lambda_reg * torch.eye(out_features + 1, device=self.device)
109
+ try:
110
+ W_aug = torch.linalg.solve(self.HTH + reg, self.HTY)
111
+ except RuntimeError:
112
+ W_aug = torch.linalg.pinv(self.HTH + reg) @ self.HTY
113
+
114
+ W = W_aug[:-1, :].t()
115
+ b = W_aug[-1, :]
116
+ self.model.classifier.weight.copy_(W)
117
+ self.model.classifier.bias.copy_(b)
118
+
119
+ elapsed = time.time() - t0
120
+ return elapsed
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-dad
3
+ Version: 0.1.0
4
+ Summary: A highly accelerated, backprop-free Decoupled Analytical Dense (DAD) target propagation training engine on top of PyTorch.
5
+ Author: Mukundan Ramaswamy
6
+ License: MIT
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
11
+ Requires-Python: >=3.8
12
+ Description-Content-Type: text/markdown
13
+ Requires-Dist: torch>=2.0.0
14
+ Requires-Dist: torchvision
@@ -0,0 +1,11 @@
1
+ README.md
2
+ pyproject.toml
3
+ torch_dad/__init__.py
4
+ torch_dad/layers.py
5
+ torch_dad/models.py
6
+ torch_dad/trainers.py
7
+ torch_dad.egg-info/PKG-INFO
8
+ torch_dad.egg-info/SOURCES.txt
9
+ torch_dad.egg-info/dependency_links.txt
10
+ torch_dad.egg-info/requires.txt
11
+ torch_dad.egg-info/top_level.txt
@@ -0,0 +1,2 @@
1
+ torch>=2.0.0
2
+ torchvision
@@ -0,0 +1 @@
1
+ torch_dad