@rlabs-inc/sparse 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,734 @@
1
+ // ============================================================================
2
+ // SPARSE - Metal Implementation
3
+ // Objective-C bridge between C API and Metal compute
4
+ // ============================================================================
5
+
6
+ #import <Foundation/Foundation.h>
7
+ #import <Metal/Metal.h>
8
+ #include "sparse.h"
9
+ #include <stdlib.h>
10
+
11
+ // ============================================================================
12
+ // INTERNAL STRUCTURES
13
+ // ============================================================================
14
+
15
+ struct SparseContext {
16
+ id<MTLDevice> device;
17
+ id<MTLCommandQueue> commandQueue;
18
+ id<MTLLibrary> library;
19
+
20
+ // Compiled kernel functions
21
+ id<MTLComputePipelineState> scatterAddPipeline;
22
+ id<MTLComputePipelineState> addArraysPipeline;
23
+ id<MTLComputePipelineState> addScalarPipeline;
24
+ id<MTLComputePipelineState> multiplyArraysPipeline;
25
+ id<MTLComputePipelineState> multiplyScalarPipeline;
26
+ id<MTLComputePipelineState> squarePipeline;
27
+ id<MTLComputePipelineState> greaterEqualPipeline;
28
+ id<MTLComputePipelineState> whereSelectPipeline;
29
+ id<MTLComputePipelineState> whereScalarPipeline;
30
+ id<MTLComputePipelineState> gatherPipeline;
31
+ id<MTLComputePipelineState> gatherBoolPipeline;
32
+ id<MTLComputePipelineState> fillFloatPipeline;
33
+ id<MTLComputePipelineState> fillZerosPipeline;
34
+ id<MTLComputePipelineState> sumReducePipeline;
35
+
36
+ char deviceName[256];
37
+ };
38
+
39
+ struct SparseBuffer {
40
+ id<MTLBuffer> buffer;
41
+ uint32_t count;
42
+ SparseDataType dtype;
43
+ SparseContextRef ctx;
44
+ };
45
+
46
+ // ============================================================================
47
+ // HELPER FUNCTIONS
48
+ // ============================================================================
49
+
50
+ static id<MTLComputePipelineState> createPipeline(SparseContextRef ctx, NSString* functionName) {
51
+ NSError* error = nil;
52
+ id<MTLFunction> function = [ctx->library newFunctionWithName:functionName];
53
+ if (!function) {
54
+ NSLog(@"Failed to find function: %@", functionName);
55
+ return nil;
56
+ }
57
+
58
+ id<MTLComputePipelineState> pipeline = [ctx->device newComputePipelineStateWithFunction:function error:&error];
59
+ if (error) {
60
+ NSLog(@"Failed to create pipeline for %@: %@", functionName, error);
61
+ return nil;
62
+ }
63
+
64
+ return pipeline;
65
+ }
66
+
67
+ static void dispatchCompute(SparseContextRef ctx, id<MTLComputePipelineState> pipeline, uint32_t count) {
68
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
69
+ id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
70
+
71
+ [encoder setComputePipelineState:pipeline];
72
+
73
+ NSUInteger threadGroupSize = pipeline.maxTotalThreadsPerThreadgroup;
74
+ if (threadGroupSize > 256) threadGroupSize = 256;
75
+
76
+ MTLSize gridSize = MTLSizeMake(count, 1, 1);
77
+ MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
78
+
79
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
80
+ [encoder endEncoding];
81
+
82
+ [commandBuffer commit];
83
+ [commandBuffer waitUntilCompleted];
84
+ }
85
+
86
+ // ============================================================================
87
+ // CONTEXT MANAGEMENT
88
+ // ============================================================================
89
+
90
+ SparseContextRef sparse_init(void) {
91
+ @autoreleasepool {
92
+ SparseContextRef ctx = (SparseContextRef)calloc(1, sizeof(struct SparseContext));
93
+ if (!ctx) return NULL;
94
+
95
+ // Get the default Metal device
96
+ ctx->device = MTLCreateSystemDefaultDevice();
97
+ if (!ctx->device) {
98
+ NSLog(@"Metal is not supported on this device");
99
+ free(ctx);
100
+ return NULL;
101
+ }
102
+
103
+ // Store device name
104
+ const char* name = [ctx->device.name UTF8String];
105
+ strncpy(ctx->deviceName, name, sizeof(ctx->deviceName) - 1);
106
+
107
+ // Create command queue
108
+ ctx->commandQueue = [ctx->device newCommandQueue];
109
+
110
+ // Load shader library from source
111
+ NSString* shaderPath = [[NSBundle mainBundle] pathForResource:@"sparse" ofType:@"metal"];
112
+ NSString* shaderSource = nil;
113
+
114
+ if (shaderPath) {
115
+ shaderSource = [NSString stringWithContentsOfFile:shaderPath encoding:NSUTF8StringEncoding error:nil];
116
+ }
117
+
118
+ if (!shaderSource) {
119
+ // Try loading from current directory or embedded
120
+ NSString* currentDir = [[NSFileManager defaultManager] currentDirectoryPath];
121
+ NSString* localPath = [currentDir stringByAppendingPathComponent:@"native/sparse.metal"];
122
+ shaderSource = [NSString stringWithContentsOfFile:localPath encoding:NSUTF8StringEncoding error:nil];
123
+
124
+ if (!shaderSource) {
125
+ // Try relative to library location
126
+ localPath = [currentDir stringByAppendingPathComponent:@"sparse.metal"];
127
+ shaderSource = [NSString stringWithContentsOfFile:localPath encoding:NSUTF8StringEncoding error:nil];
128
+ }
129
+ }
130
+
131
+ if (!shaderSource) {
132
+ NSLog(@"Could not load sparse.metal shader file");
133
+ free(ctx);
134
+ return NULL;
135
+ }
136
+
137
+ NSError* error = nil;
138
+ MTLCompileOptions* options = [[MTLCompileOptions alloc] init];
139
+ ctx->library = [ctx->device newLibraryWithSource:shaderSource options:options error:&error];
140
+
141
+ if (error) {
142
+ NSLog(@"Failed to compile Metal shaders: %@", error);
143
+ free(ctx);
144
+ return NULL;
145
+ }
146
+
147
+ // Create all compute pipelines
148
+ ctx->scatterAddPipeline = createPipeline(ctx, @"scatter_add");
149
+ ctx->addArraysPipeline = createPipeline(ctx, @"add_arrays");
150
+ ctx->addScalarPipeline = createPipeline(ctx, @"add_scalar");
151
+ ctx->multiplyArraysPipeline = createPipeline(ctx, @"multiply_arrays");
152
+ ctx->multiplyScalarPipeline = createPipeline(ctx, @"multiply_scalar");
153
+ ctx->squarePipeline = createPipeline(ctx, @"square");
154
+ ctx->greaterEqualPipeline = createPipeline(ctx, @"greater_equal");
155
+ ctx->whereSelectPipeline = createPipeline(ctx, @"where_select");
156
+ ctx->whereScalarPipeline = createPipeline(ctx, @"where_scalar");
157
+ ctx->gatherPipeline = createPipeline(ctx, @"gather");
158
+ ctx->gatherBoolPipeline = createPipeline(ctx, @"gather_bool");
159
+ ctx->fillFloatPipeline = createPipeline(ctx, @"fill_float");
160
+ ctx->fillZerosPipeline = createPipeline(ctx, @"fill_zeros");
161
+ ctx->sumReducePipeline = createPipeline(ctx, @"sum_reduce");
162
+
163
+ return ctx;
164
+ }
165
+ }
166
+
167
+ void sparse_cleanup(SparseContextRef ctx) {
168
+ if (!ctx) return;
169
+ // ARC handles Metal object cleanup
170
+ free(ctx);
171
+ }
172
+
173
+ void sparse_sync(SparseContextRef ctx) {
174
+ if (!ctx) return;
175
+ // Create and immediately complete a command buffer to sync
176
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
177
+ [commandBuffer commit];
178
+ [commandBuffer waitUntilCompleted];
179
+ }
180
+
181
+ const char* sparse_device_name(SparseContextRef ctx) {
182
+ return ctx ? ctx->deviceName : "Unknown";
183
+ }
184
+
185
+ uint64_t sparse_device_memory(SparseContextRef ctx) {
186
+ if (!ctx) return 0;
187
+ return ctx->device.recommendedMaxWorkingSetSize;
188
+ }
189
+
190
+ // ============================================================================
191
+ // BUFFER MANAGEMENT
192
+ // ============================================================================
193
+
194
+ SparseBufferRef sparse_zeros(SparseContextRef ctx, uint32_t count, SparseDataType dtype) {
195
+ if (!ctx || count == 0) return NULL;
196
+
197
+ size_t elementSize = (dtype == SPARSE_FLOAT32) ? sizeof(float) : sizeof(uint32_t);
198
+ size_t bufferSize = count * elementSize;
199
+
200
+ SparseBufferRef buf = (SparseBufferRef)calloc(1, sizeof(struct SparseBuffer));
201
+ buf->buffer = [ctx->device newBufferWithLength:bufferSize options:MTLResourceStorageModeShared];
202
+ buf->count = count;
203
+ buf->dtype = dtype;
204
+ buf->ctx = ctx;
205
+
206
+ // Zero the buffer
207
+ memset(buf->buffer.contents, 0, bufferSize);
208
+
209
+ return buf;
210
+ }
211
+
212
+ SparseBufferRef sparse_full(SparseContextRef ctx, uint32_t count, float value) {
213
+ if (!ctx || count == 0) return NULL;
214
+
215
+ SparseBufferRef buf = sparse_zeros(ctx, count, SPARSE_FLOAT32);
216
+ if (!buf) return NULL;
217
+
218
+ // Use GPU to fill
219
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
220
+ id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
221
+
222
+ [encoder setComputePipelineState:ctx->fillFloatPipeline];
223
+ [encoder setBuffer:buf->buffer offset:0 atIndex:0];
224
+ [encoder setBytes:&value length:sizeof(float) atIndex:1];
225
+ [encoder setBytes:&count length:sizeof(uint32_t) atIndex:2];
226
+
227
+ NSUInteger threadGroupSize = 256;
228
+ MTLSize gridSize = MTLSizeMake(count, 1, 1);
229
+ MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
230
+
231
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
232
+ [encoder endEncoding];
233
+ [commandBuffer commit];
234
+ [commandBuffer waitUntilCompleted];
235
+
236
+ return buf;
237
+ }
238
+
239
+ SparseBufferRef sparse_from_float(SparseContextRef ctx, const float* data, uint32_t count) {
240
+ if (!ctx || !data || count == 0) return NULL;
241
+
242
+ SparseBufferRef buf = (SparseBufferRef)calloc(1, sizeof(struct SparseBuffer));
243
+ buf->buffer = [ctx->device newBufferWithBytes:data
244
+ length:count * sizeof(float)
245
+ options:MTLResourceStorageModeShared];
246
+ buf->count = count;
247
+ buf->dtype = SPARSE_FLOAT32;
248
+ buf->ctx = ctx;
249
+
250
+ return buf;
251
+ }
252
+
253
+ SparseBufferRef sparse_from_uint(SparseContextRef ctx, const uint32_t* data, uint32_t count) {
254
+ if (!ctx || !data || count == 0) return NULL;
255
+
256
+ SparseBufferRef buf = (SparseBufferRef)calloc(1, sizeof(struct SparseBuffer));
257
+ buf->buffer = [ctx->device newBufferWithBytes:data
258
+ length:count * sizeof(uint32_t)
259
+ options:MTLResourceStorageModeShared];
260
+ buf->count = count;
261
+ buf->dtype = SPARSE_UINT32;
262
+ buf->ctx = ctx;
263
+
264
+ return buf;
265
+ }
266
+
267
+ void sparse_to_float(SparseBufferRef buf, float* out, uint32_t count) {
268
+ if (!buf || !out) return;
269
+ uint32_t copyCount = (count < buf->count) ? count : buf->count;
270
+ memcpy(out, buf->buffer.contents, copyCount * sizeof(float));
271
+ }
272
+
273
+ void sparse_to_uint(SparseBufferRef buf, uint32_t* out, uint32_t count) {
274
+ if (!buf || !out) return;
275
+ uint32_t copyCount = (count < buf->count) ? count : buf->count;
276
+ memcpy(out, buf->buffer.contents, copyCount * sizeof(uint32_t));
277
+ }
278
+
279
+ uint32_t sparse_buffer_count(SparseBufferRef buf) {
280
+ return buf ? buf->count : 0;
281
+ }
282
+
283
+ SparseDataType sparse_buffer_dtype(SparseBufferRef buf) {
284
+ return buf ? buf->dtype : SPARSE_FLOAT32;
285
+ }
286
+
287
+ void sparse_buffer_free(SparseBufferRef buf) {
288
+ if (!buf) return;
289
+ // ARC handles MTLBuffer release
290
+ free(buf);
291
+ }
292
+
293
+ // ============================================================================
294
+ // CORE OPERATIONS
295
+ // ============================================================================
296
+
297
+ void sparse_scatter_add(
298
+ SparseContextRef ctx,
299
+ SparseBufferRef target,
300
+ SparseBufferRef indices,
301
+ SparseBufferRef values,
302
+ uint32_t count
303
+ ) {
304
+ if (!ctx || !target || !indices || !values || count == 0) return;
305
+
306
+ @autoreleasepool {
307
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
308
+ id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
309
+
310
+ [encoder setComputePipelineState:ctx->scatterAddPipeline];
311
+ [encoder setBuffer:target->buffer offset:0 atIndex:0];
312
+ [encoder setBuffer:indices->buffer offset:0 atIndex:1];
313
+ [encoder setBuffer:values->buffer offset:0 atIndex:2];
314
+ [encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
315
+
316
+ NSUInteger threadGroupSize = 256;
317
+ MTLSize gridSize = MTLSizeMake(count, 1, 1);
318
+ MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
319
+
320
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
321
+ [encoder endEncoding];
322
+ [commandBuffer commit];
323
+ [commandBuffer waitUntilCompleted];
324
+ }
325
+ }
326
+
327
+ SparseBufferRef sparse_gather(
328
+ SparseContextRef ctx,
329
+ SparseBufferRef source,
330
+ SparseBufferRef indices,
331
+ uint32_t count
332
+ ) {
333
+ if (!ctx || !source || !indices || count == 0) return NULL;
334
+
335
+ SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
336
+
337
+ @autoreleasepool {
338
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
339
+ id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
340
+
341
+ [encoder setComputePipelineState:ctx->gatherPipeline];
342
+ [encoder setBuffer:source->buffer offset:0 atIndex:0];
343
+ [encoder setBuffer:indices->buffer offset:0 atIndex:1];
344
+ [encoder setBuffer:result->buffer offset:0 atIndex:2];
345
+ [encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
346
+
347
+ NSUInteger threadGroupSize = 256;
348
+ MTLSize gridSize = MTLSizeMake(count, 1, 1);
349
+ MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
350
+
351
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
352
+ [encoder endEncoding];
353
+ [commandBuffer commit];
354
+ [commandBuffer waitUntilCompleted];
355
+ }
356
+
357
+ return result;
358
+ }
359
+
360
+ SparseBufferRef sparse_gather_bool(
361
+ SparseContextRef ctx,
362
+ SparseBufferRef source,
363
+ SparseBufferRef indices,
364
+ uint32_t count
365
+ ) {
366
+ if (!ctx || !source || !indices || count == 0) return NULL;
367
+
368
+ SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_UINT32);
369
+
370
+ @autoreleasepool {
371
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
372
+ id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
373
+
374
+ [encoder setComputePipelineState:ctx->gatherBoolPipeline];
375
+ [encoder setBuffer:source->buffer offset:0 atIndex:0];
376
+ [encoder setBuffer:indices->buffer offset:0 atIndex:1];
377
+ [encoder setBuffer:result->buffer offset:0 atIndex:2];
378
+ [encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
379
+
380
+ NSUInteger threadGroupSize = 256;
381
+ MTLSize gridSize = MTLSizeMake(count, 1, 1);
382
+ MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
383
+
384
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
385
+ [encoder endEncoding];
386
+ [commandBuffer commit];
387
+ [commandBuffer waitUntilCompleted];
388
+ }
389
+
390
+ return result;
391
+ }
392
+
393
+ // ============================================================================
394
+ // ELEMENT-WISE OPERATIONS
395
+ // ============================================================================
396
+
397
+ SparseBufferRef sparse_add(SparseContextRef ctx, SparseBufferRef a, SparseBufferRef b) {
398
+ if (!ctx || !a || !b) return NULL;
399
+ uint32_t count = a->count;
400
+
401
+ SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
402
+
403
+ @autoreleasepool {
404
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
405
+ id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
406
+
407
+ [encoder setComputePipelineState:ctx->addArraysPipeline];
408
+ [encoder setBuffer:a->buffer offset:0 atIndex:0];
409
+ [encoder setBuffer:b->buffer offset:0 atIndex:1];
410
+ [encoder setBuffer:result->buffer offset:0 atIndex:2];
411
+ [encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
412
+
413
+ NSUInteger threadGroupSize = 256;
414
+ MTLSize gridSize = MTLSizeMake(count, 1, 1);
415
+ MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
416
+
417
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
418
+ [encoder endEncoding];
419
+ [commandBuffer commit];
420
+ [commandBuffer waitUntilCompleted];
421
+ }
422
+
423
+ return result;
424
+ }
425
+
426
+ SparseBufferRef sparse_add_scalar(SparseContextRef ctx, SparseBufferRef a, float scalar) {
427
+ if (!ctx || !a) return NULL;
428
+ uint32_t count = a->count;
429
+
430
+ SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
431
+
432
+ @autoreleasepool {
433
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
434
+ id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
435
+
436
+ [encoder setComputePipelineState:ctx->addScalarPipeline];
437
+ [encoder setBuffer:a->buffer offset:0 atIndex:0];
438
+ [encoder setBytes:&scalar length:sizeof(float) atIndex:1];
439
+ [encoder setBuffer:result->buffer offset:0 atIndex:2];
440
+ [encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
441
+
442
+ NSUInteger threadGroupSize = 256;
443
+ MTLSize gridSize = MTLSizeMake(count, 1, 1);
444
+ MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
445
+
446
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
447
+ [encoder endEncoding];
448
+ [commandBuffer commit];
449
+ [commandBuffer waitUntilCompleted];
450
+ }
451
+
452
+ return result;
453
+ }
454
+
455
+ SparseBufferRef sparse_multiply(SparseContextRef ctx, SparseBufferRef a, SparseBufferRef b) {
456
+ if (!ctx || !a || !b) return NULL;
457
+ uint32_t count = a->count;
458
+
459
+ SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
460
+
461
+ @autoreleasepool {
462
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
463
+ id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
464
+
465
+ [encoder setComputePipelineState:ctx->multiplyArraysPipeline];
466
+ [encoder setBuffer:a->buffer offset:0 atIndex:0];
467
+ [encoder setBuffer:b->buffer offset:0 atIndex:1];
468
+ [encoder setBuffer:result->buffer offset:0 atIndex:2];
469
+ [encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
470
+
471
+ NSUInteger threadGroupSize = 256;
472
+ MTLSize gridSize = MTLSizeMake(count, 1, 1);
473
+ MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
474
+
475
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
476
+ [encoder endEncoding];
477
+ [commandBuffer commit];
478
+ [commandBuffer waitUntilCompleted];
479
+ }
480
+
481
+ return result;
482
+ }
483
+
484
+ SparseBufferRef sparse_multiply_scalar(SparseContextRef ctx, SparseBufferRef a, float scalar) {
485
+ if (!ctx || !a) return NULL;
486
+ uint32_t count = a->count;
487
+
488
+ SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
489
+
490
+ @autoreleasepool {
491
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
492
+ id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
493
+
494
+ [encoder setComputePipelineState:ctx->multiplyScalarPipeline];
495
+ [encoder setBuffer:a->buffer offset:0 atIndex:0];
496
+ [encoder setBytes:&scalar length:sizeof(float) atIndex:1];
497
+ [encoder setBuffer:result->buffer offset:0 atIndex:2];
498
+ [encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
499
+
500
+ NSUInteger threadGroupSize = 256;
501
+ MTLSize gridSize = MTLSizeMake(count, 1, 1);
502
+ MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
503
+
504
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
505
+ [encoder endEncoding];
506
+ [commandBuffer commit];
507
+ [commandBuffer waitUntilCompleted];
508
+ }
509
+
510
+ return result;
511
+ }
512
+
513
+ SparseBufferRef sparse_square(SparseContextRef ctx, SparseBufferRef a) {
514
+ if (!ctx || !a) return NULL;
515
+ uint32_t count = a->count;
516
+
517
+ SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
518
+
519
+ @autoreleasepool {
520
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
521
+ id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
522
+
523
+ [encoder setComputePipelineState:ctx->squarePipeline];
524
+ [encoder setBuffer:a->buffer offset:0 atIndex:0];
525
+ [encoder setBuffer:result->buffer offset:0 atIndex:1];
526
+ [encoder setBytes:&count length:sizeof(uint32_t) atIndex:2];
527
+
528
+ NSUInteger threadGroupSize = 256;
529
+ MTLSize gridSize = MTLSizeMake(count, 1, 1);
530
+ MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
531
+
532
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
533
+ [encoder endEncoding];
534
+ [commandBuffer commit];
535
+ [commandBuffer waitUntilCompleted];
536
+ }
537
+
538
+ return result;
539
+ }
540
+
541
+ // ============================================================================
542
+ // CONDITIONAL OPERATIONS
543
+ // ============================================================================
544
+
545
+ SparseBufferRef sparse_greater_equal(SparseContextRef ctx, SparseBufferRef a, float threshold) {
546
+ if (!ctx || !a) return NULL;
547
+ uint32_t count = a->count;
548
+
549
+ SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_UINT32);
550
+
551
+ @autoreleasepool {
552
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
553
+ id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
554
+
555
+ [encoder setComputePipelineState:ctx->greaterEqualPipeline];
556
+ [encoder setBuffer:a->buffer offset:0 atIndex:0];
557
+ [encoder setBytes:&threshold length:sizeof(float) atIndex:1];
558
+ [encoder setBuffer:result->buffer offset:0 atIndex:2];
559
+ [encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
560
+
561
+ NSUInteger threadGroupSize = 256;
562
+ MTLSize gridSize = MTLSizeMake(count, 1, 1);
563
+ MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
564
+
565
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
566
+ [encoder endEncoding];
567
+ [commandBuffer commit];
568
+ [commandBuffer waitUntilCompleted];
569
+ }
570
+
571
+ return result;
572
+ }
573
+
574
+ SparseBufferRef sparse_where(
575
+ SparseContextRef ctx,
576
+ SparseBufferRef condition,
577
+ SparseBufferRef if_true,
578
+ SparseBufferRef if_false
579
+ ) {
580
+ if (!ctx || !condition || !if_true || !if_false) return NULL;
581
+ uint32_t count = condition->count;
582
+
583
+ SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
584
+
585
+ @autoreleasepool {
586
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
587
+ id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
588
+
589
+ [encoder setComputePipelineState:ctx->whereSelectPipeline];
590
+ [encoder setBuffer:condition->buffer offset:0 atIndex:0];
591
+ [encoder setBuffer:if_true->buffer offset:0 atIndex:1];
592
+ [encoder setBuffer:if_false->buffer offset:0 atIndex:2];
593
+ [encoder setBuffer:result->buffer offset:0 atIndex:3];
594
+ [encoder setBytes:&count length:sizeof(uint32_t) atIndex:4];
595
+
596
+ NSUInteger threadGroupSize = 256;
597
+ MTLSize gridSize = MTLSizeMake(count, 1, 1);
598
+ MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
599
+
600
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
601
+ [encoder endEncoding];
602
+ [commandBuffer commit];
603
+ [commandBuffer waitUntilCompleted];
604
+ }
605
+
606
+ return result;
607
+ }
608
+
609
+ SparseBufferRef sparse_where_scalar(
610
+ SparseContextRef ctx,
611
+ SparseBufferRef condition,
612
+ float if_true,
613
+ float if_false
614
+ ) {
615
+ if (!ctx || !condition) return NULL;
616
+ uint32_t count = condition->count;
617
+
618
+ SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
619
+
620
+ @autoreleasepool {
621
+ id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
622
+ id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
623
+
624
+ [encoder setComputePipelineState:ctx->whereScalarPipeline];
625
+ [encoder setBuffer:condition->buffer offset:0 atIndex:0];
626
+ [encoder setBytes:&if_true length:sizeof(float) atIndex:1];
627
+ [encoder setBytes:&if_false length:sizeof(float) atIndex:2];
628
+ [encoder setBuffer:result->buffer offset:0 atIndex:3];
629
+ [encoder setBytes:&count length:sizeof(uint32_t) atIndex:4];
630
+
631
+ NSUInteger threadGroupSize = 256;
632
+ MTLSize gridSize = MTLSizeMake(count, 1, 1);
633
+ MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
634
+
635
+ [encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
636
+ [encoder endEncoding];
637
+ [commandBuffer commit];
638
+ [commandBuffer waitUntilCompleted];
639
+ }
640
+
641
+ return result;
642
+ }
643
+
644
+ // ============================================================================
645
+ // REDUCTION OPERATIONS
646
+ // ============================================================================
647
+
648
+ float sparse_sum(SparseContextRef ctx, SparseBufferRef a) {
649
+ if (!ctx || !a) return 0.0f;
650
+
651
+ // For simplicity, read back to CPU and sum
652
+ // TODO: Use GPU reduction for large arrays
653
+ float* data = (float*)malloc(a->count * sizeof(float));
654
+ sparse_to_float(a, data, a->count);
655
+
656
+ float sum = 0.0f;
657
+ for (uint32_t i = 0; i < a->count; i++) {
658
+ sum += data[i];
659
+ }
660
+
661
+ free(data);
662
+ return sum;
663
+ }
664
+
665
+ uint32_t sparse_sum_bool(SparseContextRef ctx, SparseBufferRef a) {
666
+ if (!ctx || !a) return 0;
667
+
668
+ uint32_t* data = (uint32_t*)malloc(a->count * sizeof(uint32_t));
669
+ sparse_to_uint(a, data, a->count);
670
+
671
+ uint32_t sum = 0;
672
+ for (uint32_t i = 0; i < a->count; i++) {
673
+ sum += data[i];
674
+ }
675
+
676
+ free(data);
677
+ return sum;
678
+ }
679
+
680
+ // ============================================================================
681
+ // RANDOM OPERATIONS
682
+ // ============================================================================
683
+
684
+ SparseBufferRef sparse_random_uniform(
685
+ SparseContextRef ctx,
686
+ uint32_t count,
687
+ float low,
688
+ float high
689
+ ) {
690
+ if (!ctx || count == 0) return NULL;
691
+
692
+ // Generate on CPU (Metal random is complex)
693
+ float* data = (float*)malloc(count * sizeof(float));
694
+ float range = high - low;
695
+
696
+ for (uint32_t i = 0; i < count; i++) {
697
+ data[i] = low + ((float)arc4random() / (float)UINT32_MAX) * range;
698
+ }
699
+
700
+ SparseBufferRef result = sparse_from_float(ctx, data, count);
701
+ free(data);
702
+
703
+ return result;
704
+ }
705
+
706
+ SparseBufferRef sparse_random_normal(
707
+ SparseContextRef ctx,
708
+ uint32_t count,
709
+ float mean,
710
+ float std
711
+ ) {
712
+ if (!ctx || count == 0) return NULL;
713
+
714
+ // Box-Muller transform on CPU
715
+ float* data = (float*)malloc(count * sizeof(float));
716
+
717
+ for (uint32_t i = 0; i < count; i += 2) {
718
+ float u1 = ((float)arc4random() + 1.0f) / ((float)UINT32_MAX + 1.0f);
719
+ float u2 = ((float)arc4random() + 1.0f) / ((float)UINT32_MAX + 1.0f);
720
+
721
+ float z0 = sqrtf(-2.0f * logf(u1)) * cosf(2.0f * M_PI * u2);
722
+ float z1 = sqrtf(-2.0f * logf(u1)) * sinf(2.0f * M_PI * u2);
723
+
724
+ data[i] = mean + z0 * std;
725
+ if (i + 1 < count) {
726
+ data[i + 1] = mean + z1 * std;
727
+ }
728
+ }
729
+
730
+ SparseBufferRef result = sparse_from_float(ctx, data, count);
731
+ free(data);
732
+
733
+ return result;
734
+ }