@sparkleideas/plugins 3.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. package/README.md +401 -0
  2. package/__tests__/collection-manager.test.ts +332 -0
  3. package/__tests__/dependency-graph.test.ts +434 -0
  4. package/__tests__/enhanced-plugin-registry.test.ts +488 -0
  5. package/__tests__/plugin-registry.test.ts +368 -0
  6. package/__tests__/ruvector-bridge.test.ts +2429 -0
  7. package/__tests__/ruvector-integration.test.ts +1602 -0
  8. package/__tests__/ruvector-migrations.test.ts +1099 -0
  9. package/__tests__/ruvector-quantization.test.ts +846 -0
  10. package/__tests__/ruvector-streaming.test.ts +1088 -0
  11. package/__tests__/sdk.test.ts +325 -0
  12. package/__tests__/security.test.ts +348 -0
  13. package/__tests__/utils/ruvector-test-utils.ts +860 -0
  14. package/examples/plugin-creator/index.ts +636 -0
  15. package/examples/plugin-creator/plugin-creator.test.ts +312 -0
  16. package/examples/ruvector/README.md +288 -0
  17. package/examples/ruvector/attention-patterns.ts +394 -0
  18. package/examples/ruvector/basic-usage.ts +288 -0
  19. package/examples/ruvector/docker-compose.yml +75 -0
  20. package/examples/ruvector/gnn-analysis.ts +501 -0
  21. package/examples/ruvector/hyperbolic-hierarchies.ts +557 -0
  22. package/examples/ruvector/init-db.sql +119 -0
  23. package/examples/ruvector/quantization.ts +680 -0
  24. package/examples/ruvector/self-learning.ts +447 -0
  25. package/examples/ruvector/semantic-search.ts +576 -0
  26. package/examples/ruvector/streaming-large-data.ts +507 -0
  27. package/examples/ruvector/transactions.ts +594 -0
  28. package/examples/ruvector-plugins/hook-pattern-library.ts +486 -0
  29. package/examples/ruvector-plugins/index.ts +79 -0
  30. package/examples/ruvector-plugins/intent-router.ts +354 -0
  31. package/examples/ruvector-plugins/mcp-tool-optimizer.ts +424 -0
  32. package/examples/ruvector-plugins/reasoning-bank.ts +657 -0
  33. package/examples/ruvector-plugins/ruvector-plugins.test.ts +518 -0
  34. package/examples/ruvector-plugins/semantic-code-search.ts +498 -0
  35. package/examples/ruvector-plugins/shared/index.ts +20 -0
  36. package/examples/ruvector-plugins/shared/vector-utils.ts +257 -0
  37. package/examples/ruvector-plugins/sona-learning.ts +445 -0
  38. package/package.json +97 -0
  39. package/src/collections/collection-manager.ts +661 -0
  40. package/src/collections/index.ts +56 -0
  41. package/src/collections/official/index.ts +1040 -0
  42. package/src/core/base-plugin.ts +416 -0
  43. package/src/core/plugin-interface.ts +215 -0
  44. package/src/hooks/index.ts +685 -0
  45. package/src/index.ts +378 -0
  46. package/src/integrations/agentic-flow.ts +743 -0
  47. package/src/integrations/index.ts +88 -0
  48. package/src/integrations/ruvector/ARCHITECTURE.md +1245 -0
  49. package/src/integrations/ruvector/attention-advanced.ts +1040 -0
  50. package/src/integrations/ruvector/attention-executor.ts +782 -0
  51. package/src/integrations/ruvector/attention-mechanisms.ts +757 -0
  52. package/src/integrations/ruvector/attention.ts +1063 -0
  53. package/src/integrations/ruvector/gnn.ts +3050 -0
  54. package/src/integrations/ruvector/hyperbolic.ts +1948 -0
  55. package/src/integrations/ruvector/index.ts +394 -0
  56. package/src/integrations/ruvector/migrations/001_create_extension.sql +135 -0
  57. package/src/integrations/ruvector/migrations/002_create_vector_tables.sql +259 -0
  58. package/src/integrations/ruvector/migrations/003_create_indices.sql +328 -0
  59. package/src/integrations/ruvector/migrations/004_create_functions.sql +598 -0
  60. package/src/integrations/ruvector/migrations/005_create_attention_functions.sql +654 -0
  61. package/src/integrations/ruvector/migrations/006_create_gnn_functions.sql +728 -0
  62. package/src/integrations/ruvector/migrations/007_create_hyperbolic_functions.sql +762 -0
  63. package/src/integrations/ruvector/migrations/index.ts +35 -0
  64. package/src/integrations/ruvector/migrations/migrations.ts +647 -0
  65. package/src/integrations/ruvector/quantization.ts +2036 -0
  66. package/src/integrations/ruvector/ruvector-bridge.ts +2000 -0
  67. package/src/integrations/ruvector/self-learning.ts +2376 -0
  68. package/src/integrations/ruvector/streaming.ts +1737 -0
  69. package/src/integrations/ruvector/types.ts +1945 -0
  70. package/src/providers/index.ts +643 -0
  71. package/src/registry/dependency-graph.ts +568 -0
  72. package/src/registry/enhanced-plugin-registry.ts +994 -0
  73. package/src/registry/plugin-registry.ts +604 -0
  74. package/src/sdk/index.ts +563 -0
  75. package/src/security/index.ts +594 -0
  76. package/src/types/index.ts +446 -0
  77. package/src/workers/index.ts +700 -0
  78. package/tmp.json +0 -0
  79. package/tsconfig.json +25 -0
  80. package/vitest.config.ts +23 -0
@@ -0,0 +1,3050 @@
1
+ /**
2
+ * RuVector PostgreSQL Bridge - Graph Neural Network (GNN) Layers Module
3
+ *
4
+ * Comprehensive GNN support for RuVector PostgreSQL vector database integration.
5
+ * Implements GCN, GAT, GraphSAGE, GIN, MPNN, EdgeConv, and more.
6
+ *
7
+ * @module @sparkleideas/plugins/integrations/ruvector/gnn
8
+ * @version 1.0.0
9
+ */
10
+
11
+ import type {
12
+ GNNLayerType,
13
+ GNNLayer,
14
+ GraphData,
15
+ GNNOutput,
16
+ GNNAggregation,
17
+ GNNStats,
18
+ ActivationFunction,
19
+ } from './types.js';
20
+
21
+ // ============================================================================
22
+ // Constants and Configuration
23
+ // ============================================================================
24
+
25
+ /**
26
+ * Default configuration values for GNN layers.
27
+ */
28
+ export const GNN_DEFAULTS = {
29
+ dropout: 0.0,
30
+ addSelfLoops: true,
31
+ normalize: true,
32
+ useBias: true,
33
+ activation: 'relu' as ActivationFunction,
34
+ aggregation: 'mean' as GNNAggregation,
35
+ numHeads: 1,
36
+ negativeSlope: 0.2, // For LeakyReLU in GAT
37
+ eps: 0.0, // For GIN
38
+ sampleSize: 10, // For GraphSAGE
39
+ k: 20, // For EdgeConv k-NN
40
+ } as const;
41
+
42
+ /**
43
+ * SQL function mapping for GNN operations.
44
+ */
45
+ export const GNN_SQL_FUNCTIONS = {
46
+ gcn: 'ruvector.gcn_layer',
47
+ gat: 'ruvector.gat_layer',
48
+ gat_v2: 'ruvector.gat_v2_layer',
49
+ sage: 'ruvector.sage_layer',
50
+ gin: 'ruvector.gin_layer',
51
+ mpnn: 'ruvector.mpnn_layer',
52
+ edge_conv: 'ruvector.edge_conv_layer',
53
+ point_conv: 'ruvector.point_conv_layer',
54
+ transformer: 'ruvector.graph_transformer_layer',
55
+ pna: 'ruvector.pna_layer',
56
+ film: 'ruvector.film_layer',
57
+ rgcn: 'ruvector.rgcn_layer',
58
+ hgt: 'ruvector.hgt_layer',
59
+ han: 'ruvector.han_layer',
60
+ metapath: 'ruvector.metapath_layer',
61
+ } as const;
62
+
63
+ // ============================================================================
64
+ // Core Interfaces
65
+ // ============================================================================
66
+
67
+ /**
68
+ * Node identifier type.
69
+ */
70
+ export type NodeId = string | number;
71
+
72
+ /**
73
+ * Node features representation.
74
+ */
75
+ export interface NodeFeatures {
76
+ /** Node IDs */
77
+ readonly ids: NodeId[];
78
+ /** Feature vectors [num_nodes, feature_dim] */
79
+ readonly features: number[][];
80
+ /** Optional node types for heterogeneous graphs */
81
+ readonly types?: string[];
82
+ /** Optional node labels */
83
+ readonly labels?: number[];
84
+ }
85
+
86
+ /**
87
+ * Edge features representation.
88
+ */
89
+ export interface EdgeFeatures {
90
+ /** Source node IDs */
91
+ readonly sources: NodeId[];
92
+ /** Target node IDs */
93
+ readonly targets: NodeId[];
94
+ /** Edge feature vectors [num_edges, edge_dim] (optional) */
95
+ readonly features?: number[][];
96
+ /** Edge weights (optional) */
97
+ readonly weights?: number[];
98
+ /** Edge types for heterogeneous graphs (optional) */
99
+ readonly types?: string[];
100
+ }
101
+
102
+ /**
103
+ * Message representation for message passing.
104
+ */
105
+ export interface Message {
106
+ /** Source node ID */
107
+ readonly source: NodeId;
108
+ /** Target node ID */
109
+ readonly target: NodeId;
110
+ /** Message vector */
111
+ readonly vector: number[];
112
+ /** Edge features (if applicable) */
113
+ readonly edgeFeatures?: number[];
114
+ /** Message weight */
115
+ readonly weight?: number;
116
+ }
117
+
118
+ /**
119
+ * Aggregation method type with extended options.
120
+ */
121
+ export type AggregationMethod =
122
+ | GNNAggregation
123
+ | 'concat'
124
+ | 'weighted_mean'
125
+ | 'multi_head';
126
+
127
+ /**
128
+ * Path representation for graph traversal.
129
+ */
130
+ export interface Path {
131
+ /** Ordered list of node IDs */
132
+ readonly nodes: NodeId[];
133
+ /** Total path weight/distance */
134
+ readonly weight: number;
135
+ /** Edge types along the path (for heterogeneous graphs) */
136
+ readonly edgeTypes?: string[];
137
+ }
138
+
139
+ /**
140
+ * Community detection result.
141
+ */
142
+ export interface Community {
143
+ /** Community identifier */
144
+ readonly id: number;
145
+ /** Member node IDs */
146
+ readonly members: NodeId[];
147
+ /** Community centroid (average features) */
148
+ readonly centroid?: number[];
149
+ /** Modularity score */
150
+ readonly modularity?: number;
151
+ /** Internal edge density */
152
+ readonly density?: number;
153
+ }
154
+
155
+ /**
156
+ * PageRank computation options.
157
+ */
158
+ export interface PageRankOptions {
159
+ /** Damping factor (default: 0.85) */
160
+ readonly damping?: number;
161
+ /** Maximum iterations (default: 100) */
162
+ readonly maxIterations?: number;
163
+ /** Convergence tolerance (default: 1e-6) */
164
+ readonly tolerance?: number;
165
+ /** Personalization vector (teleport probabilities) */
166
+ readonly personalization?: Map<NodeId, number>;
167
+ /** Whether to use weighted edges */
168
+ readonly weighted?: boolean;
169
+ }
170
+
171
+ /**
172
+ * Community detection options.
173
+ */
174
+ export interface CommunityOptions {
175
+ /** Detection algorithm */
176
+ readonly algorithm: 'louvain' | 'label_propagation' | 'girvan_newman' | 'spectral';
177
+ /** Resolution parameter (for Louvain) */
178
+ readonly resolution?: number;
179
+ /** Maximum iterations */
180
+ readonly maxIterations?: number;
181
+ /** Minimum community size */
182
+ readonly minSize?: number;
183
+ /** Random seed for reproducibility */
184
+ readonly seed?: number;
185
+ }
186
+
187
+ /**
188
+ * GNN layer configuration with validation.
189
+ */
190
+ export interface GNNLayerConfig extends GNNLayer {
191
+ /** Layer name/identifier */
192
+ readonly name?: string;
193
+ /** Whether to cache intermediate results */
194
+ readonly cache?: boolean;
195
+ /** Quantization bits for memory efficiency */
196
+ readonly quantizeBits?: 8 | 16 | 32;
197
+ }
198
+
199
+ /**
200
+ * Factory function type for creating GNN layers.
201
+ */
202
+ export type GNNLayerFactory = (config: GNNLayerConfig) => IGNNLayer;
203
+
204
+ /**
205
+ * Interface for GNN layer implementations.
206
+ */
207
+ export interface IGNNLayer {
208
+ /** Layer type */
209
+ readonly type: GNNLayerType;
210
+ /** Layer configuration */
211
+ readonly config: GNNLayerConfig;
212
+
213
+ /**
214
+ * Forward pass through the GNN layer.
215
+ * @param graph - Input graph data
216
+ * @returns Promise resolving to GNN output
217
+ */
218
+ forward(graph: GraphData): Promise<GNNOutput>;
219
+
220
+ /**
221
+ * Message passing step.
222
+ * @param nodes - Node features
223
+ * @param edges - Edge features
224
+ * @returns Promise resolving to updated node features
225
+ */
226
+ messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures>;
227
+
228
+ /**
229
+ * Aggregate messages using the specified method.
230
+ * @param messages - Array of messages to aggregate
231
+ * @param method - Aggregation method
232
+ * @returns Promise resolving to aggregated vector
233
+ */
234
+ aggregate(messages: Message[], method: AggregationMethod): Promise<number[]>;
235
+
236
+ /**
237
+ * Reset layer state (if stateful).
238
+ */
239
+ reset(): void;
240
+
241
+ /**
242
+ * Generate SQL for this layer.
243
+ * @param tableName - Target table name
244
+ * @param options - SQL generation options
245
+ * @returns SQL string
246
+ */
247
+ toSQL(tableName: string, options?: SQLGenerationOptions): string;
248
+ }
249
+
250
+ /**
251
+ * SQL generation options.
252
+ */
253
+ export interface SQLGenerationOptions {
254
+ /** Schema name */
255
+ readonly schema?: string;
256
+ /** Node features column */
257
+ readonly nodeColumn?: string;
258
+ /** Edge table name */
259
+ readonly edgeTable?: string;
260
+ /** Whether to use prepared statements */
261
+ readonly prepared?: boolean;
262
+ /** Parameter prefix for prepared statements */
263
+ readonly paramPrefix?: string;
264
+ }
265
+
266
+ // ============================================================================
267
+ // GNN Layer Registry
268
+ // ============================================================================
269
+
270
+ /**
271
+ * Registry for managing GNN layer types and factories.
272
+ *
273
+ * @example
274
+ * ```typescript
275
+ * const registry = new GNNLayerRegistry();
276
+ * registry.registerLayer('custom_gnn', CustomGNNFactory);
277
+ * const layer = registry.createLayer('gcn', { inputDim: 64, outputDim: 32 });
278
+ * ```
279
+ */
280
+ export class GNNLayerRegistry {
281
+ private readonly factories: Map<GNNLayerType | string, GNNLayerFactory> = new Map();
282
+ private readonly defaultConfigs: Map<GNNLayerType | string, Partial<GNNLayerConfig>> = new Map();
283
+
284
+ constructor() {
285
+ // Register built-in layer factories
286
+ this.registerBuiltinLayers();
287
+ }
288
+
289
+ /**
290
+ * Register a GNN layer factory.
291
+ * @param type - Layer type identifier
292
+ * @param factory - Factory function
293
+ * @param defaultConfig - Optional default configuration
294
+ */
295
+ registerLayer(
296
+ type: GNNLayerType | string,
297
+ factory: GNNLayerFactory,
298
+ defaultConfig?: Partial<GNNLayerConfig>
299
+ ): void {
300
+ this.factories.set(type, factory);
301
+ if (defaultConfig) {
302
+ this.defaultConfigs.set(type, defaultConfig);
303
+ }
304
+ }
305
+
306
+ /**
307
+ * Unregister a GNN layer factory.
308
+ * @param type - Layer type to remove
309
+ * @returns Whether the layer was removed
310
+ */
311
+ unregisterLayer(type: GNNLayerType | string): boolean {
312
+ this.defaultConfigs.delete(type);
313
+ return this.factories.delete(type);
314
+ }
315
+
316
+ /**
317
+ * Create a GNN layer instance.
318
+ * @param type - Layer type
319
+ * @param config - Layer configuration
320
+ * @returns IGNNLayer instance
321
+ * @throws Error if layer type is not registered
322
+ */
323
+ createLayer(type: GNNLayerType, config: Partial<GNNLayerConfig>): IGNNLayer {
324
+ const factory = this.factories.get(type);
325
+ if (!factory) {
326
+ throw new Error(`Unknown GNN layer type: ${type}. Available types: ${this.getLayerTypes().join(', ')}`);
327
+ }
328
+
329
+ const defaultConfig = this.defaultConfigs.get(type) ?? {};
330
+ const fullConfig: GNNLayerConfig = {
331
+ type,
332
+ inputDim: config.inputDim ?? 64,
333
+ outputDim: config.outputDim ?? 64,
334
+ dropout: config.dropout ?? defaultConfig.dropout ?? GNN_DEFAULTS.dropout,
335
+ aggregation: config.aggregation ?? defaultConfig.aggregation ?? GNN_DEFAULTS.aggregation,
336
+ addSelfLoops: config.addSelfLoops ?? defaultConfig.addSelfLoops ?? GNN_DEFAULTS.addSelfLoops,
337
+ normalize: config.normalize ?? defaultConfig.normalize ?? GNN_DEFAULTS.normalize,
338
+ useBias: config.useBias ?? defaultConfig.useBias ?? GNN_DEFAULTS.useBias,
339
+ activation: config.activation ?? defaultConfig.activation ?? GNN_DEFAULTS.activation,
340
+ ...config,
341
+ };
342
+
343
+ return factory(fullConfig);
344
+ }
345
+
346
+ /**
347
+ * Check if a layer type is registered.
348
+ * @param type - Layer type to check
349
+ * @returns Whether the layer is registered
350
+ */
351
+ hasLayer(type: GNNLayerType | string): boolean {
352
+ return this.factories.has(type);
353
+ }
354
+
355
+ /**
356
+ * Get all registered layer types.
357
+ * @returns Array of layer type identifiers
358
+ */
359
+ getLayerTypes(): string[] {
360
+ return Array.from(this.factories.keys());
361
+ }
362
+
363
+ /**
364
+ * Get default configuration for a layer type.
365
+ * @param type - Layer type
366
+ * @returns Default configuration or undefined
367
+ */
368
+ getDefaultConfig(type: GNNLayerType | string): Partial<GNNLayerConfig> | undefined {
369
+ return this.defaultConfigs.get(type);
370
+ }
371
+
372
+ /**
373
+ * Register all built-in GNN layer factories.
374
+ */
375
+ private registerBuiltinLayers(): void {
376
+ // GCN - Graph Convolutional Network
377
+ this.registerLayer('gcn', (config) => new GCNLayer(config), {
378
+ normalize: true,
379
+ addSelfLoops: true,
380
+ });
381
+
382
+ // GAT - Graph Attention Network
383
+ this.registerLayer('gat', (config) => new GATLayer(config), {
384
+ numHeads: 8,
385
+ params: { negativeSlope: 0.2, concat: true },
386
+ });
387
+
388
+ // GAT v2 - Improved Graph Attention
389
+ this.registerLayer('gat_v2', (config) => new GATv2Layer(config), {
390
+ numHeads: 8,
391
+ params: { negativeSlope: 0.2, concat: true },
392
+ });
393
+
394
+ // GraphSAGE - Sampling and Aggregation
395
+ this.registerLayer('sage', (config) => new GraphSAGELayer(config), {
396
+ aggregation: 'mean',
397
+ params: { sampleSize: 10, samplingStrategy: 'uniform' },
398
+ });
399
+
400
+ // GIN - Graph Isomorphism Network
401
+ this.registerLayer('gin', (config) => new GINLayer(config), {
402
+ params: { eps: 0, trainEps: false },
403
+ });
404
+
405
+ // MPNN - Message Passing Neural Network
406
+ this.registerLayer('mpnn', (config) => new MPNNLayer(config), {
407
+ aggregation: 'sum',
408
+ });
409
+
410
+ // EdgeConv - Dynamic Edge Convolution
411
+ this.registerLayer('edge_conv', (config) => new EdgeConvLayer(config), {
412
+ params: { k: 20, dynamic: true },
413
+ });
414
+
415
+ // Point Convolution
416
+ this.registerLayer('point_conv', (config) => new PointConvLayer(config), {
417
+ params: { k: 16 },
418
+ });
419
+
420
+ // Graph Transformer
421
+ this.registerLayer('transformer', (config) => new GraphTransformerLayer(config), {
422
+ numHeads: 8,
423
+ params: { numLayers: 1 },
424
+ });
425
+
426
+ // PNA - Principal Neighbourhood Aggregation
427
+ this.registerLayer('pna', (config) => new PNALayer(config), {
428
+ params: {
429
+ aggregators: ['mean', 'sum', 'max', 'min'],
430
+ scalers: ['identity', 'amplification', 'attenuation'],
431
+ },
432
+ });
433
+
434
+ // FiLM - Feature-wise Linear Modulation
435
+ this.registerLayer('film', (config) => new FiLMLayer(config), {});
436
+
437
+ // RGCN - Relational Graph Convolutional Network
438
+ this.registerLayer('rgcn', (config) => new RGCNLayer(config), {
439
+ params: { numRelations: 1 },
440
+ });
441
+
442
+ // HGT - Heterogeneous Graph Transformer
443
+ this.registerLayer('hgt', (config) => new HGTLayer(config), {
444
+ numHeads: 8,
445
+ });
446
+
447
+ // HAN - Heterogeneous Attention Network
448
+ this.registerLayer('han', (config) => new HANLayer(config), {
449
+ numHeads: 8,
450
+ });
451
+
452
+ // MetaPath aggregation
453
+ this.registerLayer('metapath', (config) => new MetaPathLayer(config), {
454
+ params: { metapaths: [] },
455
+ });
456
+ }
457
+ }
458
+
459
+ // ============================================================================
460
+ // Base GNN Layer Implementation
461
+ // ============================================================================
462
+
463
+ /**
464
+ * Abstract base class for GNN layer implementations.
465
+ */
466
+ export abstract class BaseGNNLayer implements IGNNLayer {
467
+ readonly type: GNNLayerType;
468
+ readonly config: GNNLayerConfig;
469
+
470
+ constructor(config: GNNLayerConfig) {
471
+ this.type = config.type;
472
+ this.config = config;
473
+ this.validateConfig();
474
+ }
475
+
476
+ /**
477
+ * Validate layer configuration.
478
+ * @throws Error if configuration is invalid
479
+ */
480
+ protected validateConfig(): void {
481
+ if (this.config.inputDim <= 0) {
482
+ throw new Error(`Invalid inputDim: ${this.config.inputDim}. Must be positive.`);
483
+ }
484
+ if (this.config.outputDim <= 0) {
485
+ throw new Error(`Invalid outputDim: ${this.config.outputDim}. Must be positive.`);
486
+ }
487
+ if (this.config.dropout !== undefined && (this.config.dropout < 0 || this.config.dropout > 1)) {
488
+ throw new Error(`Invalid dropout: ${this.config.dropout}. Must be between 0 and 1.`);
489
+ }
490
+ if (this.config.numHeads !== undefined && this.config.numHeads <= 0) {
491
+ throw new Error(`Invalid numHeads: ${this.config.numHeads}. Must be positive.`);
492
+ }
493
+ }
494
+
495
+ abstract forward(graph: GraphData): Promise<GNNOutput>;
496
+ abstract messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures>;
497
+
498
+ /**
499
+ * Aggregate messages using the specified method.
500
+ */
501
+ async aggregate(messages: Message[], method: AggregationMethod): Promise<number[]> {
502
+ if (messages.length === 0) {
503
+ return new Array(this.config.outputDim).fill(0);
504
+ }
505
+
506
+ const vectors = messages.map((m) => m.vector);
507
+ const weights = messages.map((m) => m.weight ?? 1);
508
+
509
+ switch (method) {
510
+ case 'sum':
511
+ return this.aggregateSum(vectors);
512
+ case 'mean':
513
+ return this.aggregateMean(vectors);
514
+ case 'max':
515
+ return this.aggregateMax(vectors);
516
+ case 'min':
517
+ return this.aggregateMin(vectors);
518
+ case 'attention':
519
+ return this.aggregateAttention(vectors, weights);
520
+ case 'weighted_mean':
521
+ return this.aggregateWeightedMean(vectors, weights);
522
+ case 'softmax':
523
+ return this.aggregateSoftmax(vectors);
524
+ case 'power_mean':
525
+ return this.aggregatePowerMean(vectors, 2);
526
+ case 'std':
527
+ return this.aggregateStd(vectors);
528
+ case 'var':
529
+ return this.aggregateVar(vectors);
530
+ case 'concat':
531
+ return this.aggregateConcat(vectors);
532
+ case 'lstm':
533
+ return this.aggregateLSTM(vectors);
534
+ case 'multi_head':
535
+ return this.aggregateMultiHead(vectors);
536
+ default:
537
+ return this.aggregateMean(vectors);
538
+ }
539
+ }
540
+
541
+ /**
542
+ * Reset layer state.
543
+ */
544
+ reset(): void {
545
+ // Override in stateful layers
546
+ }
547
+
548
+ /**
549
+ * Generate SQL for this layer.
550
+ */
551
+ toSQL(tableName: string, options: SQLGenerationOptions = {}): string {
552
+ const schema = options.schema ?? 'public';
553
+ const nodeColumn = options.nodeColumn ?? 'embedding';
554
+ const edgeTable = options.edgeTable ?? `${tableName}_edges`;
555
+ const sqlFunction = GNN_SQL_FUNCTIONS[this.type] ?? 'ruvector.gnn_layer';
556
+
557
+ const configJson = JSON.stringify({
558
+ type: this.type,
559
+ input_dim: this.config.inputDim,
560
+ output_dim: this.config.outputDim,
561
+ num_heads: this.config.numHeads,
562
+ dropout: this.config.dropout,
563
+ aggregation: this.config.aggregation,
564
+ add_self_loops: this.config.addSelfLoops,
565
+ normalize: this.config.normalize,
566
+ use_bias: this.config.useBias,
567
+ activation: this.config.activation,
568
+ params: this.config.params,
569
+ });
570
+
571
+ if (options.prepared) {
572
+ const prefix = options.paramPrefix ?? '$';
573
+ return `
574
+ SELECT ${sqlFunction}(
575
+ (SELECT array_agg(${nodeColumn}) FROM "${schema}"."${tableName}"),
576
+ (SELECT array_agg(ARRAY[source_id, target_id]) FROM "${schema}"."${edgeTable}"),
577
+ ${prefix}1::jsonb
578
+ );`.trim();
579
+ }
580
+
581
+ return `
582
+ SELECT ${sqlFunction}(
583
+ (SELECT array_agg(${nodeColumn}) FROM "${schema}"."${tableName}"),
584
+ (SELECT array_agg(ARRAY[source_id, target_id]) FROM "${schema}"."${edgeTable}"),
585
+ '${configJson}'::jsonb
586
+ );`.trim();
587
+ }
588
+
589
+ // Aggregation implementations
590
+ protected aggregateSum(vectors: number[][]): number[] {
591
+ const dim = vectors[0]?.length ?? 0;
592
+ const result = new Array(dim).fill(0);
593
+ for (const vec of vectors) {
594
+ for (let i = 0; i < dim; i++) {
595
+ result[i] += vec[i] ?? 0;
596
+ }
597
+ }
598
+ return result;
599
+ }
600
+
601
+ protected aggregateMean(vectors: number[][]): number[] {
602
+ const sum = this.aggregateSum(vectors);
603
+ return sum.map((v) => v / vectors.length);
604
+ }
605
+
606
+ protected aggregateMax(vectors: number[][]): number[] {
607
+ const dim = vectors[0]?.length ?? 0;
608
+ const result = new Array(dim).fill(-Infinity);
609
+ for (const vec of vectors) {
610
+ for (let i = 0; i < dim; i++) {
611
+ result[i] = Math.max(result[i], vec[i] ?? -Infinity);
612
+ }
613
+ }
614
+ return result;
615
+ }
616
+
617
+ protected aggregateMin(vectors: number[][]): number[] {
618
+ const dim = vectors[0]?.length ?? 0;
619
+ const result = new Array(dim).fill(Infinity);
620
+ for (const vec of vectors) {
621
+ for (let i = 0; i < dim; i++) {
622
+ result[i] = Math.min(result[i], vec[i] ?? Infinity);
623
+ }
624
+ }
625
+ return result;
626
+ }
627
+
628
+ protected aggregateWeightedMean(vectors: number[][], weights: number[]): number[] {
629
+ const dim = vectors[0]?.length ?? 0;
630
+ const result = new Array(dim).fill(0);
631
+ let totalWeight = 0;
632
+
633
+ for (let j = 0; j < vectors.length; j++) {
634
+ const w = weights[j] ?? 1;
635
+ totalWeight += w;
636
+ for (let i = 0; i < dim; i++) {
637
+ result[i] += (vectors[j]?.[i] ?? 0) * w;
638
+ }
639
+ }
640
+
641
+ return result.map((v) => (totalWeight > 0 ? v / totalWeight : 0));
642
+ }
643
+
644
+ protected aggregateAttention(vectors: number[][], weights: number[]): number[] {
645
+ // Softmax over weights then weighted mean
646
+ const maxWeight = Math.max(...weights);
647
+ const expWeights = weights.map((w) => Math.exp(w - maxWeight));
648
+ const sumExp = expWeights.reduce((a, b) => a + b, 0);
649
+ const attentionWeights = expWeights.map((w) => w / sumExp);
650
+ return this.aggregateWeightedMean(vectors, attentionWeights);
651
+ }
652
+
653
+ protected aggregateSoftmax(vectors: number[][]): number[] {
654
+ // Softmax aggregation across vectors
655
+ const dim = vectors[0]?.length ?? 0;
656
+ const result = new Array(dim).fill(0);
657
+
658
+ for (let i = 0; i < dim; i++) {
659
+ const values = vectors.map((v) => v[i] ?? 0);
660
+ const maxVal = Math.max(...values);
661
+ const expValues = values.map((v) => Math.exp(v - maxVal));
662
+ const sumExp = expValues.reduce((a, b) => a + b, 0);
663
+ result[i] = expValues.reduce((sum, exp, j) => sum + (exp / sumExp) * values[j], 0);
664
+ }
665
+
666
+ return result;
667
+ }
668
+
669
+ protected aggregatePowerMean(vectors: number[][], p: number): number[] {
670
+ const dim = vectors[0]?.length ?? 0;
671
+ const result = new Array(dim).fill(0);
672
+
673
+ for (let i = 0; i < dim; i++) {
674
+ let sum = 0;
675
+ for (const vec of vectors) {
676
+ sum += Math.pow(Math.abs(vec[i] ?? 0), p);
677
+ }
678
+ result[i] = Math.pow(sum / vectors.length, 1 / p);
679
+ }
680
+
681
+ return result;
682
+ }
683
+
684
+ protected aggregateStd(vectors: number[][]): number[] {
685
+ const mean = this.aggregateMean(vectors);
686
+ const dim = mean.length;
687
+ const variance = new Array(dim).fill(0);
688
+
689
+ for (const vec of vectors) {
690
+ for (let i = 0; i < dim; i++) {
691
+ variance[i] += Math.pow((vec[i] ?? 0) - mean[i], 2);
692
+ }
693
+ }
694
+
695
+ return variance.map((v) => Math.sqrt(v / vectors.length));
696
+ }
697
+
698
+ protected aggregateVar(vectors: number[][]): number[] {
699
+ const mean = this.aggregateMean(vectors);
700
+ const dim = mean.length;
701
+ const variance = new Array(dim).fill(0);
702
+
703
+ for (const vec of vectors) {
704
+ for (let i = 0; i < dim; i++) {
705
+ variance[i] += Math.pow((vec[i] ?? 0) - mean[i], 2);
706
+ }
707
+ }
708
+
709
+ return variance.map((v) => v / vectors.length);
710
+ }
711
+
712
+ protected aggregateConcat(vectors: number[][]): number[] {
713
+ return vectors.flat();
714
+ }
715
+
716
+ protected aggregateLSTM(vectors: number[][]): number[] {
717
+ // Simplified LSTM-style aggregation (sequential processing)
718
+ let hidden = new Array(this.config.outputDim).fill(0);
719
+ for (const vec of vectors) {
720
+ hidden = this.lstmCell(vec, hidden);
721
+ }
722
+ return hidden;
723
+ }
724
+
725
+ protected aggregateMultiHead(vectors: number[][]): number[] {
726
+ // Split into heads, aggregate each, then combine
727
+ const numHeads = this.config.numHeads ?? 1;
728
+ const headDim = Math.floor((vectors[0]?.length ?? 0) / numHeads);
729
+ const results: number[][] = [];
730
+
731
+ for (let h = 0; h < numHeads; h++) {
732
+ const headVectors = vectors.map((v) =>
733
+ v.slice(h * headDim, (h + 1) * headDim)
734
+ );
735
+ results.push(this.aggregateMean(headVectors));
736
+ }
737
+
738
+ return results.flat();
739
+ }
740
+
741
+ private lstmCell(input: number[], hidden: number[]): number[] {
742
+ // Simplified LSTM update (no learned parameters)
743
+ const dim = hidden.length;
744
+ const inputDim = input.length;
745
+ const result = new Array(dim).fill(0);
746
+
747
+ for (let i = 0; i < dim; i++) {
748
+ const inputVal = input[i % inputDim] ?? 0;
749
+ const hiddenVal = hidden[i] ?? 0;
750
+ // Simple gated update
751
+ const gate = 1 / (1 + Math.exp(-(inputVal + hiddenVal)));
752
+ result[i] = gate * inputVal + (1 - gate) * hiddenVal;
753
+ }
754
+
755
+ return result;
756
+ }
757
+
758
+ /**
759
+ * Apply activation function.
760
+ */
761
+ protected applyActivation(x: number): number {
762
+ switch (this.config.activation) {
763
+ case 'relu':
764
+ return Math.max(0, x);
765
+ case 'gelu':
766
+ return 0.5 * x * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (x + 0.044715 * Math.pow(x, 3))));
767
+ case 'silu':
768
+ case 'swish':
769
+ return x / (1 + Math.exp(-x));
770
+ case 'leaky_relu':
771
+ return x >= 0 ? x : 0.01 * x;
772
+ case 'elu':
773
+ return x >= 0 ? x : Math.exp(x) - 1;
774
+ case 'selu':
775
+ const alpha = 1.6732632423543772;
776
+ const scale = 1.0507009873554805;
777
+ return scale * (x >= 0 ? x : alpha * (Math.exp(x) - 1));
778
+ case 'tanh':
779
+ return Math.tanh(x);
780
+ case 'sigmoid':
781
+ return 1 / (1 + Math.exp(-x));
782
+ case 'softmax':
783
+ case 'none':
784
+ default:
785
+ return x;
786
+ }
787
+ }
788
+
789
+ /**
790
+ * Apply dropout (during training).
791
+ */
792
+ protected applyDropout(vector: number[], training: boolean = false): number[] {
793
+ if (!training || !this.config.dropout || this.config.dropout === 0) {
794
+ return vector;
795
+ }
796
+
797
+ const scale = 1 / (1 - this.config.dropout);
798
+ return vector.map((v) => (Math.random() > this.config.dropout! ? v * scale : 0));
799
+ }
800
+
801
+ /**
802
+ * Normalize vector (L2 normalization).
803
+ */
804
+ protected normalizeVector(vector: number[]): number[] {
805
+ const norm = Math.sqrt(vector.reduce((sum, v) => sum + v * v, 0));
806
+ return norm > 0 ? vector.map((v) => v / norm) : vector;
807
+ }
808
+
809
+ /**
810
+ * Create statistics for GNN computation.
811
+ */
812
+ protected createStats(
813
+ startTime: number,
814
+ numNodes: number,
815
+ numEdges: number,
816
+ numIterations: number = 1
817
+ ): GNNStats {
818
+ return {
819
+ forwardTimeMs: Date.now() - startTime,
820
+ numNodes,
821
+ numEdges,
822
+ memoryBytes: numNodes * this.config.outputDim * 4 + numEdges * 8,
823
+ numIterations,
824
+ };
825
+ }
826
+ }
827
+
828
+ // ============================================================================
829
+ // GCN Layer Implementation
830
+ // ============================================================================
831
+
832
+ /**
833
+ * Graph Convolutional Network (GCN) layer.
834
+ *
835
+ * Implements spectral graph convolution with first-order approximation.
836
+ * Reference: Kipf & Welling, "Semi-Supervised Classification with Graph Convolutional Networks" (2017)
837
+ */
838
+ export class GCNLayer extends BaseGNNLayer {
839
+ async forward(graph: GraphData): Promise<GNNOutput> {
840
+ const startTime = Date.now();
841
+ const { nodeFeatures, edgeIndex, edgeWeights } = graph;
842
+ const numNodes = nodeFeatures.length;
843
+ const numEdges = edgeIndex[0].length;
844
+
845
+ // Build adjacency with self-loops
846
+ const adj = this.buildAdjacency(numNodes, edgeIndex, edgeWeights);
847
+
848
+ // Normalize adjacency (D^-0.5 * A * D^-0.5)
849
+ const normAdj = this.config.normalize ? this.symmetricNormalize(adj, numNodes) : adj;
850
+
851
+ // Message passing: H' = sigma(A_norm * H * W)
852
+ const outputFeatures = this.convolve(nodeFeatures, normAdj);
853
+
854
+ return {
855
+ nodeEmbeddings: outputFeatures,
856
+ graphEmbedding: this.poolGraph(outputFeatures),
857
+ stats: this.createStats(startTime, numNodes, numEdges),
858
+ };
859
+ }
860
+
861
+ async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
862
+ const numNodes = nodes.ids.length;
863
+ const edgeIndex: [number[], number[]] = [
864
+ edges.sources.map((s) => nodes.ids.indexOf(s)),
865
+ edges.targets.map((t) => nodes.ids.indexOf(t)),
866
+ ];
867
+
868
+ const adj = this.buildAdjacency(numNodes, edgeIndex, edges.weights);
869
+ const normAdj = this.config.normalize ? this.symmetricNormalize(adj, numNodes) : adj;
870
+ const outputFeatures = this.convolve(nodes.features, normAdj);
871
+
872
+ return {
873
+ ids: nodes.ids,
874
+ features: outputFeatures,
875
+ types: nodes.types,
876
+ labels: nodes.labels,
877
+ };
878
+ }
879
+
880
+ private buildAdjacency(
881
+ numNodes: number,
882
+ edgeIndex: [number[], number[]],
883
+ weights?: number[]
884
+ ): Map<number, Map<number, number>> {
885
+ const adj = new Map<number, Map<number, number>>();
886
+
887
+ // Initialize with self-loops if configured
888
+ for (let i = 0; i < numNodes; i++) {
889
+ adj.set(i, new Map());
890
+ if (this.config.addSelfLoops) {
891
+ adj.get(i)!.set(i, 1);
892
+ }
893
+ }
894
+
895
+ // Add edges
896
+ const [sources, targets] = edgeIndex;
897
+ for (let i = 0; i < sources.length; i++) {
898
+ const src = sources[i];
899
+ const tgt = targets[i];
900
+ const weight = weights?.[i] ?? 1;
901
+ if (src >= 0 && src < numNodes && tgt >= 0 && tgt < numNodes) {
902
+ adj.get(src)!.set(tgt, weight);
903
+ // Undirected: add reverse edge
904
+ adj.get(tgt)!.set(src, weight);
905
+ }
906
+ }
907
+
908
+ return adj;
909
+ }
910
+
911
+ private symmetricNormalize(
912
+ adj: Map<number, Map<number, number>>,
913
+ numNodes: number
914
+ ): Map<number, Map<number, number>> {
915
+ // Compute degree
916
+ const degree = new Array(numNodes).fill(0);
917
+ for (let i = 0; i < numNodes; i++) {
918
+ for (const weight of adj.get(i)!.values()) {
919
+ degree[i] += weight;
920
+ }
921
+ }
922
+
923
+ // D^-0.5 * A * D^-0.5
924
+ const normAdj = new Map<number, Map<number, number>>();
925
+ for (let i = 0; i < numNodes; i++) {
926
+ normAdj.set(i, new Map());
927
+ for (const [j, weight] of adj.get(i)!.entries()) {
928
+ const normWeight = weight / Math.sqrt(degree[i] * degree[j] + 1e-10);
929
+ normAdj.get(i)!.set(j, normWeight);
930
+ }
931
+ }
932
+
933
+ return normAdj;
934
+ }
935
+
936
+ private convolve(features: number[][], adj: Map<number, Map<number, number>>): number[][] {
937
+ const numNodes = features.length;
938
+ const inputDim = this.config.inputDim;
939
+ const outputDim = this.config.outputDim;
940
+ const output: number[][] = [];
941
+
942
+ for (let i = 0; i < numNodes; i++) {
943
+ const aggregated = new Array(inputDim).fill(0);
944
+
945
+ // Aggregate neighbor features
946
+ for (const [j, weight] of adj.get(i)!.entries()) {
947
+ const neighborFeatures = features[j] ?? new Array(inputDim).fill(0);
948
+ for (let k = 0; k < inputDim; k++) {
949
+ aggregated[k] += weight * (neighborFeatures[k] ?? 0);
950
+ }
951
+ }
952
+
953
+ // Project to output dimension
954
+ const projected = this.projectFeatures(aggregated, inputDim, outputDim);
955
+
956
+ // Apply activation
957
+ const activated = projected.map((x) => this.applyActivation(x));
958
+
959
+ // Apply dropout
960
+ output.push(this.applyDropout(activated));
961
+ }
962
+
963
+ return output;
964
+ }
965
+
966
+ private projectFeatures(input: number[], inputDim: number, outputDim: number): number[] {
967
+ // Simple linear projection (in practice, this would use learned weights)
968
+ const output = new Array(outputDim).fill(0);
969
+ for (let i = 0; i < outputDim; i++) {
970
+ for (let j = 0; j < inputDim; j++) {
971
+ // Use a deterministic pseudo-weight based on position
972
+ const weight = Math.sin((i * inputDim + j) * 0.1) * 0.5;
973
+ output[i] += input[j] * weight;
974
+ }
975
+ if (this.config.useBias) {
976
+ output[i] += 0.01; // Small bias term
977
+ }
978
+ }
979
+ return output;
980
+ }
981
+
982
+ private poolGraph(features: number[][]): number[] {
983
+ if (features.length === 0) return [];
984
+ return this.aggregateMean(features);
985
+ }
986
+ }
987
+
988
+ // ============================================================================
989
+ // GAT Layer Implementation
990
+ // ============================================================================
991
+
992
+ /**
993
+ * Graph Attention Network (GAT) layer.
994
+ *
995
+ * Implements attention-based message passing.
996
+ * Reference: Veličković et al., "Graph Attention Networks" (2018)
997
+ */
998
+ export class GATLayer extends BaseGNNLayer {
999
+ async forward(graph: GraphData): Promise<GNNOutput> {
1000
+ const startTime = Date.now();
1001
+ const { nodeFeatures, edgeIndex } = graph;
1002
+ const numNodes = nodeFeatures.length;
1003
+ const numEdges = edgeIndex[0].length;
1004
+ const numHeads = this.config.numHeads ?? 1;
1005
+ const negativeSlope = this.config.params?.negativeSlope ?? 0.2;
1006
+
1007
+ // Compute attention for each head
1008
+ const headOutputs: number[][][] = [];
1009
+
1010
+ for (let h = 0; h < numHeads; h++) {
1011
+ const headDim = Math.floor(this.config.outputDim / numHeads);
1012
+ const headFeatures: number[][] = [];
1013
+
1014
+ for (let i = 0; i < numNodes; i++) {
1015
+ const neighbors = this.getNeighbors(i, edgeIndex, numNodes);
1016
+ const messages: { feature: number[]; attention: number }[] = [];
1017
+
1018
+ // Compute attention for each neighbor
1019
+ for (const j of neighbors) {
1020
+ const attention = this.computeAttention(
1021
+ nodeFeatures[i],
1022
+ nodeFeatures[j],
1023
+ h,
1024
+ negativeSlope
1025
+ );
1026
+ messages.push({
1027
+ feature: this.projectHead(nodeFeatures[j], h, headDim),
1028
+ attention,
1029
+ });
1030
+ }
1031
+
1032
+ // Softmax attention weights
1033
+ const attentionSum = messages.reduce(
1034
+ (sum, m) => sum + Math.exp(m.attention),
1035
+ 0
1036
+ );
1037
+ const normalizedMessages = messages.map((m) => ({
1038
+ feature: m.feature,
1039
+ weight: Math.exp(m.attention) / (attentionSum + 1e-10),
1040
+ }));
1041
+
1042
+ // Aggregate with attention weights
1043
+ const aggregated = new Array(headDim).fill(0);
1044
+ for (const m of normalizedMessages) {
1045
+ for (let k = 0; k < headDim; k++) {
1046
+ aggregated[k] += m.weight * (m.feature[k] ?? 0);
1047
+ }
1048
+ }
1049
+
1050
+ headFeatures.push(aggregated);
1051
+ }
1052
+
1053
+ headOutputs.push(headFeatures);
1054
+ }
1055
+
1056
+ // Combine heads (concat or average)
1057
+ const concat = this.config.params?.concat ?? true;
1058
+ const outputFeatures = this.combineHeads(headOutputs, concat);
1059
+
1060
+ // Apply activation and dropout
1061
+ const finalFeatures = outputFeatures.map((f) =>
1062
+ this.applyDropout(f.map((x) => this.applyActivation(x)))
1063
+ );
1064
+
1065
+ return {
1066
+ nodeEmbeddings: finalFeatures,
1067
+ graphEmbedding: this.aggregateMean(finalFeatures),
1068
+ attentionWeights: this.extractAttentionWeights(headOutputs),
1069
+ stats: this.createStats(startTime, numNodes, numEdges),
1070
+ };
1071
+ }
1072
+
1073
+ async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
1074
+ const graph: GraphData = {
1075
+ nodeFeatures: nodes.features,
1076
+ edgeIndex: [
1077
+ edges.sources.map((s) => nodes.ids.indexOf(s)),
1078
+ edges.targets.map((t) => nodes.ids.indexOf(t)),
1079
+ ],
1080
+ };
1081
+
1082
+ const output = await this.forward(graph);
1083
+
1084
+ return {
1085
+ ids: nodes.ids,
1086
+ features: output.nodeEmbeddings,
1087
+ types: nodes.types,
1088
+ labels: nodes.labels,
1089
+ };
1090
+ }
1091
+
1092
+ private getNeighbors(
1093
+ nodeIdx: number,
1094
+ edgeIndex: [number[], number[]],
1095
+ numNodes: number
1096
+ ): number[] {
1097
+ const neighbors = new Set<number>();
1098
+
1099
+ // Add self-loop
1100
+ if (this.config.addSelfLoops) {
1101
+ neighbors.add(nodeIdx);
1102
+ }
1103
+
1104
+ // Find neighbors from edges
1105
+ const [sources, targets] = edgeIndex;
1106
+ for (let i = 0; i < sources.length; i++) {
1107
+ if (sources[i] === nodeIdx && targets[i] < numNodes) {
1108
+ neighbors.add(targets[i]);
1109
+ }
1110
+ if (targets[i] === nodeIdx && sources[i] < numNodes) {
1111
+ neighbors.add(sources[i]);
1112
+ }
1113
+ }
1114
+
1115
+ return Array.from(neighbors);
1116
+ }
1117
+
1118
+ protected computeAttention(
1119
+ nodeI: number[],
1120
+ nodeJ: number[],
1121
+ head: number,
1122
+ negativeSlope: number
1123
+ ): number {
1124
+ // Compute attention score using concatenation of features
1125
+ let score = 0;
1126
+ const dim = nodeI.length;
1127
+
1128
+ for (let k = 0; k < dim; k++) {
1129
+ // Simple attention mechanism (in practice, uses learned attention weights)
1130
+ const combined = (nodeI[k] ?? 0) + (nodeJ[k] ?? 0);
1131
+ score += combined * Math.sin((head * dim + k) * 0.1);
1132
+ }
1133
+
1134
+ // LeakyReLU
1135
+ return score >= 0 ? score : negativeSlope * score;
1136
+ }
1137
+
1138
+ private projectHead(features: number[], head: number, headDim: number): number[] {
1139
+ const output = new Array(headDim).fill(0);
1140
+ const inputDim = features.length;
1141
+
1142
+ for (let i = 0; i < headDim; i++) {
1143
+ for (let j = 0; j < inputDim; j++) {
1144
+ const weight = Math.cos((head * headDim * inputDim + i * inputDim + j) * 0.05);
1145
+ output[i] += (features[j] ?? 0) * weight;
1146
+ }
1147
+ }
1148
+
1149
+ return output;
1150
+ }
1151
+
1152
+ private combineHeads(heads: number[][][], concat: boolean): number[][] {
1153
+ const numNodes = heads[0]?.length ?? 0;
1154
+ const result: number[][] = [];
1155
+
1156
+ for (let i = 0; i < numNodes; i++) {
1157
+ if (concat) {
1158
+ // Concatenate all head outputs
1159
+ result.push(heads.flatMap((h) => h[i] ?? []));
1160
+ } else {
1161
+ // Average head outputs
1162
+ const headDim = heads[0]?.[0]?.length ?? 0;
1163
+ const averaged = new Array(headDim).fill(0);
1164
+ for (const head of heads) {
1165
+ for (let j = 0; j < headDim; j++) {
1166
+ averaged[j] += (head[i]?.[j] ?? 0) / heads.length;
1167
+ }
1168
+ }
1169
+ result.push(averaged);
1170
+ }
1171
+ }
1172
+
1173
+ return result;
1174
+ }
1175
+
1176
+ private extractAttentionWeights(heads: number[][][]): number[][] {
1177
+ // Return simplified attention representation
1178
+ return heads.map((h) => h.map((node) => node.reduce((a, b) => a + b, 0) / node.length));
1179
+ }
1180
+ }
1181
+
1182
+ // ============================================================================
1183
+ // GAT v2 Layer Implementation
1184
+ // ============================================================================
1185
+
1186
+ /**
1187
+ * Graph Attention Network v2 layer.
1188
+ *
1189
+ * Improved attention mechanism with dynamic attention.
1190
+ * Reference: Brody et al., "How Attentive are Graph Attention Networks?" (2022)
1191
+ */
1192
+ export class GATv2Layer extends GATLayer {
1193
+ protected override computeAttention(
1194
+ nodeI: number[],
1195
+ nodeJ: number[],
1196
+ head: number,
1197
+ negativeSlope: number
1198
+ ): number {
1199
+ // GAT v2: Apply attention AFTER concatenation and transformation
1200
+ const dim = nodeI.length;
1201
+ const combined = new Array(dim).fill(0);
1202
+
1203
+ // First, transform and combine
1204
+ for (let k = 0; k < dim; k++) {
1205
+ combined[k] = (nodeI[k] ?? 0) + (nodeJ[k] ?? 0);
1206
+ }
1207
+
1208
+ // Apply LeakyReLU
1209
+ for (let k = 0; k < dim; k++) {
1210
+ combined[k] = combined[k] >= 0 ? combined[k] : negativeSlope * combined[k];
1211
+ }
1212
+
1213
+ // Then compute attention
1214
+ let score = 0;
1215
+ for (let k = 0; k < dim; k++) {
1216
+ score += combined[k] * Math.sin((head * dim + k) * 0.1);
1217
+ }
1218
+
1219
+ return score;
1220
+ }
1221
+ }
1222
+
1223
+ // ============================================================================
1224
+ // GraphSAGE Layer Implementation
1225
+ // ============================================================================
1226
+
1227
+ /**
1228
+ * GraphSAGE (Sample and Aggregate) layer.
1229
+ *
1230
+ * Implements inductive representation learning with neighbor sampling.
1231
+ * Reference: Hamilton et al., "Inductive Representation Learning on Large Graphs" (2017)
1232
+ */
1233
+ export class GraphSAGELayer extends BaseGNNLayer {
1234
+ async forward(graph: GraphData): Promise<GNNOutput> {
1235
+ const startTime = Date.now();
1236
+ const { nodeFeatures, edgeIndex } = graph;
1237
+ const numNodes = nodeFeatures.length;
1238
+ const numEdges = edgeIndex[0].length;
1239
+ const sampleSize = this.config.params?.sampleSize ?? 10;
1240
+
1241
+ const outputFeatures: number[][] = [];
1242
+
1243
+ for (let i = 0; i < numNodes; i++) {
1244
+ // Sample neighbors
1245
+ const allNeighbors = this.getNeighbors(i, edgeIndex, numNodes);
1246
+ const sampledNeighbors = this.sampleNeighbors(allNeighbors, sampleSize);
1247
+
1248
+ // Aggregate neighbor features
1249
+ const neighborFeatures = sampledNeighbors.map((j) => nodeFeatures[j] ?? []);
1250
+ const aggregated = await this.aggregate(
1251
+ neighborFeatures.map((f) => ({ source: i, target: i, vector: f })),
1252
+ this.config.aggregation ?? 'mean'
1253
+ );
1254
+
1255
+ // Concatenate with self features and project
1256
+ const selfFeatures = nodeFeatures[i] ?? [];
1257
+ const combined = [...selfFeatures, ...aggregated];
1258
+ const projected = this.projectFeatures(combined, combined.length, this.config.outputDim);
1259
+
1260
+ // Normalize, activate, and apply dropout
1261
+ const normalized = this.config.normalize ? this.normalizeVector(projected) : projected;
1262
+ const activated = normalized.map((x) => this.applyActivation(x));
1263
+ outputFeatures.push(this.applyDropout(activated));
1264
+ }
1265
+
1266
+ return {
1267
+ nodeEmbeddings: outputFeatures,
1268
+ graphEmbedding: this.aggregateMean(outputFeatures),
1269
+ stats: this.createStats(startTime, numNodes, numEdges),
1270
+ };
1271
+ }
1272
+
1273
+ async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
1274
+ const graph: GraphData = {
1275
+ nodeFeatures: nodes.features,
1276
+ edgeIndex: [
1277
+ edges.sources.map((s) => nodes.ids.indexOf(s)),
1278
+ edges.targets.map((t) => nodes.ids.indexOf(t)),
1279
+ ],
1280
+ };
1281
+
1282
+ const output = await this.forward(graph);
1283
+
1284
+ return {
1285
+ ids: nodes.ids,
1286
+ features: output.nodeEmbeddings,
1287
+ types: nodes.types,
1288
+ labels: nodes.labels,
1289
+ };
1290
+ }
1291
+
1292
+ private getNeighbors(
1293
+ nodeIdx: number,
1294
+ edgeIndex: [number[], number[]],
1295
+ numNodes: number
1296
+ ): number[] {
1297
+ const neighbors = new Set<number>();
1298
+ const [sources, targets] = edgeIndex;
1299
+
1300
+ for (let i = 0; i < sources.length; i++) {
1301
+ if (sources[i] === nodeIdx && targets[i] < numNodes) {
1302
+ neighbors.add(targets[i]);
1303
+ }
1304
+ if (targets[i] === nodeIdx && sources[i] < numNodes) {
1305
+ neighbors.add(sources[i]);
1306
+ }
1307
+ }
1308
+
1309
+ return Array.from(neighbors);
1310
+ }
1311
+
1312
+ private sampleNeighbors(neighbors: number[], k: number): number[] {
1313
+ if (neighbors.length <= k) {
1314
+ return neighbors;
1315
+ }
1316
+
1317
+ // Random sampling
1318
+ const sampled: number[] = [];
1319
+ const available = [...neighbors];
1320
+
1321
+ for (let i = 0; i < k && available.length > 0; i++) {
1322
+ const idx = Math.floor(Math.random() * available.length);
1323
+ sampled.push(available[idx]);
1324
+ available.splice(idx, 1);
1325
+ }
1326
+
1327
+ return sampled;
1328
+ }
1329
+
1330
+ private projectFeatures(input: number[], inputDim: number, outputDim: number): number[] {
1331
+ const output = new Array(outputDim).fill(0);
1332
+ for (let i = 0; i < outputDim; i++) {
1333
+ for (let j = 0; j < inputDim; j++) {
1334
+ const weight = Math.sin((i * inputDim + j) * 0.1) * Math.sqrt(2 / (inputDim + outputDim));
1335
+ output[i] += (input[j] ?? 0) * weight;
1336
+ }
1337
+ if (this.config.useBias) {
1338
+ output[i] += 0.01;
1339
+ }
1340
+ }
1341
+ return output;
1342
+ }
1343
+ }
1344
+
1345
+ // ============================================================================
1346
+ // GIN Layer Implementation
1347
+ // ============================================================================
1348
+
1349
+ /**
1350
+ * Graph Isomorphism Network (GIN) layer.
1351
+ *
1352
+ * Maximally powerful GNN for graph classification.
1353
+ * Reference: Xu et al., "How Powerful are Graph Neural Networks?" (2019)
1354
+ */
1355
+ export class GINLayer extends BaseGNNLayer {
1356
+ async forward(graph: GraphData): Promise<GNNOutput> {
1357
+ const startTime = Date.now();
1358
+ const { nodeFeatures, edgeIndex } = graph;
1359
+ const numNodes = nodeFeatures.length;
1360
+ const numEdges = edgeIndex[0].length;
1361
+ const eps = this.config.params?.eps ?? 0;
1362
+
1363
+ const outputFeatures: number[][] = [];
1364
+
1365
+ for (let i = 0; i < numNodes; i++) {
1366
+ const neighbors = this.getNeighbors(i, edgeIndex, numNodes);
1367
+
1368
+ // Sum neighbor features
1369
+ const neighborSum = new Array(this.config.inputDim).fill(0);
1370
+ for (const j of neighbors) {
1371
+ const neighborFeatures = nodeFeatures[j] ?? [];
1372
+ for (let k = 0; k < this.config.inputDim; k++) {
1373
+ neighborSum[k] += neighborFeatures[k] ?? 0;
1374
+ }
1375
+ }
1376
+
1377
+ // GIN update: h_v = MLP((1 + eps) * h_v + sum(h_u))
1378
+ const selfFeatures = nodeFeatures[i] ?? [];
1379
+ const combined = new Array(this.config.inputDim).fill(0);
1380
+ for (let k = 0; k < this.config.inputDim; k++) {
1381
+ combined[k] = (1 + eps) * (selfFeatures[k] ?? 0) + neighborSum[k];
1382
+ }
1383
+
1384
+ // MLP (2-layer)
1385
+ const hidden = this.mlpLayer1(combined);
1386
+ const output = this.mlpLayer2(hidden);
1387
+ outputFeatures.push(this.applyDropout(output));
1388
+ }
1389
+
1390
+ return {
1391
+ nodeEmbeddings: outputFeatures,
1392
+ graphEmbedding: this.aggregateSum(outputFeatures), // Sum pooling for graph classification
1393
+ stats: this.createStats(startTime, numNodes, numEdges),
1394
+ };
1395
+ }
1396
+
1397
+ async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
1398
+ const graph: GraphData = {
1399
+ nodeFeatures: nodes.features,
1400
+ edgeIndex: [
1401
+ edges.sources.map((s) => nodes.ids.indexOf(s)),
1402
+ edges.targets.map((t) => nodes.ids.indexOf(t)),
1403
+ ],
1404
+ };
1405
+
1406
+ const output = await this.forward(graph);
1407
+
1408
+ return {
1409
+ ids: nodes.ids,
1410
+ features: output.nodeEmbeddings,
1411
+ types: nodes.types,
1412
+ labels: nodes.labels,
1413
+ };
1414
+ }
1415
+
1416
+ private getNeighbors(
1417
+ nodeIdx: number,
1418
+ edgeIndex: [number[], number[]],
1419
+ numNodes: number
1420
+ ): number[] {
1421
+ const neighbors = new Set<number>();
1422
+ const [sources, targets] = edgeIndex;
1423
+
1424
+ for (let i = 0; i < sources.length; i++) {
1425
+ if (sources[i] === nodeIdx && targets[i] < numNodes) {
1426
+ neighbors.add(targets[i]);
1427
+ }
1428
+ if (targets[i] === nodeIdx && sources[i] < numNodes) {
1429
+ neighbors.add(sources[i]);
1430
+ }
1431
+ }
1432
+
1433
+ return Array.from(neighbors);
1434
+ }
1435
+
1436
+ private mlpLayer1(input: number[]): number[] {
1437
+ const hiddenDim = this.config.hiddenDim ?? this.config.inputDim;
1438
+ const output = new Array(hiddenDim).fill(0);
1439
+
1440
+ for (let i = 0; i < hiddenDim; i++) {
1441
+ for (let j = 0; j < input.length; j++) {
1442
+ const weight = Math.sin((i * input.length + j) * 0.1) * 0.5;
1443
+ output[i] += (input[j] ?? 0) * weight;
1444
+ }
1445
+ output[i] = this.applyActivation(output[i]);
1446
+ }
1447
+
1448
+ return output;
1449
+ }
1450
+
1451
+ private mlpLayer2(input: number[]): number[] {
1452
+ const output = new Array(this.config.outputDim).fill(0);
1453
+
1454
+ for (let i = 0; i < this.config.outputDim; i++) {
1455
+ for (let j = 0; j < input.length; j++) {
1456
+ const weight = Math.cos((i * input.length + j) * 0.1) * 0.5;
1457
+ output[i] += (input[j] ?? 0) * weight;
1458
+ }
1459
+ }
1460
+
1461
+ return output;
1462
+ }
1463
+ }
1464
+
1465
+ // ============================================================================
1466
+ // MPNN Layer Implementation
1467
+ // ============================================================================
1468
+
1469
+ /**
1470
+ * Message Passing Neural Network (MPNN) layer.
1471
+ *
1472
+ * General framework for GNN with customizable message and update functions.
1473
+ * Reference: Gilmer et al., "Neural Message Passing for Quantum Chemistry" (2017)
1474
+ */
1475
+ export class MPNNLayer extends BaseGNNLayer {
1476
+ async forward(graph: GraphData): Promise<GNNOutput> {
1477
+ const startTime = Date.now();
1478
+ const { nodeFeatures, edgeIndex, edgeFeatures } = graph;
1479
+ const numNodes = nodeFeatures.length;
1480
+ const numEdges = edgeIndex[0].length;
1481
+
1482
+ let currentFeatures = nodeFeatures.map((f) => [...f]);
1483
+
1484
+ // Multiple rounds of message passing
1485
+ const numIterations = this.config.params?.numLayers ?? 1;
1486
+
1487
+ for (let t = 0; t < numIterations; t++) {
1488
+ const newFeatures: number[][] = [];
1489
+
1490
+ for (let i = 0; i < numNodes; i++) {
1491
+ // Collect messages from neighbors
1492
+ const messages: Message[] = [];
1493
+ const [sources, targets] = edgeIndex;
1494
+
1495
+ for (let e = 0; e < sources.length; e++) {
1496
+ if (targets[e] === i) {
1497
+ const j = sources[e];
1498
+ const edgeFeat = edgeFeatures?.[e];
1499
+ const message = this.messageFunction(
1500
+ currentFeatures[j] ?? [],
1501
+ currentFeatures[i] ?? [],
1502
+ edgeFeat
1503
+ );
1504
+ messages.push({
1505
+ source: j,
1506
+ target: i,
1507
+ vector: message,
1508
+ edgeFeatures: edgeFeat,
1509
+ });
1510
+ }
1511
+ if (sources[e] === i) {
1512
+ const j = targets[e];
1513
+ const edgeFeat = edgeFeatures?.[e];
1514
+ const message = this.messageFunction(
1515
+ currentFeatures[j] ?? [],
1516
+ currentFeatures[i] ?? [],
1517
+ edgeFeat
1518
+ );
1519
+ messages.push({
1520
+ source: j,
1521
+ target: i,
1522
+ vector: message,
1523
+ edgeFeatures: edgeFeat,
1524
+ });
1525
+ }
1526
+ }
1527
+
1528
+ // Aggregate messages
1529
+ const aggregated = await this.aggregate(messages, this.config.aggregation ?? 'sum');
1530
+
1531
+ // Update node features
1532
+ const updated = this.updateFunction(currentFeatures[i] ?? [], aggregated);
1533
+ newFeatures.push(this.applyDropout(updated));
1534
+ }
1535
+
1536
+ currentFeatures = newFeatures;
1537
+ }
1538
+
1539
+ return {
1540
+ nodeEmbeddings: currentFeatures,
1541
+ graphEmbedding: this.aggregateMean(currentFeatures),
1542
+ stats: this.createStats(startTime, numNodes, numEdges, numIterations),
1543
+ };
1544
+ }
1545
+
1546
+ async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
1547
+ const graph: GraphData = {
1548
+ nodeFeatures: nodes.features,
1549
+ edgeIndex: [
1550
+ edges.sources.map((s) => nodes.ids.indexOf(s)),
1551
+ edges.targets.map((t) => nodes.ids.indexOf(t)),
1552
+ ],
1553
+ edgeFeatures: edges.features,
1554
+ };
1555
+
1556
+ const output = await this.forward(graph);
1557
+
1558
+ return {
1559
+ ids: nodes.ids,
1560
+ features: output.nodeEmbeddings,
1561
+ types: nodes.types,
1562
+ labels: nodes.labels,
1563
+ };
1564
+ }
1565
+
1566
+ private messageFunction(
1567
+ sourceFeatures: number[],
1568
+ targetFeatures: number[],
1569
+ edgeFeatures?: number[]
1570
+ ): number[] {
1571
+ const dim = this.config.inputDim;
1572
+ const message = new Array(dim).fill(0);
1573
+
1574
+ for (let i = 0; i < dim; i++) {
1575
+ message[i] = (sourceFeatures[i] ?? 0) * 0.5 + (targetFeatures[i] ?? 0) * 0.3;
1576
+ if (edgeFeatures && edgeFeatures[i] !== undefined) {
1577
+ message[i] += edgeFeatures[i] * 0.2;
1578
+ }
1579
+ }
1580
+
1581
+ return message;
1582
+ }
1583
+
1584
+ private updateFunction(nodeFeatures: number[], aggregated: number[]): number[] {
1585
+ const output = new Array(this.config.outputDim).fill(0);
1586
+
1587
+ // GRU-like update
1588
+ for (let i = 0; i < this.config.outputDim; i++) {
1589
+ const nodeVal = nodeFeatures[i % nodeFeatures.length] ?? 0;
1590
+ const aggVal = aggregated[i % aggregated.length] ?? 0;
1591
+ const gate = 1 / (1 + Math.exp(-(nodeVal + aggVal)));
1592
+ output[i] = this.applyActivation(gate * aggVal + (1 - gate) * nodeVal);
1593
+ }
1594
+
1595
+ return output;
1596
+ }
1597
+ }
1598
+
1599
+ // ============================================================================
1600
+ // EdgeConv Layer Implementation
1601
+ // ============================================================================
1602
+
1603
+ /**
1604
+ * EdgeConv layer for dynamic graph convolution.
1605
+ *
1606
+ * Uses k-NN graph construction and edge features.
1607
+ * Reference: Wang et al., "Dynamic Graph CNN for Learning on Point Clouds" (2019)
1608
+ */
1609
+ export class EdgeConvLayer extends BaseGNNLayer {
1610
+ async forward(graph: GraphData): Promise<GNNOutput> {
1611
+ const startTime = Date.now();
1612
+ const { nodeFeatures } = graph;
1613
+ const numNodes = nodeFeatures.length;
1614
+ const k = this.config.params?.k ?? 20;
1615
+ const dynamic = this.config.params?.dynamic ?? true;
1616
+
1617
+ // Build k-NN graph
1618
+ const knnGraph = dynamic
1619
+ ? this.buildKNNGraph(nodeFeatures, k)
1620
+ : graph.edgeIndex;
1621
+
1622
+ const outputFeatures: number[][] = [];
1623
+
1624
+ for (let i = 0; i < numNodes; i++) {
1625
+ const neighbors = this.getKNNNeighbors(i, knnGraph);
1626
+ const selfFeatures = nodeFeatures[i] ?? [];
1627
+
1628
+ // Edge features: (x_j - x_i) || x_i
1629
+ const edgeFeatures: number[][] = [];
1630
+ for (const j of neighbors) {
1631
+ const neighborFeatures = nodeFeatures[j] ?? [];
1632
+ const diff = selfFeatures.map((v, idx) => (neighborFeatures[idx] ?? 0) - v);
1633
+ edgeFeatures.push([...diff, ...selfFeatures]);
1634
+ }
1635
+
1636
+ // Max pooling over edge features
1637
+ const pooled = this.maxPoolEdges(edgeFeatures);
1638
+
1639
+ // MLP on pooled features
1640
+ const output = this.edgeMLP(pooled);
1641
+ outputFeatures.push(this.applyDropout(output));
1642
+ }
1643
+
1644
+ return {
1645
+ nodeEmbeddings: outputFeatures,
1646
+ graphEmbedding: this.aggregateMean(outputFeatures),
1647
+ stats: this.createStats(startTime, numNodes, numNodes * k),
1648
+ };
1649
+ }
1650
+
1651
+ async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
1652
+ const graph: GraphData = {
1653
+ nodeFeatures: nodes.features,
1654
+ edgeIndex: [
1655
+ edges.sources.map((s) => nodes.ids.indexOf(s)),
1656
+ edges.targets.map((t) => nodes.ids.indexOf(t)),
1657
+ ],
1658
+ };
1659
+
1660
+ const output = await this.forward(graph);
1661
+
1662
+ return {
1663
+ ids: nodes.ids,
1664
+ features: output.nodeEmbeddings,
1665
+ types: nodes.types,
1666
+ labels: nodes.labels,
1667
+ };
1668
+ }
1669
+
1670
+ private buildKNNGraph(features: number[][], k: number): [number[], number[]] {
1671
+ const sources: number[] = [];
1672
+ const targets: number[] = [];
1673
+
1674
+ for (let i = 0; i < features.length; i++) {
1675
+ const distances: { idx: number; dist: number }[] = [];
1676
+
1677
+ for (let j = 0; j < features.length; j++) {
1678
+ if (i !== j) {
1679
+ const dist = this.euclideanDistance(features[i], features[j]);
1680
+ distances.push({ idx: j, dist });
1681
+ }
1682
+ }
1683
+
1684
+ distances.sort((a, b) => a.dist - b.dist);
1685
+ const neighbors = distances.slice(0, k);
1686
+
1687
+ for (const neighbor of neighbors) {
1688
+ sources.push(i);
1689
+ targets.push(neighbor.idx);
1690
+ }
1691
+ }
1692
+
1693
+ return [sources, targets];
1694
+ }
1695
+
1696
+ private euclideanDistance(a: number[], b: number[]): number {
1697
+ let sum = 0;
1698
+ for (let i = 0; i < a.length; i++) {
1699
+ const diff = (a[i] ?? 0) - (b[i] ?? 0);
1700
+ sum += diff * diff;
1701
+ }
1702
+ return Math.sqrt(sum);
1703
+ }
1704
+
1705
+ private getKNNNeighbors(nodeIdx: number, edgeIndex: [number[], number[]]): number[] {
1706
+ const neighbors: number[] = [];
1707
+ const [sources, targets] = edgeIndex;
1708
+
1709
+ for (let i = 0; i < sources.length; i++) {
1710
+ if (sources[i] === nodeIdx) {
1711
+ neighbors.push(targets[i]);
1712
+ }
1713
+ }
1714
+
1715
+ return neighbors;
1716
+ }
1717
+
1718
+ private maxPoolEdges(edgeFeatures: number[][]): number[] {
1719
+ if (edgeFeatures.length === 0) {
1720
+ return new Array(this.config.inputDim * 2).fill(0);
1721
+ }
1722
+ return this.aggregateMax(edgeFeatures);
1723
+ }
1724
+
1725
+ private edgeMLP(input: number[]): number[] {
1726
+ const output = new Array(this.config.outputDim).fill(0);
1727
+
1728
+ for (let i = 0; i < this.config.outputDim; i++) {
1729
+ for (let j = 0; j < input.length; j++) {
1730
+ const weight = Math.sin((i * input.length + j) * 0.08) * 0.4;
1731
+ output[i] += (input[j] ?? 0) * weight;
1732
+ }
1733
+ output[i] = this.applyActivation(output[i]);
1734
+ }
1735
+
1736
+ return output;
1737
+ }
1738
+ }
1739
+
1740
+ // ============================================================================
1741
+ // Additional GNN Layer Implementations (Stubs)
1742
+ // ============================================================================
1743
+
1744
+ /**
1745
+ * Point Convolution layer for point cloud data.
1746
+ */
1747
+ export class PointConvLayer extends EdgeConvLayer {
1748
+ // Extends EdgeConv with point-specific operations
1749
+ }
1750
+
1751
+ /**
1752
+ * Graph Transformer layer.
1753
+ */
1754
+ export class GraphTransformerLayer extends GATLayer {
1755
+ override async forward(graph: GraphData): Promise<GNNOutput> {
1756
+ // Add positional encoding and full attention
1757
+ const result = await super.forward(graph);
1758
+
1759
+ // Apply transformer-specific operations (layer norm, residual)
1760
+ const normalizedEmbeddings = result.nodeEmbeddings.map((f) =>
1761
+ this.layerNorm(f)
1762
+ );
1763
+
1764
+ return {
1765
+ ...result,
1766
+ nodeEmbeddings: normalizedEmbeddings,
1767
+ };
1768
+ }
1769
+
1770
+ private layerNorm(features: number[]): number[] {
1771
+ const mean = features.reduce((a, b) => a + b, 0) / features.length;
1772
+ const variance =
1773
+ features.reduce((sum, x) => sum + (x - mean) ** 2, 0) / features.length;
1774
+ const std = Math.sqrt(variance + 1e-6);
1775
+ return features.map((x) => (x - mean) / std);
1776
+ }
1777
+ }
1778
+
1779
+ /**
1780
+ * Principal Neighbourhood Aggregation (PNA) layer.
1781
+ */
1782
+ export class PNALayer extends BaseGNNLayer {
1783
+ async forward(graph: GraphData): Promise<GNNOutput> {
1784
+ const startTime = Date.now();
1785
+ const { nodeFeatures, edgeIndex } = graph;
1786
+ const numNodes = nodeFeatures.length;
1787
+ const numEdges = edgeIndex[0].length;
1788
+
1789
+ const aggregators = this.config.params?.aggregators ?? ['mean', 'sum', 'max', 'min'];
1790
+ const scalers = this.config.params?.scalers ?? ['identity', 'amplification', 'attenuation'];
1791
+
1792
+ const outputFeatures: number[][] = [];
1793
+
1794
+ for (let i = 0; i < numNodes; i++) {
1795
+ const neighbors = this.getNeighbors(i, edgeIndex, numNodes);
1796
+ const neighborFeatures = neighbors.map((j) => nodeFeatures[j] ?? []);
1797
+ const degree = neighbors.length || 1;
1798
+
1799
+ // Apply multiple aggregators
1800
+ const aggregatedResults: number[][] = [];
1801
+ for (const agg of aggregators) {
1802
+ const messages = neighborFeatures.map((f) => ({
1803
+ source: 0,
1804
+ target: i,
1805
+ vector: f,
1806
+ }));
1807
+ const result = await this.aggregate(messages, agg as AggregationMethod);
1808
+ aggregatedResults.push(result);
1809
+ }
1810
+
1811
+ // Apply scalers
1812
+ const scaledResults: number[][] = [];
1813
+ for (const aggregated of aggregatedResults) {
1814
+ for (const scaler of scalers) {
1815
+ scaledResults.push(this.applyScaler(aggregated, scaler, degree));
1816
+ }
1817
+ }
1818
+
1819
+ // Concatenate and project
1820
+ const combined = scaledResults.flat();
1821
+ const projected = this.projectFeatures(combined);
1822
+ outputFeatures.push(this.applyDropout(projected.map((x) => this.applyActivation(x))));
1823
+ }
1824
+
1825
+ return {
1826
+ nodeEmbeddings: outputFeatures,
1827
+ graphEmbedding: this.aggregateMean(outputFeatures),
1828
+ stats: this.createStats(startTime, numNodes, numEdges),
1829
+ };
1830
+ }
1831
+
1832
+ async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
1833
+ const graph: GraphData = {
1834
+ nodeFeatures: nodes.features,
1835
+ edgeIndex: [
1836
+ edges.sources.map((s) => nodes.ids.indexOf(s)),
1837
+ edges.targets.map((t) => nodes.ids.indexOf(t)),
1838
+ ],
1839
+ };
1840
+
1841
+ const output = await this.forward(graph);
1842
+
1843
+ return {
1844
+ ids: nodes.ids,
1845
+ features: output.nodeEmbeddings,
1846
+ types: nodes.types,
1847
+ labels: nodes.labels,
1848
+ };
1849
+ }
1850
+
1851
+ private getNeighbors(
1852
+ nodeIdx: number,
1853
+ edgeIndex: [number[], number[]],
1854
+ numNodes: number
1855
+ ): number[] {
1856
+ const neighbors = new Set<number>();
1857
+ const [sources, targets] = edgeIndex;
1858
+
1859
+ for (let i = 0; i < sources.length; i++) {
1860
+ if (sources[i] === nodeIdx && targets[i] < numNodes) {
1861
+ neighbors.add(targets[i]);
1862
+ }
1863
+ if (targets[i] === nodeIdx && sources[i] < numNodes) {
1864
+ neighbors.add(sources[i]);
1865
+ }
1866
+ }
1867
+
1868
+ return Array.from(neighbors);
1869
+ }
1870
+
1871
+ private applyScaler(
1872
+ features: number[],
1873
+ scaler: string,
1874
+ degree: number
1875
+ ): number[] {
1876
+ switch (scaler) {
1877
+ case 'amplification':
1878
+ return features.map((x) => x * Math.log(degree + 1));
1879
+ case 'attenuation':
1880
+ return features.map((x) => x / Math.log(degree + 1));
1881
+ case 'identity':
1882
+ default:
1883
+ return features;
1884
+ }
1885
+ }
1886
+
1887
+ private projectFeatures(input: number[]): number[] {
1888
+ const output = new Array(this.config.outputDim).fill(0);
1889
+ for (let i = 0; i < this.config.outputDim; i++) {
1890
+ for (let j = 0; j < Math.min(input.length, 100); j++) {
1891
+ const weight = Math.sin((i * 100 + j) * 0.1) * 0.3;
1892
+ output[i] += (input[j] ?? 0) * weight;
1893
+ }
1894
+ }
1895
+ return output;
1896
+ }
1897
+ }
1898
+
1899
+ /**
1900
+ * FiLM (Feature-wise Linear Modulation) layer.
1901
+ */
1902
+ export class FiLMLayer extends BaseGNNLayer {
1903
+ async forward(graph: GraphData): Promise<GNNOutput> {
1904
+ const startTime = Date.now();
1905
+ const { nodeFeatures, edgeIndex, edgeFeatures } = graph;
1906
+ const numNodes = nodeFeatures.length;
1907
+ const numEdges = edgeIndex[0].length;
1908
+
1909
+ const outputFeatures: number[][] = [];
1910
+
1911
+ for (let i = 0; i < numNodes; i++) {
1912
+ const selfFeatures = nodeFeatures[i] ?? [];
1913
+
1914
+ // Compute modulation parameters from edge features
1915
+ const { gamma, beta } = this.computeModulation(edgeFeatures ?? []);
1916
+
1917
+ // Apply FiLM: gamma * x + beta
1918
+ const modulated = selfFeatures.map((x, idx) =>
1919
+ (gamma[idx % gamma.length] ?? 1) * x + (beta[idx % beta.length] ?? 0)
1920
+ );
1921
+
1922
+ outputFeatures.push(this.applyDropout(modulated.map((x) => this.applyActivation(x))));
1923
+ }
1924
+
1925
+ return {
1926
+ nodeEmbeddings: outputFeatures,
1927
+ graphEmbedding: this.aggregateMean(outputFeatures),
1928
+ stats: this.createStats(startTime, numNodes, numEdges),
1929
+ };
1930
+ }
1931
+
1932
+ async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
1933
+ const graph: GraphData = {
1934
+ nodeFeatures: nodes.features,
1935
+ edgeIndex: [
1936
+ edges.sources.map((s) => nodes.ids.indexOf(s)),
1937
+ edges.targets.map((t) => nodes.ids.indexOf(t)),
1938
+ ],
1939
+ edgeFeatures: edges.features,
1940
+ };
1941
+
1942
+ const output = await this.forward(graph);
1943
+
1944
+ return {
1945
+ ids: nodes.ids,
1946
+ features: output.nodeEmbeddings,
1947
+ types: nodes.types,
1948
+ labels: nodes.labels,
1949
+ };
1950
+ }
1951
+
1952
+ private computeModulation(edgeFeatures: number[][]): { gamma: number[]; beta: number[] } {
1953
+ const dim = this.config.outputDim;
1954
+ const gamma = new Array(dim).fill(1);
1955
+ const beta = new Array(dim).fill(0);
1956
+
1957
+ if (edgeFeatures.length > 0) {
1958
+ const meanEdge = this.aggregateMean(edgeFeatures);
1959
+ for (let i = 0; i < dim; i++) {
1960
+ gamma[i] = 1 + 0.1 * (meanEdge[i % meanEdge.length] ?? 0);
1961
+ beta[i] = 0.1 * (meanEdge[(i + dim / 2) % meanEdge.length] ?? 0);
1962
+ }
1963
+ }
1964
+
1965
+ return { gamma, beta };
1966
+ }
1967
+ }
1968
+
1969
+ /**
1970
+ * Relational Graph Convolutional Network (RGCN) layer.
1971
+ */
1972
+ export class RGCNLayer extends GCNLayer {
1973
+ override async forward(graph: GraphData): Promise<GNNOutput> {
1974
+ const startTime = Date.now();
1975
+ const { nodeFeatures, edgeIndex, edgeTypes } = graph;
1976
+ const numNodes = nodeFeatures.length;
1977
+ const numEdges = edgeIndex[0].length;
1978
+ const numRelations = this.config.params?.numRelations ?? 1;
1979
+
1980
+ // Process each relation type separately
1981
+ const relationOutputs: number[][][] = [];
1982
+
1983
+ for (let r = 0; r < numRelations; r++) {
1984
+ // Filter edges by relation type
1985
+ const relationEdges = this.filterEdgesByType(edgeIndex, edgeTypes ?? [], r);
1986
+
1987
+ // Apply GCN for this relation
1988
+ const relationGraph: GraphData = {
1989
+ nodeFeatures,
1990
+ edgeIndex: relationEdges,
1991
+ };
1992
+
1993
+ const result = await super.forward(relationGraph);
1994
+ relationOutputs.push(result.nodeEmbeddings);
1995
+ }
1996
+
1997
+ // Combine relation outputs
1998
+ const outputFeatures = this.combineRelationOutputs(relationOutputs);
1999
+
2000
+ return {
2001
+ nodeEmbeddings: outputFeatures,
2002
+ graphEmbedding: this.aggregateMean(outputFeatures),
2003
+ stats: this.createStats(startTime, numNodes, numEdges),
2004
+ };
2005
+ }
2006
+
2007
+ private filterEdgesByType(
2008
+ edgeIndex: [number[], number[]],
2009
+ edgeTypes: number[],
2010
+ targetType: number
2011
+ ): [number[], number[]] {
2012
+ const sources: number[] = [];
2013
+ const targets: number[] = [];
2014
+ const [srcArr, tgtArr] = edgeIndex;
2015
+
2016
+ for (let i = 0; i < srcArr.length; i++) {
2017
+ if (edgeTypes[i] === targetType || edgeTypes.length === 0) {
2018
+ sources.push(srcArr[i]);
2019
+ targets.push(tgtArr[i]);
2020
+ }
2021
+ }
2022
+
2023
+ return [sources, targets];
2024
+ }
2025
+
2026
+ private combineRelationOutputs(outputs: number[][][]): number[][] {
2027
+ if (outputs.length === 0) return [];
2028
+ if (outputs.length === 1) return outputs[0];
2029
+
2030
+ const numNodes = outputs[0].length;
2031
+ const result: number[][] = [];
2032
+
2033
+ for (let i = 0; i < numNodes; i++) {
2034
+ const nodeOutputs = outputs.map((o) => o[i] ?? []);
2035
+ result.push(this.aggregateMean(nodeOutputs));
2036
+ }
2037
+
2038
+ return result;
2039
+ }
2040
+ }
2041
+
2042
+ /**
2043
+ * Heterogeneous Graph Transformer (HGT) layer.
2044
+ */
2045
+ export class HGTLayer extends GATLayer {
2046
+ override async forward(graph: GraphData): Promise<GNNOutput> {
2047
+ // HGT uses type-specific transformations
2048
+ const { nodeFeatures, nodeTypes } = graph;
2049
+
2050
+ // Transform features based on node types
2051
+ const transformedFeatures = nodeFeatures.map((f, i) => {
2052
+ const nodeType = nodeTypes?.[i] ?? 0;
2053
+ return this.typeSpecificTransform(f, nodeType);
2054
+ });
2055
+
2056
+ const transformedGraph: GraphData = {
2057
+ ...graph,
2058
+ nodeFeatures: transformedFeatures,
2059
+ };
2060
+
2061
+ return super.forward(transformedGraph);
2062
+ }
2063
+
2064
+ private typeSpecificTransform(features: number[], nodeType: number): number[] {
2065
+ // Apply type-specific transformation
2066
+ return features.map((x, i) => {
2067
+ const weight = Math.sin((nodeType * this.config.inputDim + i) * 0.1);
2068
+ return x * (1 + 0.1 * weight);
2069
+ });
2070
+ }
2071
+ }
2072
+
2073
+ /**
2074
+ * Heterogeneous Attention Network (HAN) layer.
2075
+ */
2076
+ export class HANLayer extends GATLayer {
2077
+ override async forward(graph: GraphData): Promise<GNNOutput> {
2078
+ const metapaths = this.config.params?.metapaths ?? [];
2079
+
2080
+ if (metapaths.length === 0) {
2081
+ return super.forward(graph);
2082
+ }
2083
+
2084
+ // Process each metapath
2085
+ const metapathOutputs: number[][][] = [];
2086
+
2087
+ for (const metapath of metapaths) {
2088
+ const metapathGraph = this.extractMetapathSubgraph(graph, metapath);
2089
+ const result = await super.forward(metapathGraph);
2090
+ metapathOutputs.push(result.nodeEmbeddings);
2091
+ }
2092
+
2093
+ // Attention over metapaths
2094
+ const outputFeatures = this.attentionOverMetapaths(metapathOutputs);
2095
+
2096
+ return {
2097
+ nodeEmbeddings: outputFeatures,
2098
+ graphEmbedding: this.aggregateMean(outputFeatures),
2099
+ stats: {
2100
+ forwardTimeMs: 0,
2101
+ numNodes: graph.nodeFeatures.length,
2102
+ numEdges: graph.edgeIndex[0].length,
2103
+ memoryBytes: 0,
2104
+ numIterations: metapaths.length,
2105
+ },
2106
+ };
2107
+ }
2108
+
2109
+ private extractMetapathSubgraph(graph: GraphData, _metapath: string[]): GraphData {
2110
+ // Simplified: return original graph
2111
+ // In practice, would filter edges based on metapath
2112
+ // The _metapath parameter would be used to filter edge types
2113
+ return graph;
2114
+ }
2115
+
2116
+ private attentionOverMetapaths(outputs: number[][][]): number[][] {
2117
+ if (outputs.length === 0) return [];
2118
+ if (outputs.length === 1) return outputs[0];
2119
+
2120
+ const numNodes = outputs[0].length;
2121
+ const result: number[][] = [];
2122
+
2123
+ // Compute attention weights for metapaths
2124
+ const metapathWeights = outputs.map((o) => {
2125
+ const importance = o.reduce(
2126
+ (sum, node) => sum + node.reduce((s, v) => s + Math.abs(v), 0),
2127
+ 0
2128
+ );
2129
+ return importance;
2130
+ });
2131
+
2132
+ const maxWeight = Math.max(...metapathWeights);
2133
+ const expWeights = metapathWeights.map((w) => Math.exp((w - maxWeight) / 10));
2134
+ const sumExp = expWeights.reduce((a, b) => a + b, 0);
2135
+ const normalizedWeights = expWeights.map((w) => w / sumExp);
2136
+
2137
+ for (let i = 0; i < numNodes; i++) {
2138
+ const dim = outputs[0][i]?.length ?? 0;
2139
+ const combined = new Array(dim).fill(0);
2140
+
2141
+ for (let m = 0; m < outputs.length; m++) {
2142
+ const nodeFeatures = outputs[m][i] ?? [];
2143
+ for (let j = 0; j < dim; j++) {
2144
+ combined[j] += normalizedWeights[m] * (nodeFeatures[j] ?? 0);
2145
+ }
2146
+ }
2147
+
2148
+ result.push(combined);
2149
+ }
2150
+
2151
+ return result;
2152
+ }
2153
+ }
2154
+
2155
+ /**
2156
+ * MetaPath-based aggregation layer.
2157
+ */
2158
+ export class MetaPathLayer extends HANLayer {
2159
+ // Extends HAN with metapath-specific functionality
2160
+ }
2161
+
2162
+ // ============================================================================
2163
+ // Graph Operations
2164
+ // ============================================================================
2165
+
2166
+ /**
2167
+ * Graph operations for advanced graph analytics.
2168
+ *
2169
+ * @example
2170
+ * ```typescript
2171
+ * const ops = new GraphOperations();
2172
+ * const neighbors = await ops.kHopNeighbors('node1', 2);
2173
+ * const path = await ops.shortestPath('source', 'target');
2174
+ * const ranks = await ops.pageRank({ damping: 0.85 });
2175
+ * const communities = await ops.communityDetection({ algorithm: 'louvain' });
2176
+ * ```
2177
+ */
2178
+ export class GraphOperations {
2179
+ private adjacencyList: Map<NodeId, Set<NodeId>> = new Map();
2180
+ private weights: Map<string, number> = new Map();
2181
+ private nodeFeatures: Map<NodeId, number[]> = new Map();
2182
+
2183
+ /**
2184
+ * Load graph data.
2185
+ */
2186
+ loadGraph(graph: GraphData): void {
2187
+ this.adjacencyList.clear();
2188
+ this.weights.clear();
2189
+ this.nodeFeatures.clear();
2190
+
2191
+ const { nodeFeatures, edgeIndex, edgeWeights } = graph;
2192
+ const [sources, targets] = edgeIndex;
2193
+
2194
+ // Initialize nodes
2195
+ for (let i = 0; i < nodeFeatures.length; i++) {
2196
+ this.adjacencyList.set(i, new Set());
2197
+ this.nodeFeatures.set(i, nodeFeatures[i]);
2198
+ }
2199
+
2200
+ // Add edges
2201
+ for (let i = 0; i < sources.length; i++) {
2202
+ const src = sources[i];
2203
+ const tgt = targets[i];
2204
+ const weight = edgeWeights?.[i] ?? 1;
2205
+
2206
+ this.adjacencyList.get(src)?.add(tgt);
2207
+ this.adjacencyList.get(tgt)?.add(src);
2208
+ this.weights.set(`${src}-${tgt}`, weight);
2209
+ this.weights.set(`${tgt}-${src}`, weight);
2210
+ }
2211
+ }
2212
+
2213
+ /**
2214
+ * Find k-hop neighbors of a node.
2215
+ */
2216
+ async kHopNeighbors(nodeId: NodeId, k: number): Promise<NodeId[]> {
2217
+ const visited = new Set<NodeId>();
2218
+ const queue: { node: NodeId; depth: number }[] = [{ node: nodeId, depth: 0 }];
2219
+ const result: NodeId[] = [];
2220
+
2221
+ while (queue.length > 0) {
2222
+ const { node, depth } = queue.shift()!;
2223
+
2224
+ if (visited.has(node)) continue;
2225
+ visited.add(node);
2226
+
2227
+ if (depth > 0) {
2228
+ result.push(node);
2229
+ }
2230
+
2231
+ if (depth < k) {
2232
+ const neighbors = this.adjacencyList.get(node) ?? new Set();
2233
+ for (const neighbor of neighbors) {
2234
+ if (!visited.has(neighbor)) {
2235
+ queue.push({ node: neighbor, depth: depth + 1 });
2236
+ }
2237
+ }
2238
+ }
2239
+ }
2240
+
2241
+ return result;
2242
+ }
2243
+
2244
+ /**
2245
+ * Find shortest path between two nodes using Dijkstra's algorithm.
2246
+ */
2247
+ async shortestPath(source: NodeId, target: NodeId): Promise<Path> {
2248
+ const distances = new Map<NodeId, number>();
2249
+ const previous = new Map<NodeId, NodeId | null>();
2250
+ const unvisited = new Set<NodeId>(this.adjacencyList.keys());
2251
+
2252
+ for (const node of this.adjacencyList.keys()) {
2253
+ distances.set(node, Infinity);
2254
+ previous.set(node, null);
2255
+ }
2256
+ distances.set(source, 0);
2257
+
2258
+ while (unvisited.size > 0) {
2259
+ // Find minimum distance node
2260
+ let current: NodeId | null = null;
2261
+ let minDist = Infinity;
2262
+
2263
+ for (const node of unvisited) {
2264
+ const dist = distances.get(node) ?? Infinity;
2265
+ if (dist < minDist) {
2266
+ minDist = dist;
2267
+ current = node;
2268
+ }
2269
+ }
2270
+
2271
+ if (current === null || current === target) break;
2272
+
2273
+ unvisited.delete(current);
2274
+
2275
+ // Update neighbors
2276
+ const neighbors = this.adjacencyList.get(current) ?? new Set();
2277
+ for (const neighbor of neighbors) {
2278
+ if (!unvisited.has(neighbor)) continue;
2279
+
2280
+ const edgeWeight = this.weights.get(`${current}-${neighbor}`) ?? 1;
2281
+ const alt = (distances.get(current) ?? Infinity) + edgeWeight;
2282
+
2283
+ if (alt < (distances.get(neighbor) ?? Infinity)) {
2284
+ distances.set(neighbor, alt);
2285
+ previous.set(neighbor, current);
2286
+ }
2287
+ }
2288
+ }
2289
+
2290
+ // Reconstruct path
2291
+ const nodes: NodeId[] = [];
2292
+ let current: NodeId | null = target;
2293
+
2294
+ while (current !== null) {
2295
+ nodes.unshift(current);
2296
+ current = previous.get(current) ?? null;
2297
+ }
2298
+
2299
+ if (nodes[0] !== source) {
2300
+ return { nodes: [], weight: Infinity };
2301
+ }
2302
+
2303
+ return {
2304
+ nodes,
2305
+ weight: distances.get(target) ?? Infinity,
2306
+ };
2307
+ }
2308
+
2309
+ /**
2310
+ * Compute PageRank scores for all nodes.
2311
+ */
2312
+ async pageRank(options: PageRankOptions = {}): Promise<Map<NodeId, number>> {
2313
+ const damping = options.damping ?? 0.85;
2314
+ const maxIterations = options.maxIterations ?? 100;
2315
+ const tolerance = options.tolerance ?? 1e-6;
2316
+
2317
+ const nodes = Array.from(this.adjacencyList.keys());
2318
+ const n = nodes.length;
2319
+
2320
+ if (n === 0) return new Map();
2321
+
2322
+ // Initialize ranks
2323
+ let ranks = new Map<NodeId, number>();
2324
+ const initialRank = 1 / n;
2325
+
2326
+ for (const node of nodes) {
2327
+ ranks.set(node, options.personalization?.get(node) ?? initialRank);
2328
+ }
2329
+
2330
+ // Power iteration
2331
+ for (let iter = 0; iter < maxIterations; iter++) {
2332
+ const newRanks = new Map<NodeId, number>();
2333
+ let diff = 0;
2334
+
2335
+ for (const node of nodes) {
2336
+ let sum = 0;
2337
+ const neighbors = this.adjacencyList.get(node) ?? new Set();
2338
+
2339
+ for (const neighbor of neighbors) {
2340
+ const neighborOutDegree = this.adjacencyList.get(neighbor)?.size ?? 1;
2341
+ const neighborRank = ranks.get(neighbor) ?? 0;
2342
+
2343
+ if (options.weighted) {
2344
+ const weight = this.weights.get(`${neighbor}-${node}`) ?? 1;
2345
+ sum += (neighborRank * weight) / neighborOutDegree;
2346
+ } else {
2347
+ sum += neighborRank / neighborOutDegree;
2348
+ }
2349
+ }
2350
+
2351
+ const teleport = options.personalization?.get(node) ?? 1 / n;
2352
+ const newRank = (1 - damping) * teleport + damping * sum;
2353
+ newRanks.set(node, newRank);
2354
+
2355
+ diff += Math.abs(newRank - (ranks.get(node) ?? 0));
2356
+ }
2357
+
2358
+ ranks = newRanks;
2359
+
2360
+ if (diff < tolerance) break;
2361
+ }
2362
+
2363
+ return ranks;
2364
+ }
2365
+
2366
+ /**
2367
+ * Detect communities in the graph.
2368
+ */
2369
+ async communityDetection(options: CommunityOptions): Promise<Community[]> {
2370
+ switch (options.algorithm) {
2371
+ case 'louvain':
2372
+ return this.louvainCommunityDetection(options);
2373
+ case 'label_propagation':
2374
+ return this.labelPropagationCommunityDetection(options);
2375
+ case 'girvan_newman':
2376
+ return this.girvanNewmanCommunityDetection(options);
2377
+ case 'spectral':
2378
+ return this.spectralCommunityDetection(options);
2379
+ default:
2380
+ return this.louvainCommunityDetection(options);
2381
+ }
2382
+ }
2383
+
2384
+ private async louvainCommunityDetection(options: CommunityOptions): Promise<Community[]> {
2385
+ const nodes = Array.from(this.adjacencyList.keys());
2386
+ const resolution = options.resolution ?? 1.0;
2387
+ const maxIterations = options.maxIterations ?? 100;
2388
+
2389
+ // Initialize: each node is its own community
2390
+ const community = new Map<NodeId, number>();
2391
+ let nextCommunityId = 0;
2392
+
2393
+ for (const node of nodes) {
2394
+ community.set(node, nextCommunityId++);
2395
+ }
2396
+
2397
+ // Compute total edge weight
2398
+ let totalWeight = 0;
2399
+ for (const weight of this.weights.values()) {
2400
+ totalWeight += weight;
2401
+ }
2402
+ totalWeight /= 2; // Undirected edges counted twice
2403
+
2404
+ // Phase 1: Local moving
2405
+ for (let iter = 0; iter < maxIterations; iter++) {
2406
+ let improved = false;
2407
+
2408
+ for (const node of nodes) {
2409
+ const currentCommunity = community.get(node)!;
2410
+ const neighbors = this.adjacencyList.get(node) ?? new Set();
2411
+
2412
+ // Find neighbor communities
2413
+ const neighborCommunities = new Set<number>();
2414
+ for (const neighbor of neighbors) {
2415
+ neighborCommunities.add(community.get(neighbor)!);
2416
+ }
2417
+
2418
+ // Find best community
2419
+ let bestCommunity = currentCommunity;
2420
+ let bestModularityGain = 0;
2421
+
2422
+ for (const targetCommunity of neighborCommunities) {
2423
+ if (targetCommunity === currentCommunity) continue;
2424
+
2425
+ const gain = this.modularityGain(
2426
+ node,
2427
+ currentCommunity,
2428
+ targetCommunity,
2429
+ community,
2430
+ resolution,
2431
+ totalWeight
2432
+ );
2433
+
2434
+ if (gain > bestModularityGain) {
2435
+ bestModularityGain = gain;
2436
+ bestCommunity = targetCommunity;
2437
+ }
2438
+ }
2439
+
2440
+ if (bestCommunity !== currentCommunity) {
2441
+ community.set(node, bestCommunity);
2442
+ improved = true;
2443
+ }
2444
+ }
2445
+
2446
+ if (!improved) break;
2447
+ }
2448
+
2449
+ // Build communities
2450
+ const communityMembers = new Map<number, NodeId[]>();
2451
+ for (const [node, commId] of community.entries()) {
2452
+ if (!communityMembers.has(commId)) {
2453
+ communityMembers.set(commId, []);
2454
+ }
2455
+ communityMembers.get(commId)!.push(node);
2456
+ }
2457
+
2458
+ // Filter by minimum size
2459
+ const minSize = options.minSize ?? 1;
2460
+ const communities: Community[] = [];
2461
+
2462
+ for (const [id, members] of communityMembers.entries()) {
2463
+ if (members.length >= minSize) {
2464
+ communities.push({
2465
+ id,
2466
+ members,
2467
+ centroid: this.computeCentroid(members),
2468
+ modularity: this.computeModularity(members, community, totalWeight),
2469
+ density: this.computeDensity(members),
2470
+ });
2471
+ }
2472
+ }
2473
+
2474
+ return communities;
2475
+ }
2476
+
2477
+ private modularityGain(
2478
+ node: NodeId,
2479
+ fromCommunity: number,
2480
+ toCommunity: number,
2481
+ community: Map<NodeId, number>,
2482
+ resolution: number,
2483
+ totalWeight: number
2484
+ ): number {
2485
+ const neighbors = this.adjacencyList.get(node) ?? new Set();
2486
+ let linksToCommunity = 0;
2487
+ let linksFromCommunity = 0;
2488
+
2489
+ for (const neighbor of neighbors) {
2490
+ const neighborCommunity = community.get(neighbor)!;
2491
+ const weight = this.weights.get(`${node}-${neighbor}`) ?? 1;
2492
+
2493
+ if (neighborCommunity === toCommunity) {
2494
+ linksToCommunity += weight;
2495
+ }
2496
+ if (neighborCommunity === fromCommunity) {
2497
+ linksFromCommunity += weight;
2498
+ }
2499
+ }
2500
+
2501
+ const nodeDegree = neighbors.size;
2502
+
2503
+ return (
2504
+ (linksToCommunity - linksFromCommunity) / totalWeight -
2505
+ (resolution * nodeDegree * (linksToCommunity - linksFromCommunity)) /
2506
+ (2 * totalWeight * totalWeight)
2507
+ );
2508
+ }
2509
+
2510
+ private computeCentroid(members: NodeId[]): number[] | undefined {
2511
+ if (members.length === 0) return undefined;
2512
+
2513
+ const features = members.map((m) => this.nodeFeatures.get(m) ?? []);
2514
+ if (features[0]?.length === 0) return undefined;
2515
+
2516
+ const dim = features[0].length;
2517
+ const centroid = new Array(dim).fill(0);
2518
+
2519
+ for (const f of features) {
2520
+ for (let i = 0; i < dim; i++) {
2521
+ centroid[i] += (f[i] ?? 0) / members.length;
2522
+ }
2523
+ }
2524
+
2525
+ return centroid;
2526
+ }
2527
+
2528
+ private computeModularity(
2529
+ members: NodeId[],
2530
+ _community: Map<NodeId, number>,
2531
+ totalWeight: number
2532
+ ): number {
2533
+ // Note: _community is passed for potential future use in computing inter-community edges
2534
+ let internalEdges = 0;
2535
+ let totalDegree = 0;
2536
+
2537
+ for (const node of members) {
2538
+ const neighbors = this.adjacencyList.get(node) ?? new Set();
2539
+ totalDegree += neighbors.size;
2540
+
2541
+ for (const neighbor of neighbors) {
2542
+ if (members.includes(neighbor)) {
2543
+ internalEdges += this.weights.get(`${node}-${neighbor}`) ?? 1;
2544
+ }
2545
+ }
2546
+ }
2547
+
2548
+ internalEdges /= 2; // Counted twice
2549
+ const expected = (totalDegree * totalDegree) / (4 * totalWeight);
2550
+
2551
+ return (internalEdges - expected) / totalWeight;
2552
+ }
2553
+
2554
+ private computeDensity(members: NodeId[]): number {
2555
+ if (members.length <= 1) return 1;
2556
+
2557
+ let edges = 0;
2558
+ for (const node of members) {
2559
+ const neighbors = this.adjacencyList.get(node) ?? new Set();
2560
+ for (const neighbor of neighbors) {
2561
+ if (members.includes(neighbor)) {
2562
+ edges++;
2563
+ }
2564
+ }
2565
+ }
2566
+
2567
+ edges /= 2;
2568
+ const maxEdges = (members.length * (members.length - 1)) / 2;
2569
+
2570
+ return edges / maxEdges;
2571
+ }
2572
+
2573
+ private async labelPropagationCommunityDetection(options: CommunityOptions): Promise<Community[]> {
2574
+ const nodes = Array.from(this.adjacencyList.keys());
2575
+ const maxIterations = options.maxIterations ?? 100;
2576
+
2577
+ // Initialize labels
2578
+ const labels = new Map<NodeId, number>();
2579
+ let nextLabel = 0;
2580
+ for (const node of nodes) {
2581
+ labels.set(node, nextLabel++);
2582
+ }
2583
+
2584
+ // Iterate
2585
+ for (let iter = 0; iter < maxIterations; iter++) {
2586
+ let changed = false;
2587
+
2588
+ // Shuffle nodes
2589
+ const shuffled = [...nodes].sort(() => Math.random() - 0.5);
2590
+
2591
+ for (const node of shuffled) {
2592
+ const neighbors = this.adjacencyList.get(node) ?? new Set();
2593
+ if (neighbors.size === 0) continue;
2594
+
2595
+ // Count neighbor labels
2596
+ const labelCounts = new Map<number, number>();
2597
+ for (const neighbor of neighbors) {
2598
+ const label = labels.get(neighbor)!;
2599
+ labelCounts.set(label, (labelCounts.get(label) ?? 0) + 1);
2600
+ }
2601
+
2602
+ // Find most common label
2603
+ let maxCount = 0;
2604
+ let bestLabel = labels.get(node)!;
2605
+ for (const [label, count] of labelCounts.entries()) {
2606
+ if (count > maxCount) {
2607
+ maxCount = count;
2608
+ bestLabel = label;
2609
+ }
2610
+ }
2611
+
2612
+ if (bestLabel !== labels.get(node)) {
2613
+ labels.set(node, bestLabel);
2614
+ changed = true;
2615
+ }
2616
+ }
2617
+
2618
+ if (!changed) break;
2619
+ }
2620
+
2621
+ // Build communities
2622
+ const communityMembers = new Map<number, NodeId[]>();
2623
+ for (const [node, label] of labels.entries()) {
2624
+ if (!communityMembers.has(label)) {
2625
+ communityMembers.set(label, []);
2626
+ }
2627
+ communityMembers.get(label)!.push(node);
2628
+ }
2629
+
2630
+ return Array.from(communityMembers.entries()).map(([id, members]) => ({
2631
+ id,
2632
+ members,
2633
+ centroid: this.computeCentroid(members),
2634
+ density: this.computeDensity(members),
2635
+ }));
2636
+ }
2637
+
2638
+ private async girvanNewmanCommunityDetection(options: CommunityOptions): Promise<Community[]> {
2639
+ // Simplified Girvan-Newman (edge betweenness)
2640
+ // In practice, this would iteratively remove high-betweenness edges
2641
+ return this.labelPropagationCommunityDetection(options);
2642
+ }
2643
+
2644
+ private async spectralCommunityDetection(options: CommunityOptions): Promise<Community[]> {
2645
+ // Simplified spectral clustering
2646
+ // In practice, would use eigendecomposition of Laplacian
2647
+ return this.labelPropagationCommunityDetection(options);
2648
+ }
2649
+
2650
+ /**
2651
+ * Generate SQL for k-hop neighbors query.
2652
+ */
2653
+ kHopNeighborsSQL(nodeId: string, k: number, tableName: string, options: SQLGenerationOptions = {}): string {
2654
+ const schema = options.schema ?? 'public';
2655
+ const edgeTable = options.edgeTable ?? `${tableName}_edges`;
2656
+
2657
+ return `
2658
+ WITH RECURSIVE k_hop AS (
2659
+ SELECT source_id AS node_id, 1 AS depth
2660
+ FROM "${schema}"."${edgeTable}"
2661
+ WHERE target_id = '${nodeId}'
2662
+ UNION
2663
+ SELECT target_id AS node_id, 1 AS depth
2664
+ FROM "${schema}"."${edgeTable}"
2665
+ WHERE source_id = '${nodeId}'
2666
+ UNION ALL
2667
+ SELECT e.target_id AS node_id, kh.depth + 1
2668
+ FROM k_hop kh
2669
+ JOIN "${schema}"."${edgeTable}" e ON kh.node_id = e.source_id
2670
+ WHERE kh.depth < ${k}
2671
+ UNION ALL
2672
+ SELECT e.source_id AS node_id, kh.depth + 1
2673
+ FROM k_hop kh
2674
+ JOIN "${schema}"."${edgeTable}" e ON kh.node_id = e.target_id
2675
+ WHERE kh.depth < ${k}
2676
+ )
2677
+ SELECT DISTINCT node_id FROM k_hop WHERE node_id != '${nodeId}';`.trim();
2678
+ }
2679
+
2680
+ /**
2681
+ * Generate SQL for shortest path query.
2682
+ */
2683
+ shortestPathSQL(source: string, target: string, tableName: string, options: SQLGenerationOptions = {}): string {
2684
+ const schema = options.schema ?? 'public';
2685
+ const edgeTable = options.edgeTable ?? `${tableName}_edges`;
2686
+
2687
+ return `
2688
+ WITH RECURSIVE path AS (
2689
+ SELECT
2690
+ source_id,
2691
+ target_id,
2692
+ ARRAY[source_id, target_id] AS path,
2693
+ weight AS total_weight,
2694
+ 1 AS depth
2695
+ FROM "${schema}"."${edgeTable}"
2696
+ WHERE source_id = '${source}'
2697
+ UNION ALL
2698
+ SELECT
2699
+ p.source_id,
2700
+ e.target_id,
2701
+ p.path || e.target_id,
2702
+ p.total_weight + e.weight,
2703
+ p.depth + 1
2704
+ FROM path p
2705
+ JOIN "${schema}"."${edgeTable}" e ON p.target_id = e.source_id
2706
+ WHERE NOT e.target_id = ANY(p.path)
2707
+ AND p.depth < 10
2708
+ )
2709
+ SELECT path, total_weight
2710
+ FROM path
2711
+ WHERE target_id = '${target}'
2712
+ ORDER BY total_weight
2713
+ LIMIT 1;`.trim();
2714
+ }
2715
+
2716
+ /**
2717
+ * Generate SQL for PageRank computation.
2718
+ */
2719
+ pageRankSQL(tableName: string, options: PageRankOptions & SQLGenerationOptions = {}): string {
2720
+ const schema = options.schema ?? 'public';
2721
+ const edgeTable = options.edgeTable ?? `${tableName}_edges`;
2722
+ const damping = options.damping ?? 0.85;
2723
+ const maxIterations = options.maxIterations ?? 100;
2724
+
2725
+ return `
2726
+ SELECT ruvector.page_rank(
2727
+ (SELECT array_agg(ARRAY[source_id::text, target_id::text]) FROM "${schema}"."${edgeTable}"),
2728
+ ${damping},
2729
+ ${maxIterations}
2730
+ );`.trim();
2731
+ }
2732
+
2733
+ /**
2734
+ * Generate SQL for community detection.
2735
+ */
2736
+ communityDetectionSQL(tableName: string, options: CommunityOptions & SQLGenerationOptions): string {
2737
+ const schema = options.schema ?? 'public';
2738
+ const edgeTable = options.edgeTable ?? `${tableName}_edges`;
2739
+ const algorithm = options.algorithm ?? 'louvain';
2740
+ const resolution = options.resolution ?? 1.0;
2741
+
2742
+ return `
2743
+ SELECT ruvector.community_detection(
2744
+ (SELECT array_agg(ARRAY[source_id::text, target_id::text]) FROM "${schema}"."${edgeTable}"),
2745
+ '${algorithm}',
2746
+ ${resolution}
2747
+ );`.trim();
2748
+ }
2749
+ }
2750
+
2751
+ // ============================================================================
2752
+ // SQL Generator for GNN Operations
2753
+ // ============================================================================
2754
+
2755
+ /**
2756
+ * SQL generator for GNN operations in PostgreSQL with RuVector.
2757
+ */
2758
+ export class GNNSQLGenerator {
2759
+ /**
2760
+ * Generate SQL for GNN layer forward pass.
2761
+ */
2762
+ static layerForwardSQL(
2763
+ layer: IGNNLayer,
2764
+ tableName: string,
2765
+ options: SQLGenerationOptions = {}
2766
+ ): string {
2767
+ return layer.toSQL(tableName, options);
2768
+ }
2769
+
2770
+ /**
2771
+ * Generate SQL for batch GNN operations.
2772
+ */
2773
+ static batchGNNSQL(
2774
+ layers: IGNNLayer[],
2775
+ tableName: string,
2776
+ options: SQLGenerationOptions = {}
2777
+ ): string {
2778
+ const schema = options.schema ?? 'public';
2779
+ const nodeColumn = options.nodeColumn ?? 'embedding';
2780
+ const edgeTable = options.edgeTable ?? `${tableName}_edges`;
2781
+
2782
+ const layerConfigs = layers.map((l) => ({
2783
+ type: l.type,
2784
+ input_dim: l.config.inputDim,
2785
+ output_dim: l.config.outputDim,
2786
+ num_heads: l.config.numHeads,
2787
+ dropout: l.config.dropout,
2788
+ aggregation: l.config.aggregation,
2789
+ params: l.config.params,
2790
+ }));
2791
+
2792
+ return `
2793
+ SELECT ruvector.batch_gnn_forward(
2794
+ (SELECT array_agg(${nodeColumn}) FROM "${schema}"."${tableName}"),
2795
+ (SELECT array_agg(ARRAY[source_id, target_id]) FROM "${schema}"."${edgeTable}"),
2796
+ '${JSON.stringify(layerConfigs)}'::jsonb
2797
+ );`.trim();
2798
+ }
2799
+
2800
+ /**
2801
+ * Generate SQL for caching computed embeddings.
2802
+ */
2803
+ static cacheEmbeddingsSQL(
2804
+ tableName: string,
2805
+ cacheTable: string,
2806
+ options: SQLGenerationOptions = {}
2807
+ ): string {
2808
+ const schema = options.schema ?? 'public';
2809
+
2810
+ return `
2811
+ INSERT INTO "${schema}"."${cacheTable}" (node_id, embedding, computed_at)
2812
+ SELECT
2813
+ id,
2814
+ ${options.nodeColumn ?? 'embedding'},
2815
+ NOW()
2816
+ FROM "${schema}"."${tableName}"
2817
+ ON CONFLICT (node_id)
2818
+ DO UPDATE SET
2819
+ embedding = EXCLUDED.embedding,
2820
+ computed_at = NOW();`.trim();
2821
+ }
2822
+
2823
+ /**
2824
+ * Generate SQL for creating GNN cache table.
2825
+ */
2826
+ static createCacheTableSQL(
2827
+ cacheTable: string,
2828
+ dimension: number,
2829
+ options: SQLGenerationOptions = {}
2830
+ ): string {
2831
+ const schema = options.schema ?? 'public';
2832
+
2833
+ return `
2834
+ CREATE TABLE IF NOT EXISTS "${schema}"."${cacheTable}" (
2835
+ node_id TEXT PRIMARY KEY,
2836
+ embedding vector(${dimension}) NOT NULL,
2837
+ computed_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
2838
+ layer_config JSONB,
2839
+ version INTEGER DEFAULT 1
2840
+ );
2841
+
2842
+ CREATE INDEX IF NOT EXISTS "${cacheTable}_computed_at_idx"
2843
+ ON "${schema}"."${cacheTable}" (computed_at);`.trim();
2844
+ }
2845
+
2846
+ /**
2847
+ * Generate SQL for message passing operation.
2848
+ */
2849
+ static messagePassingSQL(
2850
+ tableName: string,
2851
+ aggregation: GNNAggregation,
2852
+ options: SQLGenerationOptions = {}
2853
+ ): string {
2854
+ const schema = options.schema ?? 'public';
2855
+ const nodeColumn = options.nodeColumn ?? 'embedding';
2856
+ const edgeTable = options.edgeTable ?? `${tableName}_edges`;
2857
+
2858
+ const aggFunctionMap: Record<GNNAggregation, string> = {
2859
+ mean: 'avg',
2860
+ sum: 'sum',
2861
+ max: 'max',
2862
+ min: 'min',
2863
+ attention: 'attention_avg',
2864
+ lstm: 'lstm_agg',
2865
+ softmax: 'softmax_avg',
2866
+ power_mean: 'power_mean',
2867
+ std: 'std',
2868
+ var: 'var',
2869
+ };
2870
+ const aggFunction = aggFunctionMap[aggregation] ?? 'avg';
2871
+
2872
+ return `
2873
+ SELECT
2874
+ n.id,
2875
+ ruvector.vector_${aggFunction}(array_agg(neighbor.${nodeColumn})) AS aggregated_embedding
2876
+ FROM "${schema}"."${tableName}" n
2877
+ LEFT JOIN "${schema}"."${edgeTable}" e ON n.id = e.target_id
2878
+ LEFT JOIN "${schema}"."${tableName}" neighbor ON e.source_id = neighbor.id
2879
+ GROUP BY n.id;`.trim();
2880
+ }
2881
+
2882
+ /**
2883
+ * Generate SQL for graph pooling.
2884
+ */
2885
+ static graphPoolingSQL(
2886
+ tableName: string,
2887
+ poolingMethod: 'mean' | 'sum' | 'max' | 'attention',
2888
+ options: SQLGenerationOptions = {}
2889
+ ): string {
2890
+ const schema = options.schema ?? 'public';
2891
+ const nodeColumn = options.nodeColumn ?? 'embedding';
2892
+
2893
+ const poolFunction = {
2894
+ mean: 'vector_avg',
2895
+ sum: 'vector_sum',
2896
+ max: 'vector_max',
2897
+ attention: 'vector_attention_pool',
2898
+ }[poolingMethod] ?? 'vector_avg';
2899
+
2900
+ return `
2901
+ SELECT ruvector.${poolFunction}(
2902
+ (SELECT array_agg(${nodeColumn}) FROM "${schema}"."${tableName}")
2903
+ ) AS graph_embedding;`.trim();
2904
+ }
2905
+ }
2906
+
2907
+ // ============================================================================
2908
+ // Embedding Cache Manager
2909
+ // ============================================================================
2910
+
2911
+ /**
2912
+ * Manager for caching computed GNN embeddings.
2913
+ */
2914
+ export class GNNEmbeddingCache {
2915
+ private cache: Map<string, { embedding: number[]; timestamp: number; version: number }> =
2916
+ new Map();
2917
+ private maxSize: number;
2918
+ private ttlMs: number;
2919
+
2920
+ constructor(maxSize: number = 10000, ttlMs: number = 3600000) {
2921
+ this.maxSize = maxSize;
2922
+ this.ttlMs = ttlMs;
2923
+ }
2924
+
2925
+ /**
2926
+ * Get cached embedding.
2927
+ */
2928
+ get(nodeId: NodeId, version?: number): number[] | undefined {
2929
+ const key = String(nodeId);
2930
+ const entry = this.cache.get(key);
2931
+
2932
+ if (!entry) return undefined;
2933
+
2934
+ // Check TTL
2935
+ if (Date.now() - entry.timestamp > this.ttlMs) {
2936
+ this.cache.delete(key);
2937
+ return undefined;
2938
+ }
2939
+
2940
+ // Check version
2941
+ if (version !== undefined && entry.version !== version) {
2942
+ return undefined;
2943
+ }
2944
+
2945
+ return entry.embedding;
2946
+ }
2947
+
2948
+ /**
2949
+ * Set cached embedding.
2950
+ */
2951
+ set(nodeId: NodeId, embedding: number[], version: number = 1): void {
2952
+ // Evict if at capacity
2953
+ if (this.cache.size >= this.maxSize) {
2954
+ this.evictOldest();
2955
+ }
2956
+
2957
+ this.cache.set(String(nodeId), {
2958
+ embedding,
2959
+ timestamp: Date.now(),
2960
+ version,
2961
+ });
2962
+ }
2963
+
2964
+ /**
2965
+ * Batch get embeddings.
2966
+ */
2967
+ getBatch(nodeIds: NodeId[], version?: number): Map<NodeId, number[]> {
2968
+ const result = new Map<NodeId, number[]>();
2969
+
2970
+ for (const id of nodeIds) {
2971
+ const embedding = this.get(id, version);
2972
+ if (embedding) {
2973
+ result.set(id, embedding);
2974
+ }
2975
+ }
2976
+
2977
+ return result;
2978
+ }
2979
+
2980
+ /**
2981
+ * Batch set embeddings.
2982
+ */
2983
+ setBatch(embeddings: Map<NodeId, number[]>, version: number = 1): void {
2984
+ for (const [id, embedding] of embeddings.entries()) {
2985
+ this.set(id, embedding, version);
2986
+ }
2987
+ }
2988
+
2989
+ /**
2990
+ * Clear cache.
2991
+ */
2992
+ clear(): void {
2993
+ this.cache.clear();
2994
+ }
2995
+
2996
+ /**
2997
+ * Get cache statistics.
2998
+ */
2999
+ getStats(): { size: number; maxSize: number; hitRate: number } {
3000
+ return {
3001
+ size: this.cache.size,
3002
+ maxSize: this.maxSize,
3003
+ hitRate: 0, // Would need to track hits/misses
3004
+ };
3005
+ }
3006
+
3007
+ private evictOldest(): void {
3008
+ let oldestKey: string | null = null;
3009
+ let oldestTimestamp = Infinity;
3010
+
3011
+ for (const [key, entry] of this.cache.entries()) {
3012
+ if (entry.timestamp < oldestTimestamp) {
3013
+ oldestTimestamp = entry.timestamp;
3014
+ oldestKey = key;
3015
+ }
3016
+ }
3017
+
3018
+ if (oldestKey) {
3019
+ this.cache.delete(oldestKey);
3020
+ }
3021
+ }
3022
+ }
3023
+
3024
+ // ============================================================================
3025
+ // Factory and Default Instance
3026
+ // ============================================================================
3027
+
3028
+ /**
3029
+ * Create a default GNN layer registry with all built-in layers.
3030
+ */
3031
+ export function createGNNLayerRegistry(): GNNLayerRegistry {
3032
+ return new GNNLayerRegistry();
3033
+ }
3034
+
3035
+ /**
3036
+ * Create a GNN layer with the default registry.
3037
+ */
3038
+ export function createGNNLayer(type: GNNLayerType, config: Partial<GNNLayerConfig>): IGNNLayer {
3039
+ const registry = createGNNLayerRegistry();
3040
+ return registry.createLayer(type, config);
3041
+ }
3042
+
3043
+ /**
3044
+ * Create graph operations instance.
3045
+ */
3046
+ export function createGraphOperations(): GraphOperations {
3047
+ return new GraphOperations();
3048
+ }
3049
+
3050
+ // GNN_DEFAULTS and GNN_SQL_FUNCTIONS are already exported via export const above