titan-synapse 0.1.1 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,381 @@
1
+ //! Expert — Sparse Mixture of Experts Blocks
2
+ //!
3
+ //! Each expert is a specialized feed-forward network that processes tokens
4
+ //! routed to it by the Thalamus. Only top-k experts activate per token,
5
+ //! giving us sparse activation — most of the model is dormant at any time.
6
+ //!
7
+ //! Key insight: A 3B parameter model with 8 experts and top-2 routing
8
+ //! only uses ~800M parameters per token. You get the knowledge capacity
9
+ //! of 3B params with the speed of 800M.
10
+ //!
11
+ //! OBSERVABILITY: Each expert tracks its own activation statistics,
12
+ //! specialization score, and contribution magnitude. You can see
13
+ //! exactly which experts are doing the heavy lifting and which are coasting.
14
+
15
+ use anyhow::Result;
16
+ use candle_core::{Device, Tensor, DType, D};
17
+
18
+ /// Configuration for the expert pool
19
+ #[derive(Debug, Clone)]
20
+ pub struct ExpertPoolConfig {
21
+ /// Input/output dimension
22
+ pub d_model: usize,
23
+ /// Expert hidden dimension (typically 4x d_model, like transformer FFN)
24
+ pub d_expert: usize,
25
+ /// Number of experts
26
+ pub n_experts: usize,
27
+ /// Device
28
+ pub device: Device,
29
+ }
30
+
31
+ impl Default for ExpertPoolConfig {
32
+ fn default() -> Self {
33
+ Self {
34
+ d_model: 768,
35
+ d_expert: 3072,
36
+ n_experts: 8,
37
+ device: Device::Cpu,
38
+ }
39
+ }
40
+ }
41
+
42
+ /// Introspection data for a single expert
43
+ #[derive(Debug, Clone)]
44
+ pub struct ExpertStats {
45
+ /// Expert name/index
46
+ pub name: String,
47
+ /// How many tokens this expert processed
48
+ pub tokens_processed: usize,
49
+ /// Average output magnitude (L2 norm)
50
+ pub avg_output_magnitude: f32,
51
+ /// Average activation sparsity (% of hidden units near zero)
52
+ pub avg_activation_sparsity: f32,
53
+ /// Specialization score: how different this expert's output is from average
54
+ pub specialization_score: f32,
55
+ }
56
+
57
+ /// Full introspection for the expert pool
58
+ #[derive(Debug, Clone)]
59
+ pub struct ExpertPoolIntrospection {
60
+ /// Per-expert statistics
61
+ pub expert_stats: Vec<ExpertStats>,
62
+ /// Which experts contributed most to the output (by magnitude)
63
+ pub top_contributors: Vec<(usize, f32)>,
64
+ /// Total tokens processed in last forward pass
65
+ pub total_tokens: usize,
66
+ /// Sparsity: average fraction of experts activated
67
+ pub activation_sparsity: f32,
68
+ }
69
+
70
+ /// A single expert — gated feed-forward network with SwiGLU activation
71
+ pub struct Expert {
72
+ /// Up projection: d_model → d_expert
73
+ w_up: Tensor,
74
+ /// Gate projection: d_model → d_expert (for SwiGLU)
75
+ w_gate: Tensor,
76
+ /// Down projection: d_expert → d_model
77
+ w_down: Tensor,
78
+
79
+ // Running stats
80
+ tokens_processed: usize,
81
+ total_output_magnitude: f32,
82
+ total_activation_sparsity: f32,
83
+ }
84
+
85
+ impl Expert {
86
+ pub fn new(d_model: usize, d_expert: usize, device: &Device) -> Result<Self> {
87
+ let scale_in = (1.0 / d_model as f64).sqrt() as f32;
88
+ let scale_h = (1.0 / d_expert as f64).sqrt() as f32;
89
+
90
+ let w_up = Tensor::randn(0f32, scale_in, &[d_expert, d_model], device)?;
91
+ let w_gate = Tensor::randn(0f32, scale_in, &[d_expert, d_model], device)?;
92
+ let w_down = Tensor::randn(0f32, scale_h, &[d_model, d_expert], device)?;
93
+
94
+ Ok(Self {
95
+ w_up,
96
+ w_gate,
97
+ w_down,
98
+ tokens_processed: 0,
99
+ total_output_magnitude: 0.0,
100
+ total_activation_sparsity: 0.0,
101
+ })
102
+ }
103
+
104
+ /// Forward pass: SwiGLU FFN
105
+ /// Input: (tokens, d_model) — flat token batch
106
+ /// Output: (tokens, d_model)
107
+ pub fn forward(&mut self, x: &Tensor) -> Result<Tensor> {
108
+ // SwiGLU: down(silu(gate(x)) * up(x))
109
+ let gate = x.matmul(&self.w_gate.t()?)?;
110
+ let up = x.matmul(&self.w_up.t()?)?;
111
+ let gate_act = silu(&gate)?;
112
+ let hidden = (&gate_act * &up)?;
113
+
114
+ // Track activation sparsity (how many hidden units are near zero)
115
+ let sparsity = compute_sparsity(&hidden)?;
116
+ let n_tokens = x.dims()[0];
117
+ self.tokens_processed += n_tokens;
118
+ self.total_activation_sparsity += sparsity * n_tokens as f32;
119
+
120
+ let output = hidden.matmul(&self.w_down.t()?)?;
121
+
122
+ // Track output magnitude
123
+ let mag = output.sqr()?.mean_all()?.to_scalar::<f32>()?.sqrt();
124
+ self.total_output_magnitude += mag * n_tokens as f32;
125
+
126
+ Ok(output)
127
+ }
128
+
129
+ /// Get running statistics
130
+ pub fn stats(&self, name: &str) -> ExpertStats {
131
+ let n = self.tokens_processed.max(1) as f32;
132
+ ExpertStats {
133
+ name: name.to_string(),
134
+ tokens_processed: self.tokens_processed,
135
+ avg_output_magnitude: self.total_output_magnitude / n,
136
+ avg_activation_sparsity: self.total_activation_sparsity / n,
137
+ specialization_score: 0.0, // Computed at pool level
138
+ }
139
+ }
140
+
141
+ /// Reset statistics
142
+ pub fn reset_stats(&mut self) {
143
+ self.tokens_processed = 0;
144
+ self.total_output_magnitude = 0.0;
145
+ self.total_activation_sparsity = 0.0;
146
+ }
147
+ }
148
+
149
+ /// Pool of experts managed by the Thalamus router
150
+ pub struct ExpertPool {
151
+ config: ExpertPoolConfig,
152
+ /// The experts themselves
153
+ experts: Vec<Expert>,
154
+ /// Expert names
155
+ expert_names: Vec<String>,
156
+ /// Last introspection
157
+ last_introspection: Option<ExpertPoolIntrospection>,
158
+ }
159
+
160
+ impl ExpertPool {
161
+ pub fn new(config: ExpertPoolConfig) -> Result<Self> {
162
+ let mut experts = Vec::with_capacity(config.n_experts);
163
+ let mut names = Vec::with_capacity(config.n_experts);
164
+
165
+ for i in 0..config.n_experts {
166
+ experts.push(Expert::new(config.d_model, config.d_expert, &config.device)?);
167
+ names.push(format!("expert_{i}"));
168
+ }
169
+
170
+ Ok(Self {
171
+ experts,
172
+ expert_names: names,
173
+ last_introspection: None,
174
+ config,
175
+ })
176
+ }
177
+
178
+ /// Set expert names
179
+ pub fn set_expert_names(&mut self, names: Vec<String>) {
180
+ self.expert_names = names;
181
+ }
182
+
183
+ /// Forward pass — route tokens to selected experts and combine
184
+ ///
185
+ /// x: (batch, seq_len, d_model)
186
+ /// routing_weights: (batch, seq_len, top_k)
187
+ /// expert_indices: [batch][seq][top_k] — which experts to use
188
+ ///
189
+ /// Output: (batch, seq_len, d_model)
190
+ pub fn forward(
191
+ &mut self,
192
+ x: &Tensor,
193
+ routing_weights: &Tensor,
194
+ expert_indices: &[Vec<Vec<usize>>],
195
+ ) -> Result<Tensor> {
196
+ let (batch, seq_len, d_model) = x.dims3()?;
197
+ let dev = &self.config.device;
198
+
199
+ let mut output = Tensor::zeros(&[batch, seq_len, d_model], DType::F32, dev)?;
200
+
201
+ // Process each expert's assigned tokens
202
+ // (In production, this would be batched more efficiently)
203
+ for b in 0..batch {
204
+ for t in 0..seq_len {
205
+ let x_t = x.narrow(0, b, 1)?.narrow(1, t, 1)?
206
+ .reshape(&[1, d_model])?; // (1, d_model)
207
+
208
+ let mut combined = Tensor::zeros(&[1, d_model], DType::F32, dev)?;
209
+
210
+ for (k, &expert_idx) in expert_indices[b][t].iter().enumerate() {
211
+ if expert_idx >= self.experts.len() {
212
+ continue;
213
+ }
214
+
215
+ // Get routing weight for this expert
216
+ let weight = routing_weights
217
+ .narrow(0, b, 1)?
218
+ .narrow(1, t, 1)?
219
+ .narrow(2, k, 1)?
220
+ .squeeze(0)?.squeeze(0)?.squeeze(0)?
221
+ .to_scalar::<f32>()?;
222
+
223
+ // Run through expert
224
+ let expert_out = self.experts[expert_idx].forward(&x_t)?;
225
+
226
+ // Weight and accumulate
227
+ let weight_t = Tensor::new(&[weight], dev)?;
228
+ let weighted = expert_out.broadcast_mul(&weight_t)?;
229
+ combined = (&combined + &weighted)?;
230
+ }
231
+
232
+ // Place combined output
233
+ let combined_3d = combined.unsqueeze(0)?; // (1, 1, d_model)
234
+ output = output.slice_assign(
235
+ &[b..b+1, t..t+1, 0..d_model],
236
+ &combined_3d,
237
+ )?;
238
+ }
239
+ }
240
+
241
+ // Build introspection
242
+ let expert_stats: Vec<ExpertStats> = self.experts.iter()
243
+ .enumerate()
244
+ .map(|(i, e)| e.stats(self.expert_names.get(i).map(|s| s.as_str()).unwrap_or("?")))
245
+ .collect();
246
+
247
+ let mut top_contributors: Vec<(usize, f32)> = expert_stats.iter()
248
+ .enumerate()
249
+ .map(|(i, s)| (i, s.avg_output_magnitude))
250
+ .collect();
251
+ top_contributors.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
252
+
253
+ let active_experts: usize = expert_stats.iter()
254
+ .filter(|s| s.tokens_processed > 0)
255
+ .count();
256
+ let activation_sparsity = 1.0 - (active_experts as f32 / self.config.n_experts as f32);
257
+
258
+ self.last_introspection = Some(ExpertPoolIntrospection {
259
+ expert_stats,
260
+ top_contributors,
261
+ total_tokens: batch * seq_len,
262
+ activation_sparsity,
263
+ });
264
+
265
+ Ok(output)
266
+ }
267
+
268
+ /// Get introspection data
269
+ pub fn introspect(&self) -> Option<&ExpertPoolIntrospection> {
270
+ self.last_introspection.as_ref()
271
+ }
272
+
273
+ /// Reset all expert statistics
274
+ pub fn reset_stats(&mut self) {
275
+ for expert in &mut self.experts {
276
+ expert.reset_stats();
277
+ }
278
+ }
279
+
280
+ /// Number of experts
281
+ pub fn n_experts(&self) -> usize {
282
+ self.experts.len()
283
+ }
284
+ }
285
+
286
+ /// Compute activation sparsity (fraction of values near zero)
287
+ fn compute_sparsity(x: &Tensor) -> Result<f32> {
288
+ let abs = x.abs()?;
289
+ let threshold = Tensor::new(&[0.01f32], x.device())?.broadcast_as(abs.shape())?;
290
+ let near_zero = abs.lt(&threshold)?;
291
+ let total = x.elem_count() as f32;
292
+ let sparse_count = near_zero.to_dtype(DType::F32)?.sum_all()?.to_scalar::<f32>()?;
293
+ Ok(sparse_count / total)
294
+ }
295
+
296
+ /// SiLU activation
297
+ fn silu(x: &Tensor) -> Result<Tensor> {
298
+ let sigmoid = candle_nn::ops::sigmoid(x)?;
299
+ x.mul(&sigmoid).map_err(|e| anyhow::anyhow!("{e}"))
300
+ }
301
+
302
+ #[cfg(test)]
303
+ mod tests {
304
+ use super::*;
305
+
306
+ #[test]
307
+ fn test_expert_creation() {
308
+ let expert = Expert::new(64, 256, &Device::Cpu);
309
+ assert!(expert.is_ok());
310
+ }
311
+
312
+ #[test]
313
+ fn test_expert_forward() {
314
+ let mut expert = Expert::new(64, 256, &Device::Cpu).unwrap();
315
+ let x = Tensor::randn(0f32, 1.0, &[4, 64], &Device::Cpu).unwrap();
316
+ let out = expert.forward(&x).unwrap();
317
+ assert_eq!(out.dims(), &[4, 64]);
318
+ }
319
+
320
+ #[test]
321
+ fn test_expert_pool() {
322
+ let config = ExpertPoolConfig {
323
+ d_model: 64,
324
+ d_expert: 256,
325
+ n_experts: 4,
326
+ device: Device::Cpu,
327
+ };
328
+ let pool = ExpertPool::new(config);
329
+ assert!(pool.is_ok());
330
+ assert_eq!(pool.unwrap().n_experts(), 4);
331
+ }
332
+
333
+ #[test]
334
+ fn test_expert_pool_forward() {
335
+ let config = ExpertPoolConfig {
336
+ d_model: 32,
337
+ d_expert: 128,
338
+ n_experts: 4,
339
+ device: Device::Cpu,
340
+ };
341
+ let mut pool = ExpertPool::new(config).unwrap();
342
+
343
+ let x = Tensor::randn(0f32, 1.0, &[1, 4, 32], &Device::Cpu).unwrap();
344
+ let weights = Tensor::new(
345
+ &[0.6f32, 0.4, 0.5, 0.5, 0.7, 0.3, 0.55, 0.45],
346
+ &Device::Cpu,
347
+ ).unwrap().reshape(&[1, 4, 2]).unwrap();
348
+
349
+ let indices = vec![vec![
350
+ vec![0, 1], vec![1, 2], vec![0, 3], vec![2, 3],
351
+ ]];
352
+
353
+ let out = pool.forward(&x, &weights, &indices).unwrap();
354
+ assert_eq!(out.dims(), &[1, 4, 32]);
355
+ }
356
+
357
+ #[test]
358
+ fn test_expert_introspection() {
359
+ let config = ExpertPoolConfig {
360
+ d_model: 32,
361
+ d_expert: 128,
362
+ n_experts: 4,
363
+ device: Device::Cpu,
364
+ };
365
+ let mut pool = ExpertPool::new(config).unwrap();
366
+
367
+ let x = Tensor::randn(0f32, 1.0, &[1, 4, 32], &Device::Cpu).unwrap();
368
+ let weights = Tensor::new(
369
+ &[0.6f32, 0.4, 0.5, 0.5, 0.7, 0.3, 0.55, 0.45],
370
+ &Device::Cpu,
371
+ ).unwrap().reshape(&[1, 4, 2]).unwrap();
372
+ let indices = vec![vec![
373
+ vec![0, 1], vec![1, 2], vec![0, 3], vec![2, 3],
374
+ ]];
375
+
376
+ let _ = pool.forward(&x, &weights, &indices).unwrap();
377
+ let intro = pool.introspect().unwrap();
378
+ assert_eq!(intro.expert_stats.len(), 4);
379
+ assert_eq!(intro.total_tokens, 4);
380
+ }
381
+ }