@sparkleideas/plugins 3.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +401 -0
- package/__tests__/collection-manager.test.ts +332 -0
- package/__tests__/dependency-graph.test.ts +434 -0
- package/__tests__/enhanced-plugin-registry.test.ts +488 -0
- package/__tests__/plugin-registry.test.ts +368 -0
- package/__tests__/ruvector-bridge.test.ts +2429 -0
- package/__tests__/ruvector-integration.test.ts +1602 -0
- package/__tests__/ruvector-migrations.test.ts +1099 -0
- package/__tests__/ruvector-quantization.test.ts +846 -0
- package/__tests__/ruvector-streaming.test.ts +1088 -0
- package/__tests__/sdk.test.ts +325 -0
- package/__tests__/security.test.ts +348 -0
- package/__tests__/utils/ruvector-test-utils.ts +860 -0
- package/examples/plugin-creator/index.ts +636 -0
- package/examples/plugin-creator/plugin-creator.test.ts +312 -0
- package/examples/ruvector/README.md +288 -0
- package/examples/ruvector/attention-patterns.ts +394 -0
- package/examples/ruvector/basic-usage.ts +288 -0
- package/examples/ruvector/docker-compose.yml +75 -0
- package/examples/ruvector/gnn-analysis.ts +501 -0
- package/examples/ruvector/hyperbolic-hierarchies.ts +557 -0
- package/examples/ruvector/init-db.sql +119 -0
- package/examples/ruvector/quantization.ts +680 -0
- package/examples/ruvector/self-learning.ts +447 -0
- package/examples/ruvector/semantic-search.ts +576 -0
- package/examples/ruvector/streaming-large-data.ts +507 -0
- package/examples/ruvector/transactions.ts +594 -0
- package/examples/ruvector-plugins/hook-pattern-library.ts +486 -0
- package/examples/ruvector-plugins/index.ts +79 -0
- package/examples/ruvector-plugins/intent-router.ts +354 -0
- package/examples/ruvector-plugins/mcp-tool-optimizer.ts +424 -0
- package/examples/ruvector-plugins/reasoning-bank.ts +657 -0
- package/examples/ruvector-plugins/ruvector-plugins.test.ts +518 -0
- package/examples/ruvector-plugins/semantic-code-search.ts +498 -0
- package/examples/ruvector-plugins/shared/index.ts +20 -0
- package/examples/ruvector-plugins/shared/vector-utils.ts +257 -0
- package/examples/ruvector-plugins/sona-learning.ts +445 -0
- package/package.json +97 -0
- package/src/collections/collection-manager.ts +661 -0
- package/src/collections/index.ts +56 -0
- package/src/collections/official/index.ts +1040 -0
- package/src/core/base-plugin.ts +416 -0
- package/src/core/plugin-interface.ts +215 -0
- package/src/hooks/index.ts +685 -0
- package/src/index.ts +378 -0
- package/src/integrations/agentic-flow.ts +743 -0
- package/src/integrations/index.ts +88 -0
- package/src/integrations/ruvector/ARCHITECTURE.md +1245 -0
- package/src/integrations/ruvector/attention-advanced.ts +1040 -0
- package/src/integrations/ruvector/attention-executor.ts +782 -0
- package/src/integrations/ruvector/attention-mechanisms.ts +757 -0
- package/src/integrations/ruvector/attention.ts +1063 -0
- package/src/integrations/ruvector/gnn.ts +3050 -0
- package/src/integrations/ruvector/hyperbolic.ts +1948 -0
- package/src/integrations/ruvector/index.ts +394 -0
- package/src/integrations/ruvector/migrations/001_create_extension.sql +135 -0
- package/src/integrations/ruvector/migrations/002_create_vector_tables.sql +259 -0
- package/src/integrations/ruvector/migrations/003_create_indices.sql +328 -0
- package/src/integrations/ruvector/migrations/004_create_functions.sql +598 -0
- package/src/integrations/ruvector/migrations/005_create_attention_functions.sql +654 -0
- package/src/integrations/ruvector/migrations/006_create_gnn_functions.sql +728 -0
- package/src/integrations/ruvector/migrations/007_create_hyperbolic_functions.sql +762 -0
- package/src/integrations/ruvector/migrations/index.ts +35 -0
- package/src/integrations/ruvector/migrations/migrations.ts +647 -0
- package/src/integrations/ruvector/quantization.ts +2036 -0
- package/src/integrations/ruvector/ruvector-bridge.ts +2000 -0
- package/src/integrations/ruvector/self-learning.ts +2376 -0
- package/src/integrations/ruvector/streaming.ts +1737 -0
- package/src/integrations/ruvector/types.ts +1945 -0
- package/src/providers/index.ts +643 -0
- package/src/registry/dependency-graph.ts +568 -0
- package/src/registry/enhanced-plugin-registry.ts +994 -0
- package/src/registry/plugin-registry.ts +604 -0
- package/src/sdk/index.ts +563 -0
- package/src/security/index.ts +594 -0
- package/src/types/index.ts +446 -0
- package/src/workers/index.ts +700 -0
- package/tmp.json +0 -0
- package/tsconfig.json +25 -0
- package/vitest.config.ts +23 -0
|
@@ -0,0 +1,3050 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* RuVector PostgreSQL Bridge - Graph Neural Network (GNN) Layers Module
|
|
3
|
+
*
|
|
4
|
+
* Comprehensive GNN support for RuVector PostgreSQL vector database integration.
|
|
5
|
+
* Implements GCN, GAT, GraphSAGE, GIN, MPNN, EdgeConv, and more.
|
|
6
|
+
*
|
|
7
|
+
* @module @sparkleideas/plugins/integrations/ruvector/gnn
|
|
8
|
+
* @version 1.0.0
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
import type {
|
|
12
|
+
GNNLayerType,
|
|
13
|
+
GNNLayer,
|
|
14
|
+
GraphData,
|
|
15
|
+
GNNOutput,
|
|
16
|
+
GNNAggregation,
|
|
17
|
+
GNNStats,
|
|
18
|
+
ActivationFunction,
|
|
19
|
+
} from './types.js';
|
|
20
|
+
|
|
21
|
+
// ============================================================================
|
|
22
|
+
// Constants and Configuration
|
|
23
|
+
// ============================================================================
|
|
24
|
+
|
|
25
|
+
/**
|
|
26
|
+
* Default configuration values for GNN layers.
|
|
27
|
+
*/
|
|
28
|
+
export const GNN_DEFAULTS = {
|
|
29
|
+
dropout: 0.0,
|
|
30
|
+
addSelfLoops: true,
|
|
31
|
+
normalize: true,
|
|
32
|
+
useBias: true,
|
|
33
|
+
activation: 'relu' as ActivationFunction,
|
|
34
|
+
aggregation: 'mean' as GNNAggregation,
|
|
35
|
+
numHeads: 1,
|
|
36
|
+
negativeSlope: 0.2, // For LeakyReLU in GAT
|
|
37
|
+
eps: 0.0, // For GIN
|
|
38
|
+
sampleSize: 10, // For GraphSAGE
|
|
39
|
+
k: 20, // For EdgeConv k-NN
|
|
40
|
+
} as const;
|
|
41
|
+
|
|
42
|
+
/**
|
|
43
|
+
* SQL function mapping for GNN operations.
|
|
44
|
+
*/
|
|
45
|
+
export const GNN_SQL_FUNCTIONS = {
|
|
46
|
+
gcn: 'ruvector.gcn_layer',
|
|
47
|
+
gat: 'ruvector.gat_layer',
|
|
48
|
+
gat_v2: 'ruvector.gat_v2_layer',
|
|
49
|
+
sage: 'ruvector.sage_layer',
|
|
50
|
+
gin: 'ruvector.gin_layer',
|
|
51
|
+
mpnn: 'ruvector.mpnn_layer',
|
|
52
|
+
edge_conv: 'ruvector.edge_conv_layer',
|
|
53
|
+
point_conv: 'ruvector.point_conv_layer',
|
|
54
|
+
transformer: 'ruvector.graph_transformer_layer',
|
|
55
|
+
pna: 'ruvector.pna_layer',
|
|
56
|
+
film: 'ruvector.film_layer',
|
|
57
|
+
rgcn: 'ruvector.rgcn_layer',
|
|
58
|
+
hgt: 'ruvector.hgt_layer',
|
|
59
|
+
han: 'ruvector.han_layer',
|
|
60
|
+
metapath: 'ruvector.metapath_layer',
|
|
61
|
+
} as const;
|
|
62
|
+
|
|
63
|
+
// ============================================================================
|
|
64
|
+
// Core Interfaces
|
|
65
|
+
// ============================================================================
|
|
66
|
+
|
|
67
|
+
/**
|
|
68
|
+
* Node identifier type.
|
|
69
|
+
*/
|
|
70
|
+
export type NodeId = string | number;
|
|
71
|
+
|
|
72
|
+
/**
|
|
73
|
+
* Node features representation.
|
|
74
|
+
*/
|
|
75
|
+
export interface NodeFeatures {
|
|
76
|
+
/** Node IDs */
|
|
77
|
+
readonly ids: NodeId[];
|
|
78
|
+
/** Feature vectors [num_nodes, feature_dim] */
|
|
79
|
+
readonly features: number[][];
|
|
80
|
+
/** Optional node types for heterogeneous graphs */
|
|
81
|
+
readonly types?: string[];
|
|
82
|
+
/** Optional node labels */
|
|
83
|
+
readonly labels?: number[];
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
/**
|
|
87
|
+
* Edge features representation.
|
|
88
|
+
*/
|
|
89
|
+
export interface EdgeFeatures {
|
|
90
|
+
/** Source node IDs */
|
|
91
|
+
readonly sources: NodeId[];
|
|
92
|
+
/** Target node IDs */
|
|
93
|
+
readonly targets: NodeId[];
|
|
94
|
+
/** Edge feature vectors [num_edges, edge_dim] (optional) */
|
|
95
|
+
readonly features?: number[][];
|
|
96
|
+
/** Edge weights (optional) */
|
|
97
|
+
readonly weights?: number[];
|
|
98
|
+
/** Edge types for heterogeneous graphs (optional) */
|
|
99
|
+
readonly types?: string[];
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
/**
|
|
103
|
+
* Message representation for message passing.
|
|
104
|
+
*/
|
|
105
|
+
export interface Message {
|
|
106
|
+
/** Source node ID */
|
|
107
|
+
readonly source: NodeId;
|
|
108
|
+
/** Target node ID */
|
|
109
|
+
readonly target: NodeId;
|
|
110
|
+
/** Message vector */
|
|
111
|
+
readonly vector: number[];
|
|
112
|
+
/** Edge features (if applicable) */
|
|
113
|
+
readonly edgeFeatures?: number[];
|
|
114
|
+
/** Message weight */
|
|
115
|
+
readonly weight?: number;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
/**
|
|
119
|
+
* Aggregation method type with extended options.
|
|
120
|
+
*/
|
|
121
|
+
export type AggregationMethod =
|
|
122
|
+
| GNNAggregation
|
|
123
|
+
| 'concat'
|
|
124
|
+
| 'weighted_mean'
|
|
125
|
+
| 'multi_head';
|
|
126
|
+
|
|
127
|
+
/**
|
|
128
|
+
* Path representation for graph traversal.
|
|
129
|
+
*/
|
|
130
|
+
export interface Path {
|
|
131
|
+
/** Ordered list of node IDs */
|
|
132
|
+
readonly nodes: NodeId[];
|
|
133
|
+
/** Total path weight/distance */
|
|
134
|
+
readonly weight: number;
|
|
135
|
+
/** Edge types along the path (for heterogeneous graphs) */
|
|
136
|
+
readonly edgeTypes?: string[];
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
/**
|
|
140
|
+
* Community detection result.
|
|
141
|
+
*/
|
|
142
|
+
export interface Community {
|
|
143
|
+
/** Community identifier */
|
|
144
|
+
readonly id: number;
|
|
145
|
+
/** Member node IDs */
|
|
146
|
+
readonly members: NodeId[];
|
|
147
|
+
/** Community centroid (average features) */
|
|
148
|
+
readonly centroid?: number[];
|
|
149
|
+
/** Modularity score */
|
|
150
|
+
readonly modularity?: number;
|
|
151
|
+
/** Internal edge density */
|
|
152
|
+
readonly density?: number;
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
/**
|
|
156
|
+
* PageRank computation options.
|
|
157
|
+
*/
|
|
158
|
+
export interface PageRankOptions {
|
|
159
|
+
/** Damping factor (default: 0.85) */
|
|
160
|
+
readonly damping?: number;
|
|
161
|
+
/** Maximum iterations (default: 100) */
|
|
162
|
+
readonly maxIterations?: number;
|
|
163
|
+
/** Convergence tolerance (default: 1e-6) */
|
|
164
|
+
readonly tolerance?: number;
|
|
165
|
+
/** Personalization vector (teleport probabilities) */
|
|
166
|
+
readonly personalization?: Map<NodeId, number>;
|
|
167
|
+
/** Whether to use weighted edges */
|
|
168
|
+
readonly weighted?: boolean;
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
/**
|
|
172
|
+
* Community detection options.
|
|
173
|
+
*/
|
|
174
|
+
export interface CommunityOptions {
|
|
175
|
+
/** Detection algorithm */
|
|
176
|
+
readonly algorithm: 'louvain' | 'label_propagation' | 'girvan_newman' | 'spectral';
|
|
177
|
+
/** Resolution parameter (for Louvain) */
|
|
178
|
+
readonly resolution?: number;
|
|
179
|
+
/** Maximum iterations */
|
|
180
|
+
readonly maxIterations?: number;
|
|
181
|
+
/** Minimum community size */
|
|
182
|
+
readonly minSize?: number;
|
|
183
|
+
/** Random seed for reproducibility */
|
|
184
|
+
readonly seed?: number;
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
/**
|
|
188
|
+
* GNN layer configuration with validation.
|
|
189
|
+
*/
|
|
190
|
+
export interface GNNLayerConfig extends GNNLayer {
|
|
191
|
+
/** Layer name/identifier */
|
|
192
|
+
readonly name?: string;
|
|
193
|
+
/** Whether to cache intermediate results */
|
|
194
|
+
readonly cache?: boolean;
|
|
195
|
+
/** Quantization bits for memory efficiency */
|
|
196
|
+
readonly quantizeBits?: 8 | 16 | 32;
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
/**
|
|
200
|
+
* Factory function type for creating GNN layers.
|
|
201
|
+
*/
|
|
202
|
+
export type GNNLayerFactory = (config: GNNLayerConfig) => IGNNLayer;
|
|
203
|
+
|
|
204
|
+
/**
|
|
205
|
+
* Interface for GNN layer implementations.
|
|
206
|
+
*/
|
|
207
|
+
export interface IGNNLayer {
|
|
208
|
+
/** Layer type */
|
|
209
|
+
readonly type: GNNLayerType;
|
|
210
|
+
/** Layer configuration */
|
|
211
|
+
readonly config: GNNLayerConfig;
|
|
212
|
+
|
|
213
|
+
/**
|
|
214
|
+
* Forward pass through the GNN layer.
|
|
215
|
+
* @param graph - Input graph data
|
|
216
|
+
* @returns Promise resolving to GNN output
|
|
217
|
+
*/
|
|
218
|
+
forward(graph: GraphData): Promise<GNNOutput>;
|
|
219
|
+
|
|
220
|
+
/**
|
|
221
|
+
* Message passing step.
|
|
222
|
+
* @param nodes - Node features
|
|
223
|
+
* @param edges - Edge features
|
|
224
|
+
* @returns Promise resolving to updated node features
|
|
225
|
+
*/
|
|
226
|
+
messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures>;
|
|
227
|
+
|
|
228
|
+
/**
|
|
229
|
+
* Aggregate messages using the specified method.
|
|
230
|
+
* @param messages - Array of messages to aggregate
|
|
231
|
+
* @param method - Aggregation method
|
|
232
|
+
* @returns Promise resolving to aggregated vector
|
|
233
|
+
*/
|
|
234
|
+
aggregate(messages: Message[], method: AggregationMethod): Promise<number[]>;
|
|
235
|
+
|
|
236
|
+
/**
|
|
237
|
+
* Reset layer state (if stateful).
|
|
238
|
+
*/
|
|
239
|
+
reset(): void;
|
|
240
|
+
|
|
241
|
+
/**
|
|
242
|
+
* Generate SQL for this layer.
|
|
243
|
+
* @param tableName - Target table name
|
|
244
|
+
* @param options - SQL generation options
|
|
245
|
+
* @returns SQL string
|
|
246
|
+
*/
|
|
247
|
+
toSQL(tableName: string, options?: SQLGenerationOptions): string;
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
/**
|
|
251
|
+
* SQL generation options.
|
|
252
|
+
*/
|
|
253
|
+
export interface SQLGenerationOptions {
|
|
254
|
+
/** Schema name */
|
|
255
|
+
readonly schema?: string;
|
|
256
|
+
/** Node features column */
|
|
257
|
+
readonly nodeColumn?: string;
|
|
258
|
+
/** Edge table name */
|
|
259
|
+
readonly edgeTable?: string;
|
|
260
|
+
/** Whether to use prepared statements */
|
|
261
|
+
readonly prepared?: boolean;
|
|
262
|
+
/** Parameter prefix for prepared statements */
|
|
263
|
+
readonly paramPrefix?: string;
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
// ============================================================================
|
|
267
|
+
// GNN Layer Registry
|
|
268
|
+
// ============================================================================
|
|
269
|
+
|
|
270
|
+
/**
|
|
271
|
+
* Registry for managing GNN layer types and factories.
|
|
272
|
+
*
|
|
273
|
+
* @example
|
|
274
|
+
* ```typescript
|
|
275
|
+
* const registry = new GNNLayerRegistry();
|
|
276
|
+
* registry.registerLayer('custom_gnn', CustomGNNFactory);
|
|
277
|
+
* const layer = registry.createLayer('gcn', { inputDim: 64, outputDim: 32 });
|
|
278
|
+
* ```
|
|
279
|
+
*/
|
|
280
|
+
export class GNNLayerRegistry {
|
|
281
|
+
private readonly factories: Map<GNNLayerType | string, GNNLayerFactory> = new Map();
|
|
282
|
+
private readonly defaultConfigs: Map<GNNLayerType | string, Partial<GNNLayerConfig>> = new Map();
|
|
283
|
+
|
|
284
|
+
constructor() {
|
|
285
|
+
// Register built-in layer factories
|
|
286
|
+
this.registerBuiltinLayers();
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
/**
|
|
290
|
+
* Register a GNN layer factory.
|
|
291
|
+
* @param type - Layer type identifier
|
|
292
|
+
* @param factory - Factory function
|
|
293
|
+
* @param defaultConfig - Optional default configuration
|
|
294
|
+
*/
|
|
295
|
+
registerLayer(
|
|
296
|
+
type: GNNLayerType | string,
|
|
297
|
+
factory: GNNLayerFactory,
|
|
298
|
+
defaultConfig?: Partial<GNNLayerConfig>
|
|
299
|
+
): void {
|
|
300
|
+
this.factories.set(type, factory);
|
|
301
|
+
if (defaultConfig) {
|
|
302
|
+
this.defaultConfigs.set(type, defaultConfig);
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
/**
|
|
307
|
+
* Unregister a GNN layer factory.
|
|
308
|
+
* @param type - Layer type to remove
|
|
309
|
+
* @returns Whether the layer was removed
|
|
310
|
+
*/
|
|
311
|
+
unregisterLayer(type: GNNLayerType | string): boolean {
|
|
312
|
+
this.defaultConfigs.delete(type);
|
|
313
|
+
return this.factories.delete(type);
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
/**
|
|
317
|
+
* Create a GNN layer instance.
|
|
318
|
+
* @param type - Layer type
|
|
319
|
+
* @param config - Layer configuration
|
|
320
|
+
* @returns IGNNLayer instance
|
|
321
|
+
* @throws Error if layer type is not registered
|
|
322
|
+
*/
|
|
323
|
+
createLayer(type: GNNLayerType, config: Partial<GNNLayerConfig>): IGNNLayer {
|
|
324
|
+
const factory = this.factories.get(type);
|
|
325
|
+
if (!factory) {
|
|
326
|
+
throw new Error(`Unknown GNN layer type: ${type}. Available types: ${this.getLayerTypes().join(', ')}`);
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
const defaultConfig = this.defaultConfigs.get(type) ?? {};
|
|
330
|
+
const fullConfig: GNNLayerConfig = {
|
|
331
|
+
type,
|
|
332
|
+
inputDim: config.inputDim ?? 64,
|
|
333
|
+
outputDim: config.outputDim ?? 64,
|
|
334
|
+
dropout: config.dropout ?? defaultConfig.dropout ?? GNN_DEFAULTS.dropout,
|
|
335
|
+
aggregation: config.aggregation ?? defaultConfig.aggregation ?? GNN_DEFAULTS.aggregation,
|
|
336
|
+
addSelfLoops: config.addSelfLoops ?? defaultConfig.addSelfLoops ?? GNN_DEFAULTS.addSelfLoops,
|
|
337
|
+
normalize: config.normalize ?? defaultConfig.normalize ?? GNN_DEFAULTS.normalize,
|
|
338
|
+
useBias: config.useBias ?? defaultConfig.useBias ?? GNN_DEFAULTS.useBias,
|
|
339
|
+
activation: config.activation ?? defaultConfig.activation ?? GNN_DEFAULTS.activation,
|
|
340
|
+
...config,
|
|
341
|
+
};
|
|
342
|
+
|
|
343
|
+
return factory(fullConfig);
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
/**
|
|
347
|
+
* Check if a layer type is registered.
|
|
348
|
+
* @param type - Layer type to check
|
|
349
|
+
* @returns Whether the layer is registered
|
|
350
|
+
*/
|
|
351
|
+
hasLayer(type: GNNLayerType | string): boolean {
|
|
352
|
+
return this.factories.has(type);
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
/**
|
|
356
|
+
* Get all registered layer types.
|
|
357
|
+
* @returns Array of layer type identifiers
|
|
358
|
+
*/
|
|
359
|
+
getLayerTypes(): string[] {
|
|
360
|
+
return Array.from(this.factories.keys());
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
/**
|
|
364
|
+
* Get default configuration for a layer type.
|
|
365
|
+
* @param type - Layer type
|
|
366
|
+
* @returns Default configuration or undefined
|
|
367
|
+
*/
|
|
368
|
+
getDefaultConfig(type: GNNLayerType | string): Partial<GNNLayerConfig> | undefined {
|
|
369
|
+
return this.defaultConfigs.get(type);
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
/**
|
|
373
|
+
* Register all built-in GNN layer factories.
|
|
374
|
+
*/
|
|
375
|
+
private registerBuiltinLayers(): void {
|
|
376
|
+
// GCN - Graph Convolutional Network
|
|
377
|
+
this.registerLayer('gcn', (config) => new GCNLayer(config), {
|
|
378
|
+
normalize: true,
|
|
379
|
+
addSelfLoops: true,
|
|
380
|
+
});
|
|
381
|
+
|
|
382
|
+
// GAT - Graph Attention Network
|
|
383
|
+
this.registerLayer('gat', (config) => new GATLayer(config), {
|
|
384
|
+
numHeads: 8,
|
|
385
|
+
params: { negativeSlope: 0.2, concat: true },
|
|
386
|
+
});
|
|
387
|
+
|
|
388
|
+
// GAT v2 - Improved Graph Attention
|
|
389
|
+
this.registerLayer('gat_v2', (config) => new GATv2Layer(config), {
|
|
390
|
+
numHeads: 8,
|
|
391
|
+
params: { negativeSlope: 0.2, concat: true },
|
|
392
|
+
});
|
|
393
|
+
|
|
394
|
+
// GraphSAGE - Sampling and Aggregation
|
|
395
|
+
this.registerLayer('sage', (config) => new GraphSAGELayer(config), {
|
|
396
|
+
aggregation: 'mean',
|
|
397
|
+
params: { sampleSize: 10, samplingStrategy: 'uniform' },
|
|
398
|
+
});
|
|
399
|
+
|
|
400
|
+
// GIN - Graph Isomorphism Network
|
|
401
|
+
this.registerLayer('gin', (config) => new GINLayer(config), {
|
|
402
|
+
params: { eps: 0, trainEps: false },
|
|
403
|
+
});
|
|
404
|
+
|
|
405
|
+
// MPNN - Message Passing Neural Network
|
|
406
|
+
this.registerLayer('mpnn', (config) => new MPNNLayer(config), {
|
|
407
|
+
aggregation: 'sum',
|
|
408
|
+
});
|
|
409
|
+
|
|
410
|
+
// EdgeConv - Dynamic Edge Convolution
|
|
411
|
+
this.registerLayer('edge_conv', (config) => new EdgeConvLayer(config), {
|
|
412
|
+
params: { k: 20, dynamic: true },
|
|
413
|
+
});
|
|
414
|
+
|
|
415
|
+
// Point Convolution
|
|
416
|
+
this.registerLayer('point_conv', (config) => new PointConvLayer(config), {
|
|
417
|
+
params: { k: 16 },
|
|
418
|
+
});
|
|
419
|
+
|
|
420
|
+
// Graph Transformer
|
|
421
|
+
this.registerLayer('transformer', (config) => new GraphTransformerLayer(config), {
|
|
422
|
+
numHeads: 8,
|
|
423
|
+
params: { numLayers: 1 },
|
|
424
|
+
});
|
|
425
|
+
|
|
426
|
+
// PNA - Principal Neighbourhood Aggregation
|
|
427
|
+
this.registerLayer('pna', (config) => new PNALayer(config), {
|
|
428
|
+
params: {
|
|
429
|
+
aggregators: ['mean', 'sum', 'max', 'min'],
|
|
430
|
+
scalers: ['identity', 'amplification', 'attenuation'],
|
|
431
|
+
},
|
|
432
|
+
});
|
|
433
|
+
|
|
434
|
+
// FiLM - Feature-wise Linear Modulation
|
|
435
|
+
this.registerLayer('film', (config) => new FiLMLayer(config), {});
|
|
436
|
+
|
|
437
|
+
// RGCN - Relational Graph Convolutional Network
|
|
438
|
+
this.registerLayer('rgcn', (config) => new RGCNLayer(config), {
|
|
439
|
+
params: { numRelations: 1 },
|
|
440
|
+
});
|
|
441
|
+
|
|
442
|
+
// HGT - Heterogeneous Graph Transformer
|
|
443
|
+
this.registerLayer('hgt', (config) => new HGTLayer(config), {
|
|
444
|
+
numHeads: 8,
|
|
445
|
+
});
|
|
446
|
+
|
|
447
|
+
// HAN - Heterogeneous Attention Network
|
|
448
|
+
this.registerLayer('han', (config) => new HANLayer(config), {
|
|
449
|
+
numHeads: 8,
|
|
450
|
+
});
|
|
451
|
+
|
|
452
|
+
// MetaPath aggregation
|
|
453
|
+
this.registerLayer('metapath', (config) => new MetaPathLayer(config), {
|
|
454
|
+
params: { metapaths: [] },
|
|
455
|
+
});
|
|
456
|
+
}
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
// ============================================================================
|
|
460
|
+
// Base GNN Layer Implementation
|
|
461
|
+
// ============================================================================
|
|
462
|
+
|
|
463
|
+
/**
|
|
464
|
+
* Abstract base class for GNN layer implementations.
|
|
465
|
+
*/
|
|
466
|
+
export abstract class BaseGNNLayer implements IGNNLayer {
|
|
467
|
+
readonly type: GNNLayerType;
|
|
468
|
+
readonly config: GNNLayerConfig;
|
|
469
|
+
|
|
470
|
+
constructor(config: GNNLayerConfig) {
|
|
471
|
+
this.type = config.type;
|
|
472
|
+
this.config = config;
|
|
473
|
+
this.validateConfig();
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
/**
|
|
477
|
+
* Validate layer configuration.
|
|
478
|
+
* @throws Error if configuration is invalid
|
|
479
|
+
*/
|
|
480
|
+
protected validateConfig(): void {
|
|
481
|
+
if (this.config.inputDim <= 0) {
|
|
482
|
+
throw new Error(`Invalid inputDim: ${this.config.inputDim}. Must be positive.`);
|
|
483
|
+
}
|
|
484
|
+
if (this.config.outputDim <= 0) {
|
|
485
|
+
throw new Error(`Invalid outputDim: ${this.config.outputDim}. Must be positive.`);
|
|
486
|
+
}
|
|
487
|
+
if (this.config.dropout !== undefined && (this.config.dropout < 0 || this.config.dropout > 1)) {
|
|
488
|
+
throw new Error(`Invalid dropout: ${this.config.dropout}. Must be between 0 and 1.`);
|
|
489
|
+
}
|
|
490
|
+
if (this.config.numHeads !== undefined && this.config.numHeads <= 0) {
|
|
491
|
+
throw new Error(`Invalid numHeads: ${this.config.numHeads}. Must be positive.`);
|
|
492
|
+
}
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
abstract forward(graph: GraphData): Promise<GNNOutput>;
|
|
496
|
+
abstract messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures>;
|
|
497
|
+
|
|
498
|
+
/**
|
|
499
|
+
* Aggregate messages using the specified method.
|
|
500
|
+
*/
|
|
501
|
+
async aggregate(messages: Message[], method: AggregationMethod): Promise<number[]> {
|
|
502
|
+
if (messages.length === 0) {
|
|
503
|
+
return new Array(this.config.outputDim).fill(0);
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
const vectors = messages.map((m) => m.vector);
|
|
507
|
+
const weights = messages.map((m) => m.weight ?? 1);
|
|
508
|
+
|
|
509
|
+
switch (method) {
|
|
510
|
+
case 'sum':
|
|
511
|
+
return this.aggregateSum(vectors);
|
|
512
|
+
case 'mean':
|
|
513
|
+
return this.aggregateMean(vectors);
|
|
514
|
+
case 'max':
|
|
515
|
+
return this.aggregateMax(vectors);
|
|
516
|
+
case 'min':
|
|
517
|
+
return this.aggregateMin(vectors);
|
|
518
|
+
case 'attention':
|
|
519
|
+
return this.aggregateAttention(vectors, weights);
|
|
520
|
+
case 'weighted_mean':
|
|
521
|
+
return this.aggregateWeightedMean(vectors, weights);
|
|
522
|
+
case 'softmax':
|
|
523
|
+
return this.aggregateSoftmax(vectors);
|
|
524
|
+
case 'power_mean':
|
|
525
|
+
return this.aggregatePowerMean(vectors, 2);
|
|
526
|
+
case 'std':
|
|
527
|
+
return this.aggregateStd(vectors);
|
|
528
|
+
case 'var':
|
|
529
|
+
return this.aggregateVar(vectors);
|
|
530
|
+
case 'concat':
|
|
531
|
+
return this.aggregateConcat(vectors);
|
|
532
|
+
case 'lstm':
|
|
533
|
+
return this.aggregateLSTM(vectors);
|
|
534
|
+
case 'multi_head':
|
|
535
|
+
return this.aggregateMultiHead(vectors);
|
|
536
|
+
default:
|
|
537
|
+
return this.aggregateMean(vectors);
|
|
538
|
+
}
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
/**
|
|
542
|
+
* Reset layer state.
|
|
543
|
+
*/
|
|
544
|
+
reset(): void {
|
|
545
|
+
// Override in stateful layers
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
/**
|
|
549
|
+
* Generate SQL for this layer.
|
|
550
|
+
*/
|
|
551
|
+
toSQL(tableName: string, options: SQLGenerationOptions = {}): string {
|
|
552
|
+
const schema = options.schema ?? 'public';
|
|
553
|
+
const nodeColumn = options.nodeColumn ?? 'embedding';
|
|
554
|
+
const edgeTable = options.edgeTable ?? `${tableName}_edges`;
|
|
555
|
+
const sqlFunction = GNN_SQL_FUNCTIONS[this.type] ?? 'ruvector.gnn_layer';
|
|
556
|
+
|
|
557
|
+
const configJson = JSON.stringify({
|
|
558
|
+
type: this.type,
|
|
559
|
+
input_dim: this.config.inputDim,
|
|
560
|
+
output_dim: this.config.outputDim,
|
|
561
|
+
num_heads: this.config.numHeads,
|
|
562
|
+
dropout: this.config.dropout,
|
|
563
|
+
aggregation: this.config.aggregation,
|
|
564
|
+
add_self_loops: this.config.addSelfLoops,
|
|
565
|
+
normalize: this.config.normalize,
|
|
566
|
+
use_bias: this.config.useBias,
|
|
567
|
+
activation: this.config.activation,
|
|
568
|
+
params: this.config.params,
|
|
569
|
+
});
|
|
570
|
+
|
|
571
|
+
if (options.prepared) {
|
|
572
|
+
const prefix = options.paramPrefix ?? '$';
|
|
573
|
+
return `
|
|
574
|
+
SELECT ${sqlFunction}(
|
|
575
|
+
(SELECT array_agg(${nodeColumn}) FROM "${schema}"."${tableName}"),
|
|
576
|
+
(SELECT array_agg(ARRAY[source_id, target_id]) FROM "${schema}"."${edgeTable}"),
|
|
577
|
+
${prefix}1::jsonb
|
|
578
|
+
);`.trim();
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
return `
|
|
582
|
+
SELECT ${sqlFunction}(
|
|
583
|
+
(SELECT array_agg(${nodeColumn}) FROM "${schema}"."${tableName}"),
|
|
584
|
+
(SELECT array_agg(ARRAY[source_id, target_id]) FROM "${schema}"."${edgeTable}"),
|
|
585
|
+
'${configJson}'::jsonb
|
|
586
|
+
);`.trim();
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
// Aggregation implementations
|
|
590
|
+
protected aggregateSum(vectors: number[][]): number[] {
|
|
591
|
+
const dim = vectors[0]?.length ?? 0;
|
|
592
|
+
const result = new Array(dim).fill(0);
|
|
593
|
+
for (const vec of vectors) {
|
|
594
|
+
for (let i = 0; i < dim; i++) {
|
|
595
|
+
result[i] += vec[i] ?? 0;
|
|
596
|
+
}
|
|
597
|
+
}
|
|
598
|
+
return result;
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
protected aggregateMean(vectors: number[][]): number[] {
|
|
602
|
+
const sum = this.aggregateSum(vectors);
|
|
603
|
+
return sum.map((v) => v / vectors.length);
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
protected aggregateMax(vectors: number[][]): number[] {
|
|
607
|
+
const dim = vectors[0]?.length ?? 0;
|
|
608
|
+
const result = new Array(dim).fill(-Infinity);
|
|
609
|
+
for (const vec of vectors) {
|
|
610
|
+
for (let i = 0; i < dim; i++) {
|
|
611
|
+
result[i] = Math.max(result[i], vec[i] ?? -Infinity);
|
|
612
|
+
}
|
|
613
|
+
}
|
|
614
|
+
return result;
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
protected aggregateMin(vectors: number[][]): number[] {
|
|
618
|
+
const dim = vectors[0]?.length ?? 0;
|
|
619
|
+
const result = new Array(dim).fill(Infinity);
|
|
620
|
+
for (const vec of vectors) {
|
|
621
|
+
for (let i = 0; i < dim; i++) {
|
|
622
|
+
result[i] = Math.min(result[i], vec[i] ?? Infinity);
|
|
623
|
+
}
|
|
624
|
+
}
|
|
625
|
+
return result;
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
protected aggregateWeightedMean(vectors: number[][], weights: number[]): number[] {
|
|
629
|
+
const dim = vectors[0]?.length ?? 0;
|
|
630
|
+
const result = new Array(dim).fill(0);
|
|
631
|
+
let totalWeight = 0;
|
|
632
|
+
|
|
633
|
+
for (let j = 0; j < vectors.length; j++) {
|
|
634
|
+
const w = weights[j] ?? 1;
|
|
635
|
+
totalWeight += w;
|
|
636
|
+
for (let i = 0; i < dim; i++) {
|
|
637
|
+
result[i] += (vectors[j]?.[i] ?? 0) * w;
|
|
638
|
+
}
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
return result.map((v) => (totalWeight > 0 ? v / totalWeight : 0));
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
protected aggregateAttention(vectors: number[][], weights: number[]): number[] {
|
|
645
|
+
// Softmax over weights then weighted mean
|
|
646
|
+
const maxWeight = Math.max(...weights);
|
|
647
|
+
const expWeights = weights.map((w) => Math.exp(w - maxWeight));
|
|
648
|
+
const sumExp = expWeights.reduce((a, b) => a + b, 0);
|
|
649
|
+
const attentionWeights = expWeights.map((w) => w / sumExp);
|
|
650
|
+
return this.aggregateWeightedMean(vectors, attentionWeights);
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
protected aggregateSoftmax(vectors: number[][]): number[] {
|
|
654
|
+
// Softmax aggregation across vectors
|
|
655
|
+
const dim = vectors[0]?.length ?? 0;
|
|
656
|
+
const result = new Array(dim).fill(0);
|
|
657
|
+
|
|
658
|
+
for (let i = 0; i < dim; i++) {
|
|
659
|
+
const values = vectors.map((v) => v[i] ?? 0);
|
|
660
|
+
const maxVal = Math.max(...values);
|
|
661
|
+
const expValues = values.map((v) => Math.exp(v - maxVal));
|
|
662
|
+
const sumExp = expValues.reduce((a, b) => a + b, 0);
|
|
663
|
+
result[i] = expValues.reduce((sum, exp, j) => sum + (exp / sumExp) * values[j], 0);
|
|
664
|
+
}
|
|
665
|
+
|
|
666
|
+
return result;
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
protected aggregatePowerMean(vectors: number[][], p: number): number[] {
|
|
670
|
+
const dim = vectors[0]?.length ?? 0;
|
|
671
|
+
const result = new Array(dim).fill(0);
|
|
672
|
+
|
|
673
|
+
for (let i = 0; i < dim; i++) {
|
|
674
|
+
let sum = 0;
|
|
675
|
+
for (const vec of vectors) {
|
|
676
|
+
sum += Math.pow(Math.abs(vec[i] ?? 0), p);
|
|
677
|
+
}
|
|
678
|
+
result[i] = Math.pow(sum / vectors.length, 1 / p);
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
return result;
|
|
682
|
+
}
|
|
683
|
+
|
|
684
|
+
protected aggregateStd(vectors: number[][]): number[] {
|
|
685
|
+
const mean = this.aggregateMean(vectors);
|
|
686
|
+
const dim = mean.length;
|
|
687
|
+
const variance = new Array(dim).fill(0);
|
|
688
|
+
|
|
689
|
+
for (const vec of vectors) {
|
|
690
|
+
for (let i = 0; i < dim; i++) {
|
|
691
|
+
variance[i] += Math.pow((vec[i] ?? 0) - mean[i], 2);
|
|
692
|
+
}
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
return variance.map((v) => Math.sqrt(v / vectors.length));
|
|
696
|
+
}
|
|
697
|
+
|
|
698
|
+
protected aggregateVar(vectors: number[][]): number[] {
|
|
699
|
+
const mean = this.aggregateMean(vectors);
|
|
700
|
+
const dim = mean.length;
|
|
701
|
+
const variance = new Array(dim).fill(0);
|
|
702
|
+
|
|
703
|
+
for (const vec of vectors) {
|
|
704
|
+
for (let i = 0; i < dim; i++) {
|
|
705
|
+
variance[i] += Math.pow((vec[i] ?? 0) - mean[i], 2);
|
|
706
|
+
}
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
return variance.map((v) => v / vectors.length);
|
|
710
|
+
}
|
|
711
|
+
|
|
712
|
+
protected aggregateConcat(vectors: number[][]): number[] {
|
|
713
|
+
return vectors.flat();
|
|
714
|
+
}
|
|
715
|
+
|
|
716
|
+
protected aggregateLSTM(vectors: number[][]): number[] {
|
|
717
|
+
// Simplified LSTM-style aggregation (sequential processing)
|
|
718
|
+
let hidden = new Array(this.config.outputDim).fill(0);
|
|
719
|
+
for (const vec of vectors) {
|
|
720
|
+
hidden = this.lstmCell(vec, hidden);
|
|
721
|
+
}
|
|
722
|
+
return hidden;
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
protected aggregateMultiHead(vectors: number[][]): number[] {
|
|
726
|
+
// Split into heads, aggregate each, then combine
|
|
727
|
+
const numHeads = this.config.numHeads ?? 1;
|
|
728
|
+
const headDim = Math.floor((vectors[0]?.length ?? 0) / numHeads);
|
|
729
|
+
const results: number[][] = [];
|
|
730
|
+
|
|
731
|
+
for (let h = 0; h < numHeads; h++) {
|
|
732
|
+
const headVectors = vectors.map((v) =>
|
|
733
|
+
v.slice(h * headDim, (h + 1) * headDim)
|
|
734
|
+
);
|
|
735
|
+
results.push(this.aggregateMean(headVectors));
|
|
736
|
+
}
|
|
737
|
+
|
|
738
|
+
return results.flat();
|
|
739
|
+
}
|
|
740
|
+
|
|
741
|
+
private lstmCell(input: number[], hidden: number[]): number[] {
|
|
742
|
+
// Simplified LSTM update (no learned parameters)
|
|
743
|
+
const dim = hidden.length;
|
|
744
|
+
const inputDim = input.length;
|
|
745
|
+
const result = new Array(dim).fill(0);
|
|
746
|
+
|
|
747
|
+
for (let i = 0; i < dim; i++) {
|
|
748
|
+
const inputVal = input[i % inputDim] ?? 0;
|
|
749
|
+
const hiddenVal = hidden[i] ?? 0;
|
|
750
|
+
// Simple gated update
|
|
751
|
+
const gate = 1 / (1 + Math.exp(-(inputVal + hiddenVal)));
|
|
752
|
+
result[i] = gate * inputVal + (1 - gate) * hiddenVal;
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
return result;
|
|
756
|
+
}
|
|
757
|
+
|
|
758
|
+
/**
|
|
759
|
+
* Apply activation function.
|
|
760
|
+
*/
|
|
761
|
+
protected applyActivation(x: number): number {
|
|
762
|
+
switch (this.config.activation) {
|
|
763
|
+
case 'relu':
|
|
764
|
+
return Math.max(0, x);
|
|
765
|
+
case 'gelu':
|
|
766
|
+
return 0.5 * x * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (x + 0.044715 * Math.pow(x, 3))));
|
|
767
|
+
case 'silu':
|
|
768
|
+
case 'swish':
|
|
769
|
+
return x / (1 + Math.exp(-x));
|
|
770
|
+
case 'leaky_relu':
|
|
771
|
+
return x >= 0 ? x : 0.01 * x;
|
|
772
|
+
case 'elu':
|
|
773
|
+
return x >= 0 ? x : Math.exp(x) - 1;
|
|
774
|
+
case 'selu':
|
|
775
|
+
const alpha = 1.6732632423543772;
|
|
776
|
+
const scale = 1.0507009873554805;
|
|
777
|
+
return scale * (x >= 0 ? x : alpha * (Math.exp(x) - 1));
|
|
778
|
+
case 'tanh':
|
|
779
|
+
return Math.tanh(x);
|
|
780
|
+
case 'sigmoid':
|
|
781
|
+
return 1 / (1 + Math.exp(-x));
|
|
782
|
+
case 'softmax':
|
|
783
|
+
case 'none':
|
|
784
|
+
default:
|
|
785
|
+
return x;
|
|
786
|
+
}
|
|
787
|
+
}
|
|
788
|
+
|
|
789
|
+
/**
|
|
790
|
+
* Apply dropout (during training).
|
|
791
|
+
*/
|
|
792
|
+
protected applyDropout(vector: number[], training: boolean = false): number[] {
|
|
793
|
+
if (!training || !this.config.dropout || this.config.dropout === 0) {
|
|
794
|
+
return vector;
|
|
795
|
+
}
|
|
796
|
+
|
|
797
|
+
const scale = 1 / (1 - this.config.dropout);
|
|
798
|
+
return vector.map((v) => (Math.random() > this.config.dropout! ? v * scale : 0));
|
|
799
|
+
}
|
|
800
|
+
|
|
801
|
+
/**
|
|
802
|
+
* Normalize vector (L2 normalization).
|
|
803
|
+
*/
|
|
804
|
+
protected normalizeVector(vector: number[]): number[] {
|
|
805
|
+
const norm = Math.sqrt(vector.reduce((sum, v) => sum + v * v, 0));
|
|
806
|
+
return norm > 0 ? vector.map((v) => v / norm) : vector;
|
|
807
|
+
}
|
|
808
|
+
|
|
809
|
+
/**
|
|
810
|
+
* Create statistics for GNN computation.
|
|
811
|
+
*/
|
|
812
|
+
protected createStats(
|
|
813
|
+
startTime: number,
|
|
814
|
+
numNodes: number,
|
|
815
|
+
numEdges: number,
|
|
816
|
+
numIterations: number = 1
|
|
817
|
+
): GNNStats {
|
|
818
|
+
return {
|
|
819
|
+
forwardTimeMs: Date.now() - startTime,
|
|
820
|
+
numNodes,
|
|
821
|
+
numEdges,
|
|
822
|
+
memoryBytes: numNodes * this.config.outputDim * 4 + numEdges * 8,
|
|
823
|
+
numIterations,
|
|
824
|
+
};
|
|
825
|
+
}
|
|
826
|
+
}
|
|
827
|
+
|
|
828
|
+
// ============================================================================
|
|
829
|
+
// GCN Layer Implementation
|
|
830
|
+
// ============================================================================
|
|
831
|
+
|
|
832
|
+
/**
|
|
833
|
+
* Graph Convolutional Network (GCN) layer.
|
|
834
|
+
*
|
|
835
|
+
* Implements spectral graph convolution with first-order approximation.
|
|
836
|
+
* Reference: Kipf & Welling, "Semi-Supervised Classification with Graph Convolutional Networks" (2017)
|
|
837
|
+
*/
|
|
838
|
+
export class GCNLayer extends BaseGNNLayer {
|
|
839
|
+
async forward(graph: GraphData): Promise<GNNOutput> {
|
|
840
|
+
const startTime = Date.now();
|
|
841
|
+
const { nodeFeatures, edgeIndex, edgeWeights } = graph;
|
|
842
|
+
const numNodes = nodeFeatures.length;
|
|
843
|
+
const numEdges = edgeIndex[0].length;
|
|
844
|
+
|
|
845
|
+
// Build adjacency with self-loops
|
|
846
|
+
const adj = this.buildAdjacency(numNodes, edgeIndex, edgeWeights);
|
|
847
|
+
|
|
848
|
+
// Normalize adjacency (D^-0.5 * A * D^-0.5)
|
|
849
|
+
const normAdj = this.config.normalize ? this.symmetricNormalize(adj, numNodes) : adj;
|
|
850
|
+
|
|
851
|
+
// Message passing: H' = sigma(A_norm * H * W)
|
|
852
|
+
const outputFeatures = this.convolve(nodeFeatures, normAdj);
|
|
853
|
+
|
|
854
|
+
return {
|
|
855
|
+
nodeEmbeddings: outputFeatures,
|
|
856
|
+
graphEmbedding: this.poolGraph(outputFeatures),
|
|
857
|
+
stats: this.createStats(startTime, numNodes, numEdges),
|
|
858
|
+
};
|
|
859
|
+
}
|
|
860
|
+
|
|
861
|
+
async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
|
|
862
|
+
const numNodes = nodes.ids.length;
|
|
863
|
+
const edgeIndex: [number[], number[]] = [
|
|
864
|
+
edges.sources.map((s) => nodes.ids.indexOf(s)),
|
|
865
|
+
edges.targets.map((t) => nodes.ids.indexOf(t)),
|
|
866
|
+
];
|
|
867
|
+
|
|
868
|
+
const adj = this.buildAdjacency(numNodes, edgeIndex, edges.weights);
|
|
869
|
+
const normAdj = this.config.normalize ? this.symmetricNormalize(adj, numNodes) : adj;
|
|
870
|
+
const outputFeatures = this.convolve(nodes.features, normAdj);
|
|
871
|
+
|
|
872
|
+
return {
|
|
873
|
+
ids: nodes.ids,
|
|
874
|
+
features: outputFeatures,
|
|
875
|
+
types: nodes.types,
|
|
876
|
+
labels: nodes.labels,
|
|
877
|
+
};
|
|
878
|
+
}
|
|
879
|
+
|
|
880
|
+
private buildAdjacency(
|
|
881
|
+
numNodes: number,
|
|
882
|
+
edgeIndex: [number[], number[]],
|
|
883
|
+
weights?: number[]
|
|
884
|
+
): Map<number, Map<number, number>> {
|
|
885
|
+
const adj = new Map<number, Map<number, number>>();
|
|
886
|
+
|
|
887
|
+
// Initialize with self-loops if configured
|
|
888
|
+
for (let i = 0; i < numNodes; i++) {
|
|
889
|
+
adj.set(i, new Map());
|
|
890
|
+
if (this.config.addSelfLoops) {
|
|
891
|
+
adj.get(i)!.set(i, 1);
|
|
892
|
+
}
|
|
893
|
+
}
|
|
894
|
+
|
|
895
|
+
// Add edges
|
|
896
|
+
const [sources, targets] = edgeIndex;
|
|
897
|
+
for (let i = 0; i < sources.length; i++) {
|
|
898
|
+
const src = sources[i];
|
|
899
|
+
const tgt = targets[i];
|
|
900
|
+
const weight = weights?.[i] ?? 1;
|
|
901
|
+
if (src >= 0 && src < numNodes && tgt >= 0 && tgt < numNodes) {
|
|
902
|
+
adj.get(src)!.set(tgt, weight);
|
|
903
|
+
// Undirected: add reverse edge
|
|
904
|
+
adj.get(tgt)!.set(src, weight);
|
|
905
|
+
}
|
|
906
|
+
}
|
|
907
|
+
|
|
908
|
+
return adj;
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
private symmetricNormalize(
|
|
912
|
+
adj: Map<number, Map<number, number>>,
|
|
913
|
+
numNodes: number
|
|
914
|
+
): Map<number, Map<number, number>> {
|
|
915
|
+
// Compute degree
|
|
916
|
+
const degree = new Array(numNodes).fill(0);
|
|
917
|
+
for (let i = 0; i < numNodes; i++) {
|
|
918
|
+
for (const weight of adj.get(i)!.values()) {
|
|
919
|
+
degree[i] += weight;
|
|
920
|
+
}
|
|
921
|
+
}
|
|
922
|
+
|
|
923
|
+
// D^-0.5 * A * D^-0.5
|
|
924
|
+
const normAdj = new Map<number, Map<number, number>>();
|
|
925
|
+
for (let i = 0; i < numNodes; i++) {
|
|
926
|
+
normAdj.set(i, new Map());
|
|
927
|
+
for (const [j, weight] of adj.get(i)!.entries()) {
|
|
928
|
+
const normWeight = weight / Math.sqrt(degree[i] * degree[j] + 1e-10);
|
|
929
|
+
normAdj.get(i)!.set(j, normWeight);
|
|
930
|
+
}
|
|
931
|
+
}
|
|
932
|
+
|
|
933
|
+
return normAdj;
|
|
934
|
+
}
|
|
935
|
+
|
|
936
|
+
private convolve(features: number[][], adj: Map<number, Map<number, number>>): number[][] {
|
|
937
|
+
const numNodes = features.length;
|
|
938
|
+
const inputDim = this.config.inputDim;
|
|
939
|
+
const outputDim = this.config.outputDim;
|
|
940
|
+
const output: number[][] = [];
|
|
941
|
+
|
|
942
|
+
for (let i = 0; i < numNodes; i++) {
|
|
943
|
+
const aggregated = new Array(inputDim).fill(0);
|
|
944
|
+
|
|
945
|
+
// Aggregate neighbor features
|
|
946
|
+
for (const [j, weight] of adj.get(i)!.entries()) {
|
|
947
|
+
const neighborFeatures = features[j] ?? new Array(inputDim).fill(0);
|
|
948
|
+
for (let k = 0; k < inputDim; k++) {
|
|
949
|
+
aggregated[k] += weight * (neighborFeatures[k] ?? 0);
|
|
950
|
+
}
|
|
951
|
+
}
|
|
952
|
+
|
|
953
|
+
// Project to output dimension
|
|
954
|
+
const projected = this.projectFeatures(aggregated, inputDim, outputDim);
|
|
955
|
+
|
|
956
|
+
// Apply activation
|
|
957
|
+
const activated = projected.map((x) => this.applyActivation(x));
|
|
958
|
+
|
|
959
|
+
// Apply dropout
|
|
960
|
+
output.push(this.applyDropout(activated));
|
|
961
|
+
}
|
|
962
|
+
|
|
963
|
+
return output;
|
|
964
|
+
}
|
|
965
|
+
|
|
966
|
+
private projectFeatures(input: number[], inputDim: number, outputDim: number): number[] {
|
|
967
|
+
// Simple linear projection (in practice, this would use learned weights)
|
|
968
|
+
const output = new Array(outputDim).fill(0);
|
|
969
|
+
for (let i = 0; i < outputDim; i++) {
|
|
970
|
+
for (let j = 0; j < inputDim; j++) {
|
|
971
|
+
// Use a deterministic pseudo-weight based on position
|
|
972
|
+
const weight = Math.sin((i * inputDim + j) * 0.1) * 0.5;
|
|
973
|
+
output[i] += input[j] * weight;
|
|
974
|
+
}
|
|
975
|
+
if (this.config.useBias) {
|
|
976
|
+
output[i] += 0.01; // Small bias term
|
|
977
|
+
}
|
|
978
|
+
}
|
|
979
|
+
return output;
|
|
980
|
+
}
|
|
981
|
+
|
|
982
|
+
private poolGraph(features: number[][]): number[] {
|
|
983
|
+
if (features.length === 0) return [];
|
|
984
|
+
return this.aggregateMean(features);
|
|
985
|
+
}
|
|
986
|
+
}
|
|
987
|
+
|
|
988
|
+
// ============================================================================
|
|
989
|
+
// GAT Layer Implementation
|
|
990
|
+
// ============================================================================
|
|
991
|
+
|
|
992
|
+
/**
|
|
993
|
+
* Graph Attention Network (GAT) layer.
|
|
994
|
+
*
|
|
995
|
+
* Implements attention-based message passing.
|
|
996
|
+
* Reference: Veličković et al., "Graph Attention Networks" (2018)
|
|
997
|
+
*/
|
|
998
|
+
export class GATLayer extends BaseGNNLayer {
|
|
999
|
+
async forward(graph: GraphData): Promise<GNNOutput> {
|
|
1000
|
+
const startTime = Date.now();
|
|
1001
|
+
const { nodeFeatures, edgeIndex } = graph;
|
|
1002
|
+
const numNodes = nodeFeatures.length;
|
|
1003
|
+
const numEdges = edgeIndex[0].length;
|
|
1004
|
+
const numHeads = this.config.numHeads ?? 1;
|
|
1005
|
+
const negativeSlope = this.config.params?.negativeSlope ?? 0.2;
|
|
1006
|
+
|
|
1007
|
+
// Compute attention for each head
|
|
1008
|
+
const headOutputs: number[][][] = [];
|
|
1009
|
+
|
|
1010
|
+
for (let h = 0; h < numHeads; h++) {
|
|
1011
|
+
const headDim = Math.floor(this.config.outputDim / numHeads);
|
|
1012
|
+
const headFeatures: number[][] = [];
|
|
1013
|
+
|
|
1014
|
+
for (let i = 0; i < numNodes; i++) {
|
|
1015
|
+
const neighbors = this.getNeighbors(i, edgeIndex, numNodes);
|
|
1016
|
+
const messages: { feature: number[]; attention: number }[] = [];
|
|
1017
|
+
|
|
1018
|
+
// Compute attention for each neighbor
|
|
1019
|
+
for (const j of neighbors) {
|
|
1020
|
+
const attention = this.computeAttention(
|
|
1021
|
+
nodeFeatures[i],
|
|
1022
|
+
nodeFeatures[j],
|
|
1023
|
+
h,
|
|
1024
|
+
negativeSlope
|
|
1025
|
+
);
|
|
1026
|
+
messages.push({
|
|
1027
|
+
feature: this.projectHead(nodeFeatures[j], h, headDim),
|
|
1028
|
+
attention,
|
|
1029
|
+
});
|
|
1030
|
+
}
|
|
1031
|
+
|
|
1032
|
+
// Softmax attention weights
|
|
1033
|
+
const attentionSum = messages.reduce(
|
|
1034
|
+
(sum, m) => sum + Math.exp(m.attention),
|
|
1035
|
+
0
|
|
1036
|
+
);
|
|
1037
|
+
const normalizedMessages = messages.map((m) => ({
|
|
1038
|
+
feature: m.feature,
|
|
1039
|
+
weight: Math.exp(m.attention) / (attentionSum + 1e-10),
|
|
1040
|
+
}));
|
|
1041
|
+
|
|
1042
|
+
// Aggregate with attention weights
|
|
1043
|
+
const aggregated = new Array(headDim).fill(0);
|
|
1044
|
+
for (const m of normalizedMessages) {
|
|
1045
|
+
for (let k = 0; k < headDim; k++) {
|
|
1046
|
+
aggregated[k] += m.weight * (m.feature[k] ?? 0);
|
|
1047
|
+
}
|
|
1048
|
+
}
|
|
1049
|
+
|
|
1050
|
+
headFeatures.push(aggregated);
|
|
1051
|
+
}
|
|
1052
|
+
|
|
1053
|
+
headOutputs.push(headFeatures);
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
// Combine heads (concat or average)
|
|
1057
|
+
const concat = this.config.params?.concat ?? true;
|
|
1058
|
+
const outputFeatures = this.combineHeads(headOutputs, concat);
|
|
1059
|
+
|
|
1060
|
+
// Apply activation and dropout
|
|
1061
|
+
const finalFeatures = outputFeatures.map((f) =>
|
|
1062
|
+
this.applyDropout(f.map((x) => this.applyActivation(x)))
|
|
1063
|
+
);
|
|
1064
|
+
|
|
1065
|
+
return {
|
|
1066
|
+
nodeEmbeddings: finalFeatures,
|
|
1067
|
+
graphEmbedding: this.aggregateMean(finalFeatures),
|
|
1068
|
+
attentionWeights: this.extractAttentionWeights(headOutputs),
|
|
1069
|
+
stats: this.createStats(startTime, numNodes, numEdges),
|
|
1070
|
+
};
|
|
1071
|
+
}
|
|
1072
|
+
|
|
1073
|
+
async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
|
|
1074
|
+
const graph: GraphData = {
|
|
1075
|
+
nodeFeatures: nodes.features,
|
|
1076
|
+
edgeIndex: [
|
|
1077
|
+
edges.sources.map((s) => nodes.ids.indexOf(s)),
|
|
1078
|
+
edges.targets.map((t) => nodes.ids.indexOf(t)),
|
|
1079
|
+
],
|
|
1080
|
+
};
|
|
1081
|
+
|
|
1082
|
+
const output = await this.forward(graph);
|
|
1083
|
+
|
|
1084
|
+
return {
|
|
1085
|
+
ids: nodes.ids,
|
|
1086
|
+
features: output.nodeEmbeddings,
|
|
1087
|
+
types: nodes.types,
|
|
1088
|
+
labels: nodes.labels,
|
|
1089
|
+
};
|
|
1090
|
+
}
|
|
1091
|
+
|
|
1092
|
+
private getNeighbors(
|
|
1093
|
+
nodeIdx: number,
|
|
1094
|
+
edgeIndex: [number[], number[]],
|
|
1095
|
+
numNodes: number
|
|
1096
|
+
): number[] {
|
|
1097
|
+
const neighbors = new Set<number>();
|
|
1098
|
+
|
|
1099
|
+
// Add self-loop
|
|
1100
|
+
if (this.config.addSelfLoops) {
|
|
1101
|
+
neighbors.add(nodeIdx);
|
|
1102
|
+
}
|
|
1103
|
+
|
|
1104
|
+
// Find neighbors from edges
|
|
1105
|
+
const [sources, targets] = edgeIndex;
|
|
1106
|
+
for (let i = 0; i < sources.length; i++) {
|
|
1107
|
+
if (sources[i] === nodeIdx && targets[i] < numNodes) {
|
|
1108
|
+
neighbors.add(targets[i]);
|
|
1109
|
+
}
|
|
1110
|
+
if (targets[i] === nodeIdx && sources[i] < numNodes) {
|
|
1111
|
+
neighbors.add(sources[i]);
|
|
1112
|
+
}
|
|
1113
|
+
}
|
|
1114
|
+
|
|
1115
|
+
return Array.from(neighbors);
|
|
1116
|
+
}
|
|
1117
|
+
|
|
1118
|
+
protected computeAttention(
|
|
1119
|
+
nodeI: number[],
|
|
1120
|
+
nodeJ: number[],
|
|
1121
|
+
head: number,
|
|
1122
|
+
negativeSlope: number
|
|
1123
|
+
): number {
|
|
1124
|
+
// Compute attention score using concatenation of features
|
|
1125
|
+
let score = 0;
|
|
1126
|
+
const dim = nodeI.length;
|
|
1127
|
+
|
|
1128
|
+
for (let k = 0; k < dim; k++) {
|
|
1129
|
+
// Simple attention mechanism (in practice, uses learned attention weights)
|
|
1130
|
+
const combined = (nodeI[k] ?? 0) + (nodeJ[k] ?? 0);
|
|
1131
|
+
score += combined * Math.sin((head * dim + k) * 0.1);
|
|
1132
|
+
}
|
|
1133
|
+
|
|
1134
|
+
// LeakyReLU
|
|
1135
|
+
return score >= 0 ? score : negativeSlope * score;
|
|
1136
|
+
}
|
|
1137
|
+
|
|
1138
|
+
private projectHead(features: number[], head: number, headDim: number): number[] {
|
|
1139
|
+
const output = new Array(headDim).fill(0);
|
|
1140
|
+
const inputDim = features.length;
|
|
1141
|
+
|
|
1142
|
+
for (let i = 0; i < headDim; i++) {
|
|
1143
|
+
for (let j = 0; j < inputDim; j++) {
|
|
1144
|
+
const weight = Math.cos((head * headDim * inputDim + i * inputDim + j) * 0.05);
|
|
1145
|
+
output[i] += (features[j] ?? 0) * weight;
|
|
1146
|
+
}
|
|
1147
|
+
}
|
|
1148
|
+
|
|
1149
|
+
return output;
|
|
1150
|
+
}
|
|
1151
|
+
|
|
1152
|
+
private combineHeads(heads: number[][][], concat: boolean): number[][] {
|
|
1153
|
+
const numNodes = heads[0]?.length ?? 0;
|
|
1154
|
+
const result: number[][] = [];
|
|
1155
|
+
|
|
1156
|
+
for (let i = 0; i < numNodes; i++) {
|
|
1157
|
+
if (concat) {
|
|
1158
|
+
// Concatenate all head outputs
|
|
1159
|
+
result.push(heads.flatMap((h) => h[i] ?? []));
|
|
1160
|
+
} else {
|
|
1161
|
+
// Average head outputs
|
|
1162
|
+
const headDim = heads[0]?.[0]?.length ?? 0;
|
|
1163
|
+
const averaged = new Array(headDim).fill(0);
|
|
1164
|
+
for (const head of heads) {
|
|
1165
|
+
for (let j = 0; j < headDim; j++) {
|
|
1166
|
+
averaged[j] += (head[i]?.[j] ?? 0) / heads.length;
|
|
1167
|
+
}
|
|
1168
|
+
}
|
|
1169
|
+
result.push(averaged);
|
|
1170
|
+
}
|
|
1171
|
+
}
|
|
1172
|
+
|
|
1173
|
+
return result;
|
|
1174
|
+
}
|
|
1175
|
+
|
|
1176
|
+
private extractAttentionWeights(heads: number[][][]): number[][] {
|
|
1177
|
+
// Return simplified attention representation
|
|
1178
|
+
return heads.map((h) => h.map((node) => node.reduce((a, b) => a + b, 0) / node.length));
|
|
1179
|
+
}
|
|
1180
|
+
}
|
|
1181
|
+
|
|
1182
|
+
// ============================================================================
|
|
1183
|
+
// GAT v2 Layer Implementation
|
|
1184
|
+
// ============================================================================
|
|
1185
|
+
|
|
1186
|
+
/**
|
|
1187
|
+
* Graph Attention Network v2 layer.
|
|
1188
|
+
*
|
|
1189
|
+
* Improved attention mechanism with dynamic attention.
|
|
1190
|
+
* Reference: Brody et al., "How Attentive are Graph Attention Networks?" (2022)
|
|
1191
|
+
*/
|
|
1192
|
+
export class GATv2Layer extends GATLayer {
|
|
1193
|
+
protected override computeAttention(
|
|
1194
|
+
nodeI: number[],
|
|
1195
|
+
nodeJ: number[],
|
|
1196
|
+
head: number,
|
|
1197
|
+
negativeSlope: number
|
|
1198
|
+
): number {
|
|
1199
|
+
// GAT v2: Apply attention AFTER concatenation and transformation
|
|
1200
|
+
const dim = nodeI.length;
|
|
1201
|
+
const combined = new Array(dim).fill(0);
|
|
1202
|
+
|
|
1203
|
+
// First, transform and combine
|
|
1204
|
+
for (let k = 0; k < dim; k++) {
|
|
1205
|
+
combined[k] = (nodeI[k] ?? 0) + (nodeJ[k] ?? 0);
|
|
1206
|
+
}
|
|
1207
|
+
|
|
1208
|
+
// Apply LeakyReLU
|
|
1209
|
+
for (let k = 0; k < dim; k++) {
|
|
1210
|
+
combined[k] = combined[k] >= 0 ? combined[k] : negativeSlope * combined[k];
|
|
1211
|
+
}
|
|
1212
|
+
|
|
1213
|
+
// Then compute attention
|
|
1214
|
+
let score = 0;
|
|
1215
|
+
for (let k = 0; k < dim; k++) {
|
|
1216
|
+
score += combined[k] * Math.sin((head * dim + k) * 0.1);
|
|
1217
|
+
}
|
|
1218
|
+
|
|
1219
|
+
return score;
|
|
1220
|
+
}
|
|
1221
|
+
}
|
|
1222
|
+
|
|
1223
|
+
// ============================================================================
|
|
1224
|
+
// GraphSAGE Layer Implementation
|
|
1225
|
+
// ============================================================================
|
|
1226
|
+
|
|
1227
|
+
/**
|
|
1228
|
+
* GraphSAGE (Sample and Aggregate) layer.
|
|
1229
|
+
*
|
|
1230
|
+
* Implements inductive representation learning with neighbor sampling.
|
|
1231
|
+
* Reference: Hamilton et al., "Inductive Representation Learning on Large Graphs" (2017)
|
|
1232
|
+
*/
|
|
1233
|
+
export class GraphSAGELayer extends BaseGNNLayer {
|
|
1234
|
+
async forward(graph: GraphData): Promise<GNNOutput> {
|
|
1235
|
+
const startTime = Date.now();
|
|
1236
|
+
const { nodeFeatures, edgeIndex } = graph;
|
|
1237
|
+
const numNodes = nodeFeatures.length;
|
|
1238
|
+
const numEdges = edgeIndex[0].length;
|
|
1239
|
+
const sampleSize = this.config.params?.sampleSize ?? 10;
|
|
1240
|
+
|
|
1241
|
+
const outputFeatures: number[][] = [];
|
|
1242
|
+
|
|
1243
|
+
for (let i = 0; i < numNodes; i++) {
|
|
1244
|
+
// Sample neighbors
|
|
1245
|
+
const allNeighbors = this.getNeighbors(i, edgeIndex, numNodes);
|
|
1246
|
+
const sampledNeighbors = this.sampleNeighbors(allNeighbors, sampleSize);
|
|
1247
|
+
|
|
1248
|
+
// Aggregate neighbor features
|
|
1249
|
+
const neighborFeatures = sampledNeighbors.map((j) => nodeFeatures[j] ?? []);
|
|
1250
|
+
const aggregated = await this.aggregate(
|
|
1251
|
+
neighborFeatures.map((f) => ({ source: i, target: i, vector: f })),
|
|
1252
|
+
this.config.aggregation ?? 'mean'
|
|
1253
|
+
);
|
|
1254
|
+
|
|
1255
|
+
// Concatenate with self features and project
|
|
1256
|
+
const selfFeatures = nodeFeatures[i] ?? [];
|
|
1257
|
+
const combined = [...selfFeatures, ...aggregated];
|
|
1258
|
+
const projected = this.projectFeatures(combined, combined.length, this.config.outputDim);
|
|
1259
|
+
|
|
1260
|
+
// Normalize, activate, and apply dropout
|
|
1261
|
+
const normalized = this.config.normalize ? this.normalizeVector(projected) : projected;
|
|
1262
|
+
const activated = normalized.map((x) => this.applyActivation(x));
|
|
1263
|
+
outputFeatures.push(this.applyDropout(activated));
|
|
1264
|
+
}
|
|
1265
|
+
|
|
1266
|
+
return {
|
|
1267
|
+
nodeEmbeddings: outputFeatures,
|
|
1268
|
+
graphEmbedding: this.aggregateMean(outputFeatures),
|
|
1269
|
+
stats: this.createStats(startTime, numNodes, numEdges),
|
|
1270
|
+
};
|
|
1271
|
+
}
|
|
1272
|
+
|
|
1273
|
+
async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
|
|
1274
|
+
const graph: GraphData = {
|
|
1275
|
+
nodeFeatures: nodes.features,
|
|
1276
|
+
edgeIndex: [
|
|
1277
|
+
edges.sources.map((s) => nodes.ids.indexOf(s)),
|
|
1278
|
+
edges.targets.map((t) => nodes.ids.indexOf(t)),
|
|
1279
|
+
],
|
|
1280
|
+
};
|
|
1281
|
+
|
|
1282
|
+
const output = await this.forward(graph);
|
|
1283
|
+
|
|
1284
|
+
return {
|
|
1285
|
+
ids: nodes.ids,
|
|
1286
|
+
features: output.nodeEmbeddings,
|
|
1287
|
+
types: nodes.types,
|
|
1288
|
+
labels: nodes.labels,
|
|
1289
|
+
};
|
|
1290
|
+
}
|
|
1291
|
+
|
|
1292
|
+
private getNeighbors(
|
|
1293
|
+
nodeIdx: number,
|
|
1294
|
+
edgeIndex: [number[], number[]],
|
|
1295
|
+
numNodes: number
|
|
1296
|
+
): number[] {
|
|
1297
|
+
const neighbors = new Set<number>();
|
|
1298
|
+
const [sources, targets] = edgeIndex;
|
|
1299
|
+
|
|
1300
|
+
for (let i = 0; i < sources.length; i++) {
|
|
1301
|
+
if (sources[i] === nodeIdx && targets[i] < numNodes) {
|
|
1302
|
+
neighbors.add(targets[i]);
|
|
1303
|
+
}
|
|
1304
|
+
if (targets[i] === nodeIdx && sources[i] < numNodes) {
|
|
1305
|
+
neighbors.add(sources[i]);
|
|
1306
|
+
}
|
|
1307
|
+
}
|
|
1308
|
+
|
|
1309
|
+
return Array.from(neighbors);
|
|
1310
|
+
}
|
|
1311
|
+
|
|
1312
|
+
private sampleNeighbors(neighbors: number[], k: number): number[] {
|
|
1313
|
+
if (neighbors.length <= k) {
|
|
1314
|
+
return neighbors;
|
|
1315
|
+
}
|
|
1316
|
+
|
|
1317
|
+
// Random sampling
|
|
1318
|
+
const sampled: number[] = [];
|
|
1319
|
+
const available = [...neighbors];
|
|
1320
|
+
|
|
1321
|
+
for (let i = 0; i < k && available.length > 0; i++) {
|
|
1322
|
+
const idx = Math.floor(Math.random() * available.length);
|
|
1323
|
+
sampled.push(available[idx]);
|
|
1324
|
+
available.splice(idx, 1);
|
|
1325
|
+
}
|
|
1326
|
+
|
|
1327
|
+
return sampled;
|
|
1328
|
+
}
|
|
1329
|
+
|
|
1330
|
+
private projectFeatures(input: number[], inputDim: number, outputDim: number): number[] {
|
|
1331
|
+
const output = new Array(outputDim).fill(0);
|
|
1332
|
+
for (let i = 0; i < outputDim; i++) {
|
|
1333
|
+
for (let j = 0; j < inputDim; j++) {
|
|
1334
|
+
const weight = Math.sin((i * inputDim + j) * 0.1) * Math.sqrt(2 / (inputDim + outputDim));
|
|
1335
|
+
output[i] += (input[j] ?? 0) * weight;
|
|
1336
|
+
}
|
|
1337
|
+
if (this.config.useBias) {
|
|
1338
|
+
output[i] += 0.01;
|
|
1339
|
+
}
|
|
1340
|
+
}
|
|
1341
|
+
return output;
|
|
1342
|
+
}
|
|
1343
|
+
}
|
|
1344
|
+
|
|
1345
|
+
// ============================================================================
|
|
1346
|
+
// GIN Layer Implementation
|
|
1347
|
+
// ============================================================================
|
|
1348
|
+
|
|
1349
|
+
/**
|
|
1350
|
+
* Graph Isomorphism Network (GIN) layer.
|
|
1351
|
+
*
|
|
1352
|
+
* Maximally powerful GNN for graph classification.
|
|
1353
|
+
* Reference: Xu et al., "How Powerful are Graph Neural Networks?" (2019)
|
|
1354
|
+
*/
|
|
1355
|
+
export class GINLayer extends BaseGNNLayer {
|
|
1356
|
+
async forward(graph: GraphData): Promise<GNNOutput> {
|
|
1357
|
+
const startTime = Date.now();
|
|
1358
|
+
const { nodeFeatures, edgeIndex } = graph;
|
|
1359
|
+
const numNodes = nodeFeatures.length;
|
|
1360
|
+
const numEdges = edgeIndex[0].length;
|
|
1361
|
+
const eps = this.config.params?.eps ?? 0;
|
|
1362
|
+
|
|
1363
|
+
const outputFeatures: number[][] = [];
|
|
1364
|
+
|
|
1365
|
+
for (let i = 0; i < numNodes; i++) {
|
|
1366
|
+
const neighbors = this.getNeighbors(i, edgeIndex, numNodes);
|
|
1367
|
+
|
|
1368
|
+
// Sum neighbor features
|
|
1369
|
+
const neighborSum = new Array(this.config.inputDim).fill(0);
|
|
1370
|
+
for (const j of neighbors) {
|
|
1371
|
+
const neighborFeatures = nodeFeatures[j] ?? [];
|
|
1372
|
+
for (let k = 0; k < this.config.inputDim; k++) {
|
|
1373
|
+
neighborSum[k] += neighborFeatures[k] ?? 0;
|
|
1374
|
+
}
|
|
1375
|
+
}
|
|
1376
|
+
|
|
1377
|
+
// GIN update: h_v = MLP((1 + eps) * h_v + sum(h_u))
|
|
1378
|
+
const selfFeatures = nodeFeatures[i] ?? [];
|
|
1379
|
+
const combined = new Array(this.config.inputDim).fill(0);
|
|
1380
|
+
for (let k = 0; k < this.config.inputDim; k++) {
|
|
1381
|
+
combined[k] = (1 + eps) * (selfFeatures[k] ?? 0) + neighborSum[k];
|
|
1382
|
+
}
|
|
1383
|
+
|
|
1384
|
+
// MLP (2-layer)
|
|
1385
|
+
const hidden = this.mlpLayer1(combined);
|
|
1386
|
+
const output = this.mlpLayer2(hidden);
|
|
1387
|
+
outputFeatures.push(this.applyDropout(output));
|
|
1388
|
+
}
|
|
1389
|
+
|
|
1390
|
+
return {
|
|
1391
|
+
nodeEmbeddings: outputFeatures,
|
|
1392
|
+
graphEmbedding: this.aggregateSum(outputFeatures), // Sum pooling for graph classification
|
|
1393
|
+
stats: this.createStats(startTime, numNodes, numEdges),
|
|
1394
|
+
};
|
|
1395
|
+
}
|
|
1396
|
+
|
|
1397
|
+
async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
|
|
1398
|
+
const graph: GraphData = {
|
|
1399
|
+
nodeFeatures: nodes.features,
|
|
1400
|
+
edgeIndex: [
|
|
1401
|
+
edges.sources.map((s) => nodes.ids.indexOf(s)),
|
|
1402
|
+
edges.targets.map((t) => nodes.ids.indexOf(t)),
|
|
1403
|
+
],
|
|
1404
|
+
};
|
|
1405
|
+
|
|
1406
|
+
const output = await this.forward(graph);
|
|
1407
|
+
|
|
1408
|
+
return {
|
|
1409
|
+
ids: nodes.ids,
|
|
1410
|
+
features: output.nodeEmbeddings,
|
|
1411
|
+
types: nodes.types,
|
|
1412
|
+
labels: nodes.labels,
|
|
1413
|
+
};
|
|
1414
|
+
}
|
|
1415
|
+
|
|
1416
|
+
private getNeighbors(
|
|
1417
|
+
nodeIdx: number,
|
|
1418
|
+
edgeIndex: [number[], number[]],
|
|
1419
|
+
numNodes: number
|
|
1420
|
+
): number[] {
|
|
1421
|
+
const neighbors = new Set<number>();
|
|
1422
|
+
const [sources, targets] = edgeIndex;
|
|
1423
|
+
|
|
1424
|
+
for (let i = 0; i < sources.length; i++) {
|
|
1425
|
+
if (sources[i] === nodeIdx && targets[i] < numNodes) {
|
|
1426
|
+
neighbors.add(targets[i]);
|
|
1427
|
+
}
|
|
1428
|
+
if (targets[i] === nodeIdx && sources[i] < numNodes) {
|
|
1429
|
+
neighbors.add(sources[i]);
|
|
1430
|
+
}
|
|
1431
|
+
}
|
|
1432
|
+
|
|
1433
|
+
return Array.from(neighbors);
|
|
1434
|
+
}
|
|
1435
|
+
|
|
1436
|
+
private mlpLayer1(input: number[]): number[] {
|
|
1437
|
+
const hiddenDim = this.config.hiddenDim ?? this.config.inputDim;
|
|
1438
|
+
const output = new Array(hiddenDim).fill(0);
|
|
1439
|
+
|
|
1440
|
+
for (let i = 0; i < hiddenDim; i++) {
|
|
1441
|
+
for (let j = 0; j < input.length; j++) {
|
|
1442
|
+
const weight = Math.sin((i * input.length + j) * 0.1) * 0.5;
|
|
1443
|
+
output[i] += (input[j] ?? 0) * weight;
|
|
1444
|
+
}
|
|
1445
|
+
output[i] = this.applyActivation(output[i]);
|
|
1446
|
+
}
|
|
1447
|
+
|
|
1448
|
+
return output;
|
|
1449
|
+
}
|
|
1450
|
+
|
|
1451
|
+
private mlpLayer2(input: number[]): number[] {
|
|
1452
|
+
const output = new Array(this.config.outputDim).fill(0);
|
|
1453
|
+
|
|
1454
|
+
for (let i = 0; i < this.config.outputDim; i++) {
|
|
1455
|
+
for (let j = 0; j < input.length; j++) {
|
|
1456
|
+
const weight = Math.cos((i * input.length + j) * 0.1) * 0.5;
|
|
1457
|
+
output[i] += (input[j] ?? 0) * weight;
|
|
1458
|
+
}
|
|
1459
|
+
}
|
|
1460
|
+
|
|
1461
|
+
return output;
|
|
1462
|
+
}
|
|
1463
|
+
}
|
|
1464
|
+
|
|
1465
|
+
// ============================================================================
|
|
1466
|
+
// MPNN Layer Implementation
|
|
1467
|
+
// ============================================================================
|
|
1468
|
+
|
|
1469
|
+
/**
|
|
1470
|
+
* Message Passing Neural Network (MPNN) layer.
|
|
1471
|
+
*
|
|
1472
|
+
* General framework for GNN with customizable message and update functions.
|
|
1473
|
+
* Reference: Gilmer et al., "Neural Message Passing for Quantum Chemistry" (2017)
|
|
1474
|
+
*/
|
|
1475
|
+
export class MPNNLayer extends BaseGNNLayer {
|
|
1476
|
+
async forward(graph: GraphData): Promise<GNNOutput> {
|
|
1477
|
+
const startTime = Date.now();
|
|
1478
|
+
const { nodeFeatures, edgeIndex, edgeFeatures } = graph;
|
|
1479
|
+
const numNodes = nodeFeatures.length;
|
|
1480
|
+
const numEdges = edgeIndex[0].length;
|
|
1481
|
+
|
|
1482
|
+
let currentFeatures = nodeFeatures.map((f) => [...f]);
|
|
1483
|
+
|
|
1484
|
+
// Multiple rounds of message passing
|
|
1485
|
+
const numIterations = this.config.params?.numLayers ?? 1;
|
|
1486
|
+
|
|
1487
|
+
for (let t = 0; t < numIterations; t++) {
|
|
1488
|
+
const newFeatures: number[][] = [];
|
|
1489
|
+
|
|
1490
|
+
for (let i = 0; i < numNodes; i++) {
|
|
1491
|
+
// Collect messages from neighbors
|
|
1492
|
+
const messages: Message[] = [];
|
|
1493
|
+
const [sources, targets] = edgeIndex;
|
|
1494
|
+
|
|
1495
|
+
for (let e = 0; e < sources.length; e++) {
|
|
1496
|
+
if (targets[e] === i) {
|
|
1497
|
+
const j = sources[e];
|
|
1498
|
+
const edgeFeat = edgeFeatures?.[e];
|
|
1499
|
+
const message = this.messageFunction(
|
|
1500
|
+
currentFeatures[j] ?? [],
|
|
1501
|
+
currentFeatures[i] ?? [],
|
|
1502
|
+
edgeFeat
|
|
1503
|
+
);
|
|
1504
|
+
messages.push({
|
|
1505
|
+
source: j,
|
|
1506
|
+
target: i,
|
|
1507
|
+
vector: message,
|
|
1508
|
+
edgeFeatures: edgeFeat,
|
|
1509
|
+
});
|
|
1510
|
+
}
|
|
1511
|
+
if (sources[e] === i) {
|
|
1512
|
+
const j = targets[e];
|
|
1513
|
+
const edgeFeat = edgeFeatures?.[e];
|
|
1514
|
+
const message = this.messageFunction(
|
|
1515
|
+
currentFeatures[j] ?? [],
|
|
1516
|
+
currentFeatures[i] ?? [],
|
|
1517
|
+
edgeFeat
|
|
1518
|
+
);
|
|
1519
|
+
messages.push({
|
|
1520
|
+
source: j,
|
|
1521
|
+
target: i,
|
|
1522
|
+
vector: message,
|
|
1523
|
+
edgeFeatures: edgeFeat,
|
|
1524
|
+
});
|
|
1525
|
+
}
|
|
1526
|
+
}
|
|
1527
|
+
|
|
1528
|
+
// Aggregate messages
|
|
1529
|
+
const aggregated = await this.aggregate(messages, this.config.aggregation ?? 'sum');
|
|
1530
|
+
|
|
1531
|
+
// Update node features
|
|
1532
|
+
const updated = this.updateFunction(currentFeatures[i] ?? [], aggregated);
|
|
1533
|
+
newFeatures.push(this.applyDropout(updated));
|
|
1534
|
+
}
|
|
1535
|
+
|
|
1536
|
+
currentFeatures = newFeatures;
|
|
1537
|
+
}
|
|
1538
|
+
|
|
1539
|
+
return {
|
|
1540
|
+
nodeEmbeddings: currentFeatures,
|
|
1541
|
+
graphEmbedding: this.aggregateMean(currentFeatures),
|
|
1542
|
+
stats: this.createStats(startTime, numNodes, numEdges, numIterations),
|
|
1543
|
+
};
|
|
1544
|
+
}
|
|
1545
|
+
|
|
1546
|
+
async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
|
|
1547
|
+
const graph: GraphData = {
|
|
1548
|
+
nodeFeatures: nodes.features,
|
|
1549
|
+
edgeIndex: [
|
|
1550
|
+
edges.sources.map((s) => nodes.ids.indexOf(s)),
|
|
1551
|
+
edges.targets.map((t) => nodes.ids.indexOf(t)),
|
|
1552
|
+
],
|
|
1553
|
+
edgeFeatures: edges.features,
|
|
1554
|
+
};
|
|
1555
|
+
|
|
1556
|
+
const output = await this.forward(graph);
|
|
1557
|
+
|
|
1558
|
+
return {
|
|
1559
|
+
ids: nodes.ids,
|
|
1560
|
+
features: output.nodeEmbeddings,
|
|
1561
|
+
types: nodes.types,
|
|
1562
|
+
labels: nodes.labels,
|
|
1563
|
+
};
|
|
1564
|
+
}
|
|
1565
|
+
|
|
1566
|
+
private messageFunction(
|
|
1567
|
+
sourceFeatures: number[],
|
|
1568
|
+
targetFeatures: number[],
|
|
1569
|
+
edgeFeatures?: number[]
|
|
1570
|
+
): number[] {
|
|
1571
|
+
const dim = this.config.inputDim;
|
|
1572
|
+
const message = new Array(dim).fill(0);
|
|
1573
|
+
|
|
1574
|
+
for (let i = 0; i < dim; i++) {
|
|
1575
|
+
message[i] = (sourceFeatures[i] ?? 0) * 0.5 + (targetFeatures[i] ?? 0) * 0.3;
|
|
1576
|
+
if (edgeFeatures && edgeFeatures[i] !== undefined) {
|
|
1577
|
+
message[i] += edgeFeatures[i] * 0.2;
|
|
1578
|
+
}
|
|
1579
|
+
}
|
|
1580
|
+
|
|
1581
|
+
return message;
|
|
1582
|
+
}
|
|
1583
|
+
|
|
1584
|
+
private updateFunction(nodeFeatures: number[], aggregated: number[]): number[] {
|
|
1585
|
+
const output = new Array(this.config.outputDim).fill(0);
|
|
1586
|
+
|
|
1587
|
+
// GRU-like update
|
|
1588
|
+
for (let i = 0; i < this.config.outputDim; i++) {
|
|
1589
|
+
const nodeVal = nodeFeatures[i % nodeFeatures.length] ?? 0;
|
|
1590
|
+
const aggVal = aggregated[i % aggregated.length] ?? 0;
|
|
1591
|
+
const gate = 1 / (1 + Math.exp(-(nodeVal + aggVal)));
|
|
1592
|
+
output[i] = this.applyActivation(gate * aggVal + (1 - gate) * nodeVal);
|
|
1593
|
+
}
|
|
1594
|
+
|
|
1595
|
+
return output;
|
|
1596
|
+
}
|
|
1597
|
+
}
|
|
1598
|
+
|
|
1599
|
+
// ============================================================================
|
|
1600
|
+
// EdgeConv Layer Implementation
|
|
1601
|
+
// ============================================================================
|
|
1602
|
+
|
|
1603
|
+
/**
|
|
1604
|
+
* EdgeConv layer for dynamic graph convolution.
|
|
1605
|
+
*
|
|
1606
|
+
* Uses k-NN graph construction and edge features.
|
|
1607
|
+
* Reference: Wang et al., "Dynamic Graph CNN for Learning on Point Clouds" (2019)
|
|
1608
|
+
*/
|
|
1609
|
+
export class EdgeConvLayer extends BaseGNNLayer {
|
|
1610
|
+
async forward(graph: GraphData): Promise<GNNOutput> {
|
|
1611
|
+
const startTime = Date.now();
|
|
1612
|
+
const { nodeFeatures } = graph;
|
|
1613
|
+
const numNodes = nodeFeatures.length;
|
|
1614
|
+
const k = this.config.params?.k ?? 20;
|
|
1615
|
+
const dynamic = this.config.params?.dynamic ?? true;
|
|
1616
|
+
|
|
1617
|
+
// Build k-NN graph
|
|
1618
|
+
const knnGraph = dynamic
|
|
1619
|
+
? this.buildKNNGraph(nodeFeatures, k)
|
|
1620
|
+
: graph.edgeIndex;
|
|
1621
|
+
|
|
1622
|
+
const outputFeatures: number[][] = [];
|
|
1623
|
+
|
|
1624
|
+
for (let i = 0; i < numNodes; i++) {
|
|
1625
|
+
const neighbors = this.getKNNNeighbors(i, knnGraph);
|
|
1626
|
+
const selfFeatures = nodeFeatures[i] ?? [];
|
|
1627
|
+
|
|
1628
|
+
// Edge features: (x_j - x_i) || x_i
|
|
1629
|
+
const edgeFeatures: number[][] = [];
|
|
1630
|
+
for (const j of neighbors) {
|
|
1631
|
+
const neighborFeatures = nodeFeatures[j] ?? [];
|
|
1632
|
+
const diff = selfFeatures.map((v, idx) => (neighborFeatures[idx] ?? 0) - v);
|
|
1633
|
+
edgeFeatures.push([...diff, ...selfFeatures]);
|
|
1634
|
+
}
|
|
1635
|
+
|
|
1636
|
+
// Max pooling over edge features
|
|
1637
|
+
const pooled = this.maxPoolEdges(edgeFeatures);
|
|
1638
|
+
|
|
1639
|
+
// MLP on pooled features
|
|
1640
|
+
const output = this.edgeMLP(pooled);
|
|
1641
|
+
outputFeatures.push(this.applyDropout(output));
|
|
1642
|
+
}
|
|
1643
|
+
|
|
1644
|
+
return {
|
|
1645
|
+
nodeEmbeddings: outputFeatures,
|
|
1646
|
+
graphEmbedding: this.aggregateMean(outputFeatures),
|
|
1647
|
+
stats: this.createStats(startTime, numNodes, numNodes * k),
|
|
1648
|
+
};
|
|
1649
|
+
}
|
|
1650
|
+
|
|
1651
|
+
async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
|
|
1652
|
+
const graph: GraphData = {
|
|
1653
|
+
nodeFeatures: nodes.features,
|
|
1654
|
+
edgeIndex: [
|
|
1655
|
+
edges.sources.map((s) => nodes.ids.indexOf(s)),
|
|
1656
|
+
edges.targets.map((t) => nodes.ids.indexOf(t)),
|
|
1657
|
+
],
|
|
1658
|
+
};
|
|
1659
|
+
|
|
1660
|
+
const output = await this.forward(graph);
|
|
1661
|
+
|
|
1662
|
+
return {
|
|
1663
|
+
ids: nodes.ids,
|
|
1664
|
+
features: output.nodeEmbeddings,
|
|
1665
|
+
types: nodes.types,
|
|
1666
|
+
labels: nodes.labels,
|
|
1667
|
+
};
|
|
1668
|
+
}
|
|
1669
|
+
|
|
1670
|
+
private buildKNNGraph(features: number[][], k: number): [number[], number[]] {
|
|
1671
|
+
const sources: number[] = [];
|
|
1672
|
+
const targets: number[] = [];
|
|
1673
|
+
|
|
1674
|
+
for (let i = 0; i < features.length; i++) {
|
|
1675
|
+
const distances: { idx: number; dist: number }[] = [];
|
|
1676
|
+
|
|
1677
|
+
for (let j = 0; j < features.length; j++) {
|
|
1678
|
+
if (i !== j) {
|
|
1679
|
+
const dist = this.euclideanDistance(features[i], features[j]);
|
|
1680
|
+
distances.push({ idx: j, dist });
|
|
1681
|
+
}
|
|
1682
|
+
}
|
|
1683
|
+
|
|
1684
|
+
distances.sort((a, b) => a.dist - b.dist);
|
|
1685
|
+
const neighbors = distances.slice(0, k);
|
|
1686
|
+
|
|
1687
|
+
for (const neighbor of neighbors) {
|
|
1688
|
+
sources.push(i);
|
|
1689
|
+
targets.push(neighbor.idx);
|
|
1690
|
+
}
|
|
1691
|
+
}
|
|
1692
|
+
|
|
1693
|
+
return [sources, targets];
|
|
1694
|
+
}
|
|
1695
|
+
|
|
1696
|
+
private euclideanDistance(a: number[], b: number[]): number {
|
|
1697
|
+
let sum = 0;
|
|
1698
|
+
for (let i = 0; i < a.length; i++) {
|
|
1699
|
+
const diff = (a[i] ?? 0) - (b[i] ?? 0);
|
|
1700
|
+
sum += diff * diff;
|
|
1701
|
+
}
|
|
1702
|
+
return Math.sqrt(sum);
|
|
1703
|
+
}
|
|
1704
|
+
|
|
1705
|
+
private getKNNNeighbors(nodeIdx: number, edgeIndex: [number[], number[]]): number[] {
|
|
1706
|
+
const neighbors: number[] = [];
|
|
1707
|
+
const [sources, targets] = edgeIndex;
|
|
1708
|
+
|
|
1709
|
+
for (let i = 0; i < sources.length; i++) {
|
|
1710
|
+
if (sources[i] === nodeIdx) {
|
|
1711
|
+
neighbors.push(targets[i]);
|
|
1712
|
+
}
|
|
1713
|
+
}
|
|
1714
|
+
|
|
1715
|
+
return neighbors;
|
|
1716
|
+
}
|
|
1717
|
+
|
|
1718
|
+
private maxPoolEdges(edgeFeatures: number[][]): number[] {
|
|
1719
|
+
if (edgeFeatures.length === 0) {
|
|
1720
|
+
return new Array(this.config.inputDim * 2).fill(0);
|
|
1721
|
+
}
|
|
1722
|
+
return this.aggregateMax(edgeFeatures);
|
|
1723
|
+
}
|
|
1724
|
+
|
|
1725
|
+
private edgeMLP(input: number[]): number[] {
|
|
1726
|
+
const output = new Array(this.config.outputDim).fill(0);
|
|
1727
|
+
|
|
1728
|
+
for (let i = 0; i < this.config.outputDim; i++) {
|
|
1729
|
+
for (let j = 0; j < input.length; j++) {
|
|
1730
|
+
const weight = Math.sin((i * input.length + j) * 0.08) * 0.4;
|
|
1731
|
+
output[i] += (input[j] ?? 0) * weight;
|
|
1732
|
+
}
|
|
1733
|
+
output[i] = this.applyActivation(output[i]);
|
|
1734
|
+
}
|
|
1735
|
+
|
|
1736
|
+
return output;
|
|
1737
|
+
}
|
|
1738
|
+
}
|
|
1739
|
+
|
|
1740
|
+
// ============================================================================
|
|
1741
|
+
// Additional GNN Layer Implementations (Stubs)
|
|
1742
|
+
// ============================================================================
|
|
1743
|
+
|
|
1744
|
+
/**
|
|
1745
|
+
* Point Convolution layer for point cloud data.
|
|
1746
|
+
*/
|
|
1747
|
+
export class PointConvLayer extends EdgeConvLayer {
|
|
1748
|
+
// Extends EdgeConv with point-specific operations
|
|
1749
|
+
}
|
|
1750
|
+
|
|
1751
|
+
/**
|
|
1752
|
+
* Graph Transformer layer.
|
|
1753
|
+
*/
|
|
1754
|
+
export class GraphTransformerLayer extends GATLayer {
|
|
1755
|
+
override async forward(graph: GraphData): Promise<GNNOutput> {
|
|
1756
|
+
// Add positional encoding and full attention
|
|
1757
|
+
const result = await super.forward(graph);
|
|
1758
|
+
|
|
1759
|
+
// Apply transformer-specific operations (layer norm, residual)
|
|
1760
|
+
const normalizedEmbeddings = result.nodeEmbeddings.map((f) =>
|
|
1761
|
+
this.layerNorm(f)
|
|
1762
|
+
);
|
|
1763
|
+
|
|
1764
|
+
return {
|
|
1765
|
+
...result,
|
|
1766
|
+
nodeEmbeddings: normalizedEmbeddings,
|
|
1767
|
+
};
|
|
1768
|
+
}
|
|
1769
|
+
|
|
1770
|
+
private layerNorm(features: number[]): number[] {
|
|
1771
|
+
const mean = features.reduce((a, b) => a + b, 0) / features.length;
|
|
1772
|
+
const variance =
|
|
1773
|
+
features.reduce((sum, x) => sum + (x - mean) ** 2, 0) / features.length;
|
|
1774
|
+
const std = Math.sqrt(variance + 1e-6);
|
|
1775
|
+
return features.map((x) => (x - mean) / std);
|
|
1776
|
+
}
|
|
1777
|
+
}
|
|
1778
|
+
|
|
1779
|
+
/**
|
|
1780
|
+
* Principal Neighbourhood Aggregation (PNA) layer.
|
|
1781
|
+
*/
|
|
1782
|
+
export class PNALayer extends BaseGNNLayer {
|
|
1783
|
+
async forward(graph: GraphData): Promise<GNNOutput> {
|
|
1784
|
+
const startTime = Date.now();
|
|
1785
|
+
const { nodeFeatures, edgeIndex } = graph;
|
|
1786
|
+
const numNodes = nodeFeatures.length;
|
|
1787
|
+
const numEdges = edgeIndex[0].length;
|
|
1788
|
+
|
|
1789
|
+
const aggregators = this.config.params?.aggregators ?? ['mean', 'sum', 'max', 'min'];
|
|
1790
|
+
const scalers = this.config.params?.scalers ?? ['identity', 'amplification', 'attenuation'];
|
|
1791
|
+
|
|
1792
|
+
const outputFeatures: number[][] = [];
|
|
1793
|
+
|
|
1794
|
+
for (let i = 0; i < numNodes; i++) {
|
|
1795
|
+
const neighbors = this.getNeighbors(i, edgeIndex, numNodes);
|
|
1796
|
+
const neighborFeatures = neighbors.map((j) => nodeFeatures[j] ?? []);
|
|
1797
|
+
const degree = neighbors.length || 1;
|
|
1798
|
+
|
|
1799
|
+
// Apply multiple aggregators
|
|
1800
|
+
const aggregatedResults: number[][] = [];
|
|
1801
|
+
for (const agg of aggregators) {
|
|
1802
|
+
const messages = neighborFeatures.map((f) => ({
|
|
1803
|
+
source: 0,
|
|
1804
|
+
target: i,
|
|
1805
|
+
vector: f,
|
|
1806
|
+
}));
|
|
1807
|
+
const result = await this.aggregate(messages, agg as AggregationMethod);
|
|
1808
|
+
aggregatedResults.push(result);
|
|
1809
|
+
}
|
|
1810
|
+
|
|
1811
|
+
// Apply scalers
|
|
1812
|
+
const scaledResults: number[][] = [];
|
|
1813
|
+
for (const aggregated of aggregatedResults) {
|
|
1814
|
+
for (const scaler of scalers) {
|
|
1815
|
+
scaledResults.push(this.applyScaler(aggregated, scaler, degree));
|
|
1816
|
+
}
|
|
1817
|
+
}
|
|
1818
|
+
|
|
1819
|
+
// Concatenate and project
|
|
1820
|
+
const combined = scaledResults.flat();
|
|
1821
|
+
const projected = this.projectFeatures(combined);
|
|
1822
|
+
outputFeatures.push(this.applyDropout(projected.map((x) => this.applyActivation(x))));
|
|
1823
|
+
}
|
|
1824
|
+
|
|
1825
|
+
return {
|
|
1826
|
+
nodeEmbeddings: outputFeatures,
|
|
1827
|
+
graphEmbedding: this.aggregateMean(outputFeatures),
|
|
1828
|
+
stats: this.createStats(startTime, numNodes, numEdges),
|
|
1829
|
+
};
|
|
1830
|
+
}
|
|
1831
|
+
|
|
1832
|
+
async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
|
|
1833
|
+
const graph: GraphData = {
|
|
1834
|
+
nodeFeatures: nodes.features,
|
|
1835
|
+
edgeIndex: [
|
|
1836
|
+
edges.sources.map((s) => nodes.ids.indexOf(s)),
|
|
1837
|
+
edges.targets.map((t) => nodes.ids.indexOf(t)),
|
|
1838
|
+
],
|
|
1839
|
+
};
|
|
1840
|
+
|
|
1841
|
+
const output = await this.forward(graph);
|
|
1842
|
+
|
|
1843
|
+
return {
|
|
1844
|
+
ids: nodes.ids,
|
|
1845
|
+
features: output.nodeEmbeddings,
|
|
1846
|
+
types: nodes.types,
|
|
1847
|
+
labels: nodes.labels,
|
|
1848
|
+
};
|
|
1849
|
+
}
|
|
1850
|
+
|
|
1851
|
+
private getNeighbors(
|
|
1852
|
+
nodeIdx: number,
|
|
1853
|
+
edgeIndex: [number[], number[]],
|
|
1854
|
+
numNodes: number
|
|
1855
|
+
): number[] {
|
|
1856
|
+
const neighbors = new Set<number>();
|
|
1857
|
+
const [sources, targets] = edgeIndex;
|
|
1858
|
+
|
|
1859
|
+
for (let i = 0; i < sources.length; i++) {
|
|
1860
|
+
if (sources[i] === nodeIdx && targets[i] < numNodes) {
|
|
1861
|
+
neighbors.add(targets[i]);
|
|
1862
|
+
}
|
|
1863
|
+
if (targets[i] === nodeIdx && sources[i] < numNodes) {
|
|
1864
|
+
neighbors.add(sources[i]);
|
|
1865
|
+
}
|
|
1866
|
+
}
|
|
1867
|
+
|
|
1868
|
+
return Array.from(neighbors);
|
|
1869
|
+
}
|
|
1870
|
+
|
|
1871
|
+
private applyScaler(
|
|
1872
|
+
features: number[],
|
|
1873
|
+
scaler: string,
|
|
1874
|
+
degree: number
|
|
1875
|
+
): number[] {
|
|
1876
|
+
switch (scaler) {
|
|
1877
|
+
case 'amplification':
|
|
1878
|
+
return features.map((x) => x * Math.log(degree + 1));
|
|
1879
|
+
case 'attenuation':
|
|
1880
|
+
return features.map((x) => x / Math.log(degree + 1));
|
|
1881
|
+
case 'identity':
|
|
1882
|
+
default:
|
|
1883
|
+
return features;
|
|
1884
|
+
}
|
|
1885
|
+
}
|
|
1886
|
+
|
|
1887
|
+
private projectFeatures(input: number[]): number[] {
|
|
1888
|
+
const output = new Array(this.config.outputDim).fill(0);
|
|
1889
|
+
for (let i = 0; i < this.config.outputDim; i++) {
|
|
1890
|
+
for (let j = 0; j < Math.min(input.length, 100); j++) {
|
|
1891
|
+
const weight = Math.sin((i * 100 + j) * 0.1) * 0.3;
|
|
1892
|
+
output[i] += (input[j] ?? 0) * weight;
|
|
1893
|
+
}
|
|
1894
|
+
}
|
|
1895
|
+
return output;
|
|
1896
|
+
}
|
|
1897
|
+
}
|
|
1898
|
+
|
|
1899
|
+
/**
|
|
1900
|
+
* FiLM (Feature-wise Linear Modulation) layer.
|
|
1901
|
+
*/
|
|
1902
|
+
export class FiLMLayer extends BaseGNNLayer {
|
|
1903
|
+
async forward(graph: GraphData): Promise<GNNOutput> {
|
|
1904
|
+
const startTime = Date.now();
|
|
1905
|
+
const { nodeFeatures, edgeIndex, edgeFeatures } = graph;
|
|
1906
|
+
const numNodes = nodeFeatures.length;
|
|
1907
|
+
const numEdges = edgeIndex[0].length;
|
|
1908
|
+
|
|
1909
|
+
const outputFeatures: number[][] = [];
|
|
1910
|
+
|
|
1911
|
+
for (let i = 0; i < numNodes; i++) {
|
|
1912
|
+
const selfFeatures = nodeFeatures[i] ?? [];
|
|
1913
|
+
|
|
1914
|
+
// Compute modulation parameters from edge features
|
|
1915
|
+
const { gamma, beta } = this.computeModulation(edgeFeatures ?? []);
|
|
1916
|
+
|
|
1917
|
+
// Apply FiLM: gamma * x + beta
|
|
1918
|
+
const modulated = selfFeatures.map((x, idx) =>
|
|
1919
|
+
(gamma[idx % gamma.length] ?? 1) * x + (beta[idx % beta.length] ?? 0)
|
|
1920
|
+
);
|
|
1921
|
+
|
|
1922
|
+
outputFeatures.push(this.applyDropout(modulated.map((x) => this.applyActivation(x))));
|
|
1923
|
+
}
|
|
1924
|
+
|
|
1925
|
+
return {
|
|
1926
|
+
nodeEmbeddings: outputFeatures,
|
|
1927
|
+
graphEmbedding: this.aggregateMean(outputFeatures),
|
|
1928
|
+
stats: this.createStats(startTime, numNodes, numEdges),
|
|
1929
|
+
};
|
|
1930
|
+
}
|
|
1931
|
+
|
|
1932
|
+
async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
|
|
1933
|
+
const graph: GraphData = {
|
|
1934
|
+
nodeFeatures: nodes.features,
|
|
1935
|
+
edgeIndex: [
|
|
1936
|
+
edges.sources.map((s) => nodes.ids.indexOf(s)),
|
|
1937
|
+
edges.targets.map((t) => nodes.ids.indexOf(t)),
|
|
1938
|
+
],
|
|
1939
|
+
edgeFeatures: edges.features,
|
|
1940
|
+
};
|
|
1941
|
+
|
|
1942
|
+
const output = await this.forward(graph);
|
|
1943
|
+
|
|
1944
|
+
return {
|
|
1945
|
+
ids: nodes.ids,
|
|
1946
|
+
features: output.nodeEmbeddings,
|
|
1947
|
+
types: nodes.types,
|
|
1948
|
+
labels: nodes.labels,
|
|
1949
|
+
};
|
|
1950
|
+
}
|
|
1951
|
+
|
|
1952
|
+
private computeModulation(edgeFeatures: number[][]): { gamma: number[]; beta: number[] } {
|
|
1953
|
+
const dim = this.config.outputDim;
|
|
1954
|
+
const gamma = new Array(dim).fill(1);
|
|
1955
|
+
const beta = new Array(dim).fill(0);
|
|
1956
|
+
|
|
1957
|
+
if (edgeFeatures.length > 0) {
|
|
1958
|
+
const meanEdge = this.aggregateMean(edgeFeatures);
|
|
1959
|
+
for (let i = 0; i < dim; i++) {
|
|
1960
|
+
gamma[i] = 1 + 0.1 * (meanEdge[i % meanEdge.length] ?? 0);
|
|
1961
|
+
beta[i] = 0.1 * (meanEdge[(i + dim / 2) % meanEdge.length] ?? 0);
|
|
1962
|
+
}
|
|
1963
|
+
}
|
|
1964
|
+
|
|
1965
|
+
return { gamma, beta };
|
|
1966
|
+
}
|
|
1967
|
+
}
|
|
1968
|
+
|
|
1969
|
+
/**
|
|
1970
|
+
* Relational Graph Convolutional Network (RGCN) layer.
|
|
1971
|
+
*/
|
|
1972
|
+
export class RGCNLayer extends GCNLayer {
|
|
1973
|
+
override async forward(graph: GraphData): Promise<GNNOutput> {
|
|
1974
|
+
const startTime = Date.now();
|
|
1975
|
+
const { nodeFeatures, edgeIndex, edgeTypes } = graph;
|
|
1976
|
+
const numNodes = nodeFeatures.length;
|
|
1977
|
+
const numEdges = edgeIndex[0].length;
|
|
1978
|
+
const numRelations = this.config.params?.numRelations ?? 1;
|
|
1979
|
+
|
|
1980
|
+
// Process each relation type separately
|
|
1981
|
+
const relationOutputs: number[][][] = [];
|
|
1982
|
+
|
|
1983
|
+
for (let r = 0; r < numRelations; r++) {
|
|
1984
|
+
// Filter edges by relation type
|
|
1985
|
+
const relationEdges = this.filterEdgesByType(edgeIndex, edgeTypes ?? [], r);
|
|
1986
|
+
|
|
1987
|
+
// Apply GCN for this relation
|
|
1988
|
+
const relationGraph: GraphData = {
|
|
1989
|
+
nodeFeatures,
|
|
1990
|
+
edgeIndex: relationEdges,
|
|
1991
|
+
};
|
|
1992
|
+
|
|
1993
|
+
const result = await super.forward(relationGraph);
|
|
1994
|
+
relationOutputs.push(result.nodeEmbeddings);
|
|
1995
|
+
}
|
|
1996
|
+
|
|
1997
|
+
// Combine relation outputs
|
|
1998
|
+
const outputFeatures = this.combineRelationOutputs(relationOutputs);
|
|
1999
|
+
|
|
2000
|
+
return {
|
|
2001
|
+
nodeEmbeddings: outputFeatures,
|
|
2002
|
+
graphEmbedding: this.aggregateMean(outputFeatures),
|
|
2003
|
+
stats: this.createStats(startTime, numNodes, numEdges),
|
|
2004
|
+
};
|
|
2005
|
+
}
|
|
2006
|
+
|
|
2007
|
+
private filterEdgesByType(
|
|
2008
|
+
edgeIndex: [number[], number[]],
|
|
2009
|
+
edgeTypes: number[],
|
|
2010
|
+
targetType: number
|
|
2011
|
+
): [number[], number[]] {
|
|
2012
|
+
const sources: number[] = [];
|
|
2013
|
+
const targets: number[] = [];
|
|
2014
|
+
const [srcArr, tgtArr] = edgeIndex;
|
|
2015
|
+
|
|
2016
|
+
for (let i = 0; i < srcArr.length; i++) {
|
|
2017
|
+
if (edgeTypes[i] === targetType || edgeTypes.length === 0) {
|
|
2018
|
+
sources.push(srcArr[i]);
|
|
2019
|
+
targets.push(tgtArr[i]);
|
|
2020
|
+
}
|
|
2021
|
+
}
|
|
2022
|
+
|
|
2023
|
+
return [sources, targets];
|
|
2024
|
+
}
|
|
2025
|
+
|
|
2026
|
+
private combineRelationOutputs(outputs: number[][][]): number[][] {
|
|
2027
|
+
if (outputs.length === 0) return [];
|
|
2028
|
+
if (outputs.length === 1) return outputs[0];
|
|
2029
|
+
|
|
2030
|
+
const numNodes = outputs[0].length;
|
|
2031
|
+
const result: number[][] = [];
|
|
2032
|
+
|
|
2033
|
+
for (let i = 0; i < numNodes; i++) {
|
|
2034
|
+
const nodeOutputs = outputs.map((o) => o[i] ?? []);
|
|
2035
|
+
result.push(this.aggregateMean(nodeOutputs));
|
|
2036
|
+
}
|
|
2037
|
+
|
|
2038
|
+
return result;
|
|
2039
|
+
}
|
|
2040
|
+
}
|
|
2041
|
+
|
|
2042
|
+
/**
|
|
2043
|
+
* Heterogeneous Graph Transformer (HGT) layer.
|
|
2044
|
+
*/
|
|
2045
|
+
export class HGTLayer extends GATLayer {
|
|
2046
|
+
override async forward(graph: GraphData): Promise<GNNOutput> {
|
|
2047
|
+
// HGT uses type-specific transformations
|
|
2048
|
+
const { nodeFeatures, nodeTypes } = graph;
|
|
2049
|
+
|
|
2050
|
+
// Transform features based on node types
|
|
2051
|
+
const transformedFeatures = nodeFeatures.map((f, i) => {
|
|
2052
|
+
const nodeType = nodeTypes?.[i] ?? 0;
|
|
2053
|
+
return this.typeSpecificTransform(f, nodeType);
|
|
2054
|
+
});
|
|
2055
|
+
|
|
2056
|
+
const transformedGraph: GraphData = {
|
|
2057
|
+
...graph,
|
|
2058
|
+
nodeFeatures: transformedFeatures,
|
|
2059
|
+
};
|
|
2060
|
+
|
|
2061
|
+
return super.forward(transformedGraph);
|
|
2062
|
+
}
|
|
2063
|
+
|
|
2064
|
+
private typeSpecificTransform(features: number[], nodeType: number): number[] {
|
|
2065
|
+
// Apply type-specific transformation
|
|
2066
|
+
return features.map((x, i) => {
|
|
2067
|
+
const weight = Math.sin((nodeType * this.config.inputDim + i) * 0.1);
|
|
2068
|
+
return x * (1 + 0.1 * weight);
|
|
2069
|
+
});
|
|
2070
|
+
}
|
|
2071
|
+
}
|
|
2072
|
+
|
|
2073
|
+
/**
|
|
2074
|
+
* Heterogeneous Attention Network (HAN) layer.
|
|
2075
|
+
*/
|
|
2076
|
+
export class HANLayer extends GATLayer {
|
|
2077
|
+
override async forward(graph: GraphData): Promise<GNNOutput> {
|
|
2078
|
+
const metapaths = this.config.params?.metapaths ?? [];
|
|
2079
|
+
|
|
2080
|
+
if (metapaths.length === 0) {
|
|
2081
|
+
return super.forward(graph);
|
|
2082
|
+
}
|
|
2083
|
+
|
|
2084
|
+
// Process each metapath
|
|
2085
|
+
const metapathOutputs: number[][][] = [];
|
|
2086
|
+
|
|
2087
|
+
for (const metapath of metapaths) {
|
|
2088
|
+
const metapathGraph = this.extractMetapathSubgraph(graph, metapath);
|
|
2089
|
+
const result = await super.forward(metapathGraph);
|
|
2090
|
+
metapathOutputs.push(result.nodeEmbeddings);
|
|
2091
|
+
}
|
|
2092
|
+
|
|
2093
|
+
// Attention over metapaths
|
|
2094
|
+
const outputFeatures = this.attentionOverMetapaths(metapathOutputs);
|
|
2095
|
+
|
|
2096
|
+
return {
|
|
2097
|
+
nodeEmbeddings: outputFeatures,
|
|
2098
|
+
graphEmbedding: this.aggregateMean(outputFeatures),
|
|
2099
|
+
stats: {
|
|
2100
|
+
forwardTimeMs: 0,
|
|
2101
|
+
numNodes: graph.nodeFeatures.length,
|
|
2102
|
+
numEdges: graph.edgeIndex[0].length,
|
|
2103
|
+
memoryBytes: 0,
|
|
2104
|
+
numIterations: metapaths.length,
|
|
2105
|
+
},
|
|
2106
|
+
};
|
|
2107
|
+
}
|
|
2108
|
+
|
|
2109
|
+
private extractMetapathSubgraph(graph: GraphData, _metapath: string[]): GraphData {
|
|
2110
|
+
// Simplified: return original graph
|
|
2111
|
+
// In practice, would filter edges based on metapath
|
|
2112
|
+
// The _metapath parameter would be used to filter edge types
|
|
2113
|
+
return graph;
|
|
2114
|
+
}
|
|
2115
|
+
|
|
2116
|
+
private attentionOverMetapaths(outputs: number[][][]): number[][] {
|
|
2117
|
+
if (outputs.length === 0) return [];
|
|
2118
|
+
if (outputs.length === 1) return outputs[0];
|
|
2119
|
+
|
|
2120
|
+
const numNodes = outputs[0].length;
|
|
2121
|
+
const result: number[][] = [];
|
|
2122
|
+
|
|
2123
|
+
// Compute attention weights for metapaths
|
|
2124
|
+
const metapathWeights = outputs.map((o) => {
|
|
2125
|
+
const importance = o.reduce(
|
|
2126
|
+
(sum, node) => sum + node.reduce((s, v) => s + Math.abs(v), 0),
|
|
2127
|
+
0
|
|
2128
|
+
);
|
|
2129
|
+
return importance;
|
|
2130
|
+
});
|
|
2131
|
+
|
|
2132
|
+
const maxWeight = Math.max(...metapathWeights);
|
|
2133
|
+
const expWeights = metapathWeights.map((w) => Math.exp((w - maxWeight) / 10));
|
|
2134
|
+
const sumExp = expWeights.reduce((a, b) => a + b, 0);
|
|
2135
|
+
const normalizedWeights = expWeights.map((w) => w / sumExp);
|
|
2136
|
+
|
|
2137
|
+
for (let i = 0; i < numNodes; i++) {
|
|
2138
|
+
const dim = outputs[0][i]?.length ?? 0;
|
|
2139
|
+
const combined = new Array(dim).fill(0);
|
|
2140
|
+
|
|
2141
|
+
for (let m = 0; m < outputs.length; m++) {
|
|
2142
|
+
const nodeFeatures = outputs[m][i] ?? [];
|
|
2143
|
+
for (let j = 0; j < dim; j++) {
|
|
2144
|
+
combined[j] += normalizedWeights[m] * (nodeFeatures[j] ?? 0);
|
|
2145
|
+
}
|
|
2146
|
+
}
|
|
2147
|
+
|
|
2148
|
+
result.push(combined);
|
|
2149
|
+
}
|
|
2150
|
+
|
|
2151
|
+
return result;
|
|
2152
|
+
}
|
|
2153
|
+
}
|
|
2154
|
+
|
|
2155
|
+
/**
|
|
2156
|
+
* MetaPath-based aggregation layer.
|
|
2157
|
+
*/
|
|
2158
|
+
export class MetaPathLayer extends HANLayer {
|
|
2159
|
+
// Extends HAN with metapath-specific functionality
|
|
2160
|
+
}
|
|
2161
|
+
|
|
2162
|
+
// ============================================================================
|
|
2163
|
+
// Graph Operations
|
|
2164
|
+
// ============================================================================
|
|
2165
|
+
|
|
2166
|
+
/**
|
|
2167
|
+
* Graph operations for advanced graph analytics.
|
|
2168
|
+
*
|
|
2169
|
+
* @example
|
|
2170
|
+
* ```typescript
|
|
2171
|
+
* const ops = new GraphOperations();
|
|
2172
|
+
* const neighbors = await ops.kHopNeighbors('node1', 2);
|
|
2173
|
+
* const path = await ops.shortestPath('source', 'target');
|
|
2174
|
+
* const ranks = await ops.pageRank({ damping: 0.85 });
|
|
2175
|
+
* const communities = await ops.communityDetection({ algorithm: 'louvain' });
|
|
2176
|
+
* ```
|
|
2177
|
+
*/
|
|
2178
|
+
export class GraphOperations {
|
|
2179
|
+
private adjacencyList: Map<NodeId, Set<NodeId>> = new Map();
|
|
2180
|
+
private weights: Map<string, number> = new Map();
|
|
2181
|
+
private nodeFeatures: Map<NodeId, number[]> = new Map();
|
|
2182
|
+
|
|
2183
|
+
/**
|
|
2184
|
+
* Load graph data.
|
|
2185
|
+
*/
|
|
2186
|
+
loadGraph(graph: GraphData): void {
|
|
2187
|
+
this.adjacencyList.clear();
|
|
2188
|
+
this.weights.clear();
|
|
2189
|
+
this.nodeFeatures.clear();
|
|
2190
|
+
|
|
2191
|
+
const { nodeFeatures, edgeIndex, edgeWeights } = graph;
|
|
2192
|
+
const [sources, targets] = edgeIndex;
|
|
2193
|
+
|
|
2194
|
+
// Initialize nodes
|
|
2195
|
+
for (let i = 0; i < nodeFeatures.length; i++) {
|
|
2196
|
+
this.adjacencyList.set(i, new Set());
|
|
2197
|
+
this.nodeFeatures.set(i, nodeFeatures[i]);
|
|
2198
|
+
}
|
|
2199
|
+
|
|
2200
|
+
// Add edges
|
|
2201
|
+
for (let i = 0; i < sources.length; i++) {
|
|
2202
|
+
const src = sources[i];
|
|
2203
|
+
const tgt = targets[i];
|
|
2204
|
+
const weight = edgeWeights?.[i] ?? 1;
|
|
2205
|
+
|
|
2206
|
+
this.adjacencyList.get(src)?.add(tgt);
|
|
2207
|
+
this.adjacencyList.get(tgt)?.add(src);
|
|
2208
|
+
this.weights.set(`${src}-${tgt}`, weight);
|
|
2209
|
+
this.weights.set(`${tgt}-${src}`, weight);
|
|
2210
|
+
}
|
|
2211
|
+
}
|
|
2212
|
+
|
|
2213
|
+
/**
|
|
2214
|
+
* Find k-hop neighbors of a node.
|
|
2215
|
+
*/
|
|
2216
|
+
async kHopNeighbors(nodeId: NodeId, k: number): Promise<NodeId[]> {
|
|
2217
|
+
const visited = new Set<NodeId>();
|
|
2218
|
+
const queue: { node: NodeId; depth: number }[] = [{ node: nodeId, depth: 0 }];
|
|
2219
|
+
const result: NodeId[] = [];
|
|
2220
|
+
|
|
2221
|
+
while (queue.length > 0) {
|
|
2222
|
+
const { node, depth } = queue.shift()!;
|
|
2223
|
+
|
|
2224
|
+
if (visited.has(node)) continue;
|
|
2225
|
+
visited.add(node);
|
|
2226
|
+
|
|
2227
|
+
if (depth > 0) {
|
|
2228
|
+
result.push(node);
|
|
2229
|
+
}
|
|
2230
|
+
|
|
2231
|
+
if (depth < k) {
|
|
2232
|
+
const neighbors = this.adjacencyList.get(node) ?? new Set();
|
|
2233
|
+
for (const neighbor of neighbors) {
|
|
2234
|
+
if (!visited.has(neighbor)) {
|
|
2235
|
+
queue.push({ node: neighbor, depth: depth + 1 });
|
|
2236
|
+
}
|
|
2237
|
+
}
|
|
2238
|
+
}
|
|
2239
|
+
}
|
|
2240
|
+
|
|
2241
|
+
return result;
|
|
2242
|
+
}
|
|
2243
|
+
|
|
2244
|
+
/**
|
|
2245
|
+
* Find shortest path between two nodes using Dijkstra's algorithm.
|
|
2246
|
+
*/
|
|
2247
|
+
async shortestPath(source: NodeId, target: NodeId): Promise<Path> {
|
|
2248
|
+
const distances = new Map<NodeId, number>();
|
|
2249
|
+
const previous = new Map<NodeId, NodeId | null>();
|
|
2250
|
+
const unvisited = new Set<NodeId>(this.adjacencyList.keys());
|
|
2251
|
+
|
|
2252
|
+
for (const node of this.adjacencyList.keys()) {
|
|
2253
|
+
distances.set(node, Infinity);
|
|
2254
|
+
previous.set(node, null);
|
|
2255
|
+
}
|
|
2256
|
+
distances.set(source, 0);
|
|
2257
|
+
|
|
2258
|
+
while (unvisited.size > 0) {
|
|
2259
|
+
// Find minimum distance node
|
|
2260
|
+
let current: NodeId | null = null;
|
|
2261
|
+
let minDist = Infinity;
|
|
2262
|
+
|
|
2263
|
+
for (const node of unvisited) {
|
|
2264
|
+
const dist = distances.get(node) ?? Infinity;
|
|
2265
|
+
if (dist < minDist) {
|
|
2266
|
+
minDist = dist;
|
|
2267
|
+
current = node;
|
|
2268
|
+
}
|
|
2269
|
+
}
|
|
2270
|
+
|
|
2271
|
+
if (current === null || current === target) break;
|
|
2272
|
+
|
|
2273
|
+
unvisited.delete(current);
|
|
2274
|
+
|
|
2275
|
+
// Update neighbors
|
|
2276
|
+
const neighbors = this.adjacencyList.get(current) ?? new Set();
|
|
2277
|
+
for (const neighbor of neighbors) {
|
|
2278
|
+
if (!unvisited.has(neighbor)) continue;
|
|
2279
|
+
|
|
2280
|
+
const edgeWeight = this.weights.get(`${current}-${neighbor}`) ?? 1;
|
|
2281
|
+
const alt = (distances.get(current) ?? Infinity) + edgeWeight;
|
|
2282
|
+
|
|
2283
|
+
if (alt < (distances.get(neighbor) ?? Infinity)) {
|
|
2284
|
+
distances.set(neighbor, alt);
|
|
2285
|
+
previous.set(neighbor, current);
|
|
2286
|
+
}
|
|
2287
|
+
}
|
|
2288
|
+
}
|
|
2289
|
+
|
|
2290
|
+
// Reconstruct path
|
|
2291
|
+
const nodes: NodeId[] = [];
|
|
2292
|
+
let current: NodeId | null = target;
|
|
2293
|
+
|
|
2294
|
+
while (current !== null) {
|
|
2295
|
+
nodes.unshift(current);
|
|
2296
|
+
current = previous.get(current) ?? null;
|
|
2297
|
+
}
|
|
2298
|
+
|
|
2299
|
+
if (nodes[0] !== source) {
|
|
2300
|
+
return { nodes: [], weight: Infinity };
|
|
2301
|
+
}
|
|
2302
|
+
|
|
2303
|
+
return {
|
|
2304
|
+
nodes,
|
|
2305
|
+
weight: distances.get(target) ?? Infinity,
|
|
2306
|
+
};
|
|
2307
|
+
}
|
|
2308
|
+
|
|
2309
|
+
/**
|
|
2310
|
+
* Compute PageRank scores for all nodes.
|
|
2311
|
+
*/
|
|
2312
|
+
async pageRank(options: PageRankOptions = {}): Promise<Map<NodeId, number>> {
|
|
2313
|
+
const damping = options.damping ?? 0.85;
|
|
2314
|
+
const maxIterations = options.maxIterations ?? 100;
|
|
2315
|
+
const tolerance = options.tolerance ?? 1e-6;
|
|
2316
|
+
|
|
2317
|
+
const nodes = Array.from(this.adjacencyList.keys());
|
|
2318
|
+
const n = nodes.length;
|
|
2319
|
+
|
|
2320
|
+
if (n === 0) return new Map();
|
|
2321
|
+
|
|
2322
|
+
// Initialize ranks
|
|
2323
|
+
let ranks = new Map<NodeId, number>();
|
|
2324
|
+
const initialRank = 1 / n;
|
|
2325
|
+
|
|
2326
|
+
for (const node of nodes) {
|
|
2327
|
+
ranks.set(node, options.personalization?.get(node) ?? initialRank);
|
|
2328
|
+
}
|
|
2329
|
+
|
|
2330
|
+
// Power iteration
|
|
2331
|
+
for (let iter = 0; iter < maxIterations; iter++) {
|
|
2332
|
+
const newRanks = new Map<NodeId, number>();
|
|
2333
|
+
let diff = 0;
|
|
2334
|
+
|
|
2335
|
+
for (const node of nodes) {
|
|
2336
|
+
let sum = 0;
|
|
2337
|
+
const neighbors = this.adjacencyList.get(node) ?? new Set();
|
|
2338
|
+
|
|
2339
|
+
for (const neighbor of neighbors) {
|
|
2340
|
+
const neighborOutDegree = this.adjacencyList.get(neighbor)?.size ?? 1;
|
|
2341
|
+
const neighborRank = ranks.get(neighbor) ?? 0;
|
|
2342
|
+
|
|
2343
|
+
if (options.weighted) {
|
|
2344
|
+
const weight = this.weights.get(`${neighbor}-${node}`) ?? 1;
|
|
2345
|
+
sum += (neighborRank * weight) / neighborOutDegree;
|
|
2346
|
+
} else {
|
|
2347
|
+
sum += neighborRank / neighborOutDegree;
|
|
2348
|
+
}
|
|
2349
|
+
}
|
|
2350
|
+
|
|
2351
|
+
const teleport = options.personalization?.get(node) ?? 1 / n;
|
|
2352
|
+
const newRank = (1 - damping) * teleport + damping * sum;
|
|
2353
|
+
newRanks.set(node, newRank);
|
|
2354
|
+
|
|
2355
|
+
diff += Math.abs(newRank - (ranks.get(node) ?? 0));
|
|
2356
|
+
}
|
|
2357
|
+
|
|
2358
|
+
ranks = newRanks;
|
|
2359
|
+
|
|
2360
|
+
if (diff < tolerance) break;
|
|
2361
|
+
}
|
|
2362
|
+
|
|
2363
|
+
return ranks;
|
|
2364
|
+
}
|
|
2365
|
+
|
|
2366
|
+
/**
|
|
2367
|
+
* Detect communities in the graph.
|
|
2368
|
+
*/
|
|
2369
|
+
async communityDetection(options: CommunityOptions): Promise<Community[]> {
|
|
2370
|
+
switch (options.algorithm) {
|
|
2371
|
+
case 'louvain':
|
|
2372
|
+
return this.louvainCommunityDetection(options);
|
|
2373
|
+
case 'label_propagation':
|
|
2374
|
+
return this.labelPropagationCommunityDetection(options);
|
|
2375
|
+
case 'girvan_newman':
|
|
2376
|
+
return this.girvanNewmanCommunityDetection(options);
|
|
2377
|
+
case 'spectral':
|
|
2378
|
+
return this.spectralCommunityDetection(options);
|
|
2379
|
+
default:
|
|
2380
|
+
return this.louvainCommunityDetection(options);
|
|
2381
|
+
}
|
|
2382
|
+
}
|
|
2383
|
+
|
|
2384
|
+
private async louvainCommunityDetection(options: CommunityOptions): Promise<Community[]> {
|
|
2385
|
+
const nodes = Array.from(this.adjacencyList.keys());
|
|
2386
|
+
const resolution = options.resolution ?? 1.0;
|
|
2387
|
+
const maxIterations = options.maxIterations ?? 100;
|
|
2388
|
+
|
|
2389
|
+
// Initialize: each node is its own community
|
|
2390
|
+
const community = new Map<NodeId, number>();
|
|
2391
|
+
let nextCommunityId = 0;
|
|
2392
|
+
|
|
2393
|
+
for (const node of nodes) {
|
|
2394
|
+
community.set(node, nextCommunityId++);
|
|
2395
|
+
}
|
|
2396
|
+
|
|
2397
|
+
// Compute total edge weight
|
|
2398
|
+
let totalWeight = 0;
|
|
2399
|
+
for (const weight of this.weights.values()) {
|
|
2400
|
+
totalWeight += weight;
|
|
2401
|
+
}
|
|
2402
|
+
totalWeight /= 2; // Undirected edges counted twice
|
|
2403
|
+
|
|
2404
|
+
// Phase 1: Local moving
|
|
2405
|
+
for (let iter = 0; iter < maxIterations; iter++) {
|
|
2406
|
+
let improved = false;
|
|
2407
|
+
|
|
2408
|
+
for (const node of nodes) {
|
|
2409
|
+
const currentCommunity = community.get(node)!;
|
|
2410
|
+
const neighbors = this.adjacencyList.get(node) ?? new Set();
|
|
2411
|
+
|
|
2412
|
+
// Find neighbor communities
|
|
2413
|
+
const neighborCommunities = new Set<number>();
|
|
2414
|
+
for (const neighbor of neighbors) {
|
|
2415
|
+
neighborCommunities.add(community.get(neighbor)!);
|
|
2416
|
+
}
|
|
2417
|
+
|
|
2418
|
+
// Find best community
|
|
2419
|
+
let bestCommunity = currentCommunity;
|
|
2420
|
+
let bestModularityGain = 0;
|
|
2421
|
+
|
|
2422
|
+
for (const targetCommunity of neighborCommunities) {
|
|
2423
|
+
if (targetCommunity === currentCommunity) continue;
|
|
2424
|
+
|
|
2425
|
+
const gain = this.modularityGain(
|
|
2426
|
+
node,
|
|
2427
|
+
currentCommunity,
|
|
2428
|
+
targetCommunity,
|
|
2429
|
+
community,
|
|
2430
|
+
resolution,
|
|
2431
|
+
totalWeight
|
|
2432
|
+
);
|
|
2433
|
+
|
|
2434
|
+
if (gain > bestModularityGain) {
|
|
2435
|
+
bestModularityGain = gain;
|
|
2436
|
+
bestCommunity = targetCommunity;
|
|
2437
|
+
}
|
|
2438
|
+
}
|
|
2439
|
+
|
|
2440
|
+
if (bestCommunity !== currentCommunity) {
|
|
2441
|
+
community.set(node, bestCommunity);
|
|
2442
|
+
improved = true;
|
|
2443
|
+
}
|
|
2444
|
+
}
|
|
2445
|
+
|
|
2446
|
+
if (!improved) break;
|
|
2447
|
+
}
|
|
2448
|
+
|
|
2449
|
+
// Build communities
|
|
2450
|
+
const communityMembers = new Map<number, NodeId[]>();
|
|
2451
|
+
for (const [node, commId] of community.entries()) {
|
|
2452
|
+
if (!communityMembers.has(commId)) {
|
|
2453
|
+
communityMembers.set(commId, []);
|
|
2454
|
+
}
|
|
2455
|
+
communityMembers.get(commId)!.push(node);
|
|
2456
|
+
}
|
|
2457
|
+
|
|
2458
|
+
// Filter by minimum size
|
|
2459
|
+
const minSize = options.minSize ?? 1;
|
|
2460
|
+
const communities: Community[] = [];
|
|
2461
|
+
|
|
2462
|
+
for (const [id, members] of communityMembers.entries()) {
|
|
2463
|
+
if (members.length >= minSize) {
|
|
2464
|
+
communities.push({
|
|
2465
|
+
id,
|
|
2466
|
+
members,
|
|
2467
|
+
centroid: this.computeCentroid(members),
|
|
2468
|
+
modularity: this.computeModularity(members, community, totalWeight),
|
|
2469
|
+
density: this.computeDensity(members),
|
|
2470
|
+
});
|
|
2471
|
+
}
|
|
2472
|
+
}
|
|
2473
|
+
|
|
2474
|
+
return communities;
|
|
2475
|
+
}
|
|
2476
|
+
|
|
2477
|
+
private modularityGain(
|
|
2478
|
+
node: NodeId,
|
|
2479
|
+
fromCommunity: number,
|
|
2480
|
+
toCommunity: number,
|
|
2481
|
+
community: Map<NodeId, number>,
|
|
2482
|
+
resolution: number,
|
|
2483
|
+
totalWeight: number
|
|
2484
|
+
): number {
|
|
2485
|
+
const neighbors = this.adjacencyList.get(node) ?? new Set();
|
|
2486
|
+
let linksToCommunity = 0;
|
|
2487
|
+
let linksFromCommunity = 0;
|
|
2488
|
+
|
|
2489
|
+
for (const neighbor of neighbors) {
|
|
2490
|
+
const neighborCommunity = community.get(neighbor)!;
|
|
2491
|
+
const weight = this.weights.get(`${node}-${neighbor}`) ?? 1;
|
|
2492
|
+
|
|
2493
|
+
if (neighborCommunity === toCommunity) {
|
|
2494
|
+
linksToCommunity += weight;
|
|
2495
|
+
}
|
|
2496
|
+
if (neighborCommunity === fromCommunity) {
|
|
2497
|
+
linksFromCommunity += weight;
|
|
2498
|
+
}
|
|
2499
|
+
}
|
|
2500
|
+
|
|
2501
|
+
const nodeDegree = neighbors.size;
|
|
2502
|
+
|
|
2503
|
+
return (
|
|
2504
|
+
(linksToCommunity - linksFromCommunity) / totalWeight -
|
|
2505
|
+
(resolution * nodeDegree * (linksToCommunity - linksFromCommunity)) /
|
|
2506
|
+
(2 * totalWeight * totalWeight)
|
|
2507
|
+
);
|
|
2508
|
+
}
|
|
2509
|
+
|
|
2510
|
+
private computeCentroid(members: NodeId[]): number[] | undefined {
|
|
2511
|
+
if (members.length === 0) return undefined;
|
|
2512
|
+
|
|
2513
|
+
const features = members.map((m) => this.nodeFeatures.get(m) ?? []);
|
|
2514
|
+
if (features[0]?.length === 0) return undefined;
|
|
2515
|
+
|
|
2516
|
+
const dim = features[0].length;
|
|
2517
|
+
const centroid = new Array(dim).fill(0);
|
|
2518
|
+
|
|
2519
|
+
for (const f of features) {
|
|
2520
|
+
for (let i = 0; i < dim; i++) {
|
|
2521
|
+
centroid[i] += (f[i] ?? 0) / members.length;
|
|
2522
|
+
}
|
|
2523
|
+
}
|
|
2524
|
+
|
|
2525
|
+
return centroid;
|
|
2526
|
+
}
|
|
2527
|
+
|
|
2528
|
+
private computeModularity(
|
|
2529
|
+
members: NodeId[],
|
|
2530
|
+
_community: Map<NodeId, number>,
|
|
2531
|
+
totalWeight: number
|
|
2532
|
+
): number {
|
|
2533
|
+
// Note: _community is passed for potential future use in computing inter-community edges
|
|
2534
|
+
let internalEdges = 0;
|
|
2535
|
+
let totalDegree = 0;
|
|
2536
|
+
|
|
2537
|
+
for (const node of members) {
|
|
2538
|
+
const neighbors = this.adjacencyList.get(node) ?? new Set();
|
|
2539
|
+
totalDegree += neighbors.size;
|
|
2540
|
+
|
|
2541
|
+
for (const neighbor of neighbors) {
|
|
2542
|
+
if (members.includes(neighbor)) {
|
|
2543
|
+
internalEdges += this.weights.get(`${node}-${neighbor}`) ?? 1;
|
|
2544
|
+
}
|
|
2545
|
+
}
|
|
2546
|
+
}
|
|
2547
|
+
|
|
2548
|
+
internalEdges /= 2; // Counted twice
|
|
2549
|
+
const expected = (totalDegree * totalDegree) / (4 * totalWeight);
|
|
2550
|
+
|
|
2551
|
+
return (internalEdges - expected) / totalWeight;
|
|
2552
|
+
}
|
|
2553
|
+
|
|
2554
|
+
private computeDensity(members: NodeId[]): number {
|
|
2555
|
+
if (members.length <= 1) return 1;
|
|
2556
|
+
|
|
2557
|
+
let edges = 0;
|
|
2558
|
+
for (const node of members) {
|
|
2559
|
+
const neighbors = this.adjacencyList.get(node) ?? new Set();
|
|
2560
|
+
for (const neighbor of neighbors) {
|
|
2561
|
+
if (members.includes(neighbor)) {
|
|
2562
|
+
edges++;
|
|
2563
|
+
}
|
|
2564
|
+
}
|
|
2565
|
+
}
|
|
2566
|
+
|
|
2567
|
+
edges /= 2;
|
|
2568
|
+
const maxEdges = (members.length * (members.length - 1)) / 2;
|
|
2569
|
+
|
|
2570
|
+
return edges / maxEdges;
|
|
2571
|
+
}
|
|
2572
|
+
|
|
2573
|
+
private async labelPropagationCommunityDetection(options: CommunityOptions): Promise<Community[]> {
|
|
2574
|
+
const nodes = Array.from(this.adjacencyList.keys());
|
|
2575
|
+
const maxIterations = options.maxIterations ?? 100;
|
|
2576
|
+
|
|
2577
|
+
// Initialize labels
|
|
2578
|
+
const labels = new Map<NodeId, number>();
|
|
2579
|
+
let nextLabel = 0;
|
|
2580
|
+
for (const node of nodes) {
|
|
2581
|
+
labels.set(node, nextLabel++);
|
|
2582
|
+
}
|
|
2583
|
+
|
|
2584
|
+
// Iterate
|
|
2585
|
+
for (let iter = 0; iter < maxIterations; iter++) {
|
|
2586
|
+
let changed = false;
|
|
2587
|
+
|
|
2588
|
+
// Shuffle nodes
|
|
2589
|
+
const shuffled = [...nodes].sort(() => Math.random() - 0.5);
|
|
2590
|
+
|
|
2591
|
+
for (const node of shuffled) {
|
|
2592
|
+
const neighbors = this.adjacencyList.get(node) ?? new Set();
|
|
2593
|
+
if (neighbors.size === 0) continue;
|
|
2594
|
+
|
|
2595
|
+
// Count neighbor labels
|
|
2596
|
+
const labelCounts = new Map<number, number>();
|
|
2597
|
+
for (const neighbor of neighbors) {
|
|
2598
|
+
const label = labels.get(neighbor)!;
|
|
2599
|
+
labelCounts.set(label, (labelCounts.get(label) ?? 0) + 1);
|
|
2600
|
+
}
|
|
2601
|
+
|
|
2602
|
+
// Find most common label
|
|
2603
|
+
let maxCount = 0;
|
|
2604
|
+
let bestLabel = labels.get(node)!;
|
|
2605
|
+
for (const [label, count] of labelCounts.entries()) {
|
|
2606
|
+
if (count > maxCount) {
|
|
2607
|
+
maxCount = count;
|
|
2608
|
+
bestLabel = label;
|
|
2609
|
+
}
|
|
2610
|
+
}
|
|
2611
|
+
|
|
2612
|
+
if (bestLabel !== labels.get(node)) {
|
|
2613
|
+
labels.set(node, bestLabel);
|
|
2614
|
+
changed = true;
|
|
2615
|
+
}
|
|
2616
|
+
}
|
|
2617
|
+
|
|
2618
|
+
if (!changed) break;
|
|
2619
|
+
}
|
|
2620
|
+
|
|
2621
|
+
// Build communities
|
|
2622
|
+
const communityMembers = new Map<number, NodeId[]>();
|
|
2623
|
+
for (const [node, label] of labels.entries()) {
|
|
2624
|
+
if (!communityMembers.has(label)) {
|
|
2625
|
+
communityMembers.set(label, []);
|
|
2626
|
+
}
|
|
2627
|
+
communityMembers.get(label)!.push(node);
|
|
2628
|
+
}
|
|
2629
|
+
|
|
2630
|
+
return Array.from(communityMembers.entries()).map(([id, members]) => ({
|
|
2631
|
+
id,
|
|
2632
|
+
members,
|
|
2633
|
+
centroid: this.computeCentroid(members),
|
|
2634
|
+
density: this.computeDensity(members),
|
|
2635
|
+
}));
|
|
2636
|
+
}
|
|
2637
|
+
|
|
2638
|
+
private async girvanNewmanCommunityDetection(options: CommunityOptions): Promise<Community[]> {
|
|
2639
|
+
// Simplified Girvan-Newman (edge betweenness)
|
|
2640
|
+
// In practice, this would iteratively remove high-betweenness edges
|
|
2641
|
+
return this.labelPropagationCommunityDetection(options);
|
|
2642
|
+
}
|
|
2643
|
+
|
|
2644
|
+
private async spectralCommunityDetection(options: CommunityOptions): Promise<Community[]> {
|
|
2645
|
+
// Simplified spectral clustering
|
|
2646
|
+
// In practice, would use eigendecomposition of Laplacian
|
|
2647
|
+
return this.labelPropagationCommunityDetection(options);
|
|
2648
|
+
}
|
|
2649
|
+
|
|
2650
|
+
/**
|
|
2651
|
+
* Generate SQL for k-hop neighbors query.
|
|
2652
|
+
*/
|
|
2653
|
+
kHopNeighborsSQL(nodeId: string, k: number, tableName: string, options: SQLGenerationOptions = {}): string {
|
|
2654
|
+
const schema = options.schema ?? 'public';
|
|
2655
|
+
const edgeTable = options.edgeTable ?? `${tableName}_edges`;
|
|
2656
|
+
|
|
2657
|
+
return `
|
|
2658
|
+
WITH RECURSIVE k_hop AS (
|
|
2659
|
+
SELECT source_id AS node_id, 1 AS depth
|
|
2660
|
+
FROM "${schema}"."${edgeTable}"
|
|
2661
|
+
WHERE target_id = '${nodeId}'
|
|
2662
|
+
UNION
|
|
2663
|
+
SELECT target_id AS node_id, 1 AS depth
|
|
2664
|
+
FROM "${schema}"."${edgeTable}"
|
|
2665
|
+
WHERE source_id = '${nodeId}'
|
|
2666
|
+
UNION ALL
|
|
2667
|
+
SELECT e.target_id AS node_id, kh.depth + 1
|
|
2668
|
+
FROM k_hop kh
|
|
2669
|
+
JOIN "${schema}"."${edgeTable}" e ON kh.node_id = e.source_id
|
|
2670
|
+
WHERE kh.depth < ${k}
|
|
2671
|
+
UNION ALL
|
|
2672
|
+
SELECT e.source_id AS node_id, kh.depth + 1
|
|
2673
|
+
FROM k_hop kh
|
|
2674
|
+
JOIN "${schema}"."${edgeTable}" e ON kh.node_id = e.target_id
|
|
2675
|
+
WHERE kh.depth < ${k}
|
|
2676
|
+
)
|
|
2677
|
+
SELECT DISTINCT node_id FROM k_hop WHERE node_id != '${nodeId}';`.trim();
|
|
2678
|
+
}
|
|
2679
|
+
|
|
2680
|
+
/**
|
|
2681
|
+
* Generate SQL for shortest path query.
|
|
2682
|
+
*/
|
|
2683
|
+
shortestPathSQL(source: string, target: string, tableName: string, options: SQLGenerationOptions = {}): string {
|
|
2684
|
+
const schema = options.schema ?? 'public';
|
|
2685
|
+
const edgeTable = options.edgeTable ?? `${tableName}_edges`;
|
|
2686
|
+
|
|
2687
|
+
return `
|
|
2688
|
+
WITH RECURSIVE path AS (
|
|
2689
|
+
SELECT
|
|
2690
|
+
source_id,
|
|
2691
|
+
target_id,
|
|
2692
|
+
ARRAY[source_id, target_id] AS path,
|
|
2693
|
+
weight AS total_weight,
|
|
2694
|
+
1 AS depth
|
|
2695
|
+
FROM "${schema}"."${edgeTable}"
|
|
2696
|
+
WHERE source_id = '${source}'
|
|
2697
|
+
UNION ALL
|
|
2698
|
+
SELECT
|
|
2699
|
+
p.source_id,
|
|
2700
|
+
e.target_id,
|
|
2701
|
+
p.path || e.target_id,
|
|
2702
|
+
p.total_weight + e.weight,
|
|
2703
|
+
p.depth + 1
|
|
2704
|
+
FROM path p
|
|
2705
|
+
JOIN "${schema}"."${edgeTable}" e ON p.target_id = e.source_id
|
|
2706
|
+
WHERE NOT e.target_id = ANY(p.path)
|
|
2707
|
+
AND p.depth < 10
|
|
2708
|
+
)
|
|
2709
|
+
SELECT path, total_weight
|
|
2710
|
+
FROM path
|
|
2711
|
+
WHERE target_id = '${target}'
|
|
2712
|
+
ORDER BY total_weight
|
|
2713
|
+
LIMIT 1;`.trim();
|
|
2714
|
+
}
|
|
2715
|
+
|
|
2716
|
+
/**
|
|
2717
|
+
* Generate SQL for PageRank computation.
|
|
2718
|
+
*/
|
|
2719
|
+
pageRankSQL(tableName: string, options: PageRankOptions & SQLGenerationOptions = {}): string {
|
|
2720
|
+
const schema = options.schema ?? 'public';
|
|
2721
|
+
const edgeTable = options.edgeTable ?? `${tableName}_edges`;
|
|
2722
|
+
const damping = options.damping ?? 0.85;
|
|
2723
|
+
const maxIterations = options.maxIterations ?? 100;
|
|
2724
|
+
|
|
2725
|
+
return `
|
|
2726
|
+
SELECT ruvector.page_rank(
|
|
2727
|
+
(SELECT array_agg(ARRAY[source_id::text, target_id::text]) FROM "${schema}"."${edgeTable}"),
|
|
2728
|
+
${damping},
|
|
2729
|
+
${maxIterations}
|
|
2730
|
+
);`.trim();
|
|
2731
|
+
}
|
|
2732
|
+
|
|
2733
|
+
/**
|
|
2734
|
+
* Generate SQL for community detection.
|
|
2735
|
+
*/
|
|
2736
|
+
communityDetectionSQL(tableName: string, options: CommunityOptions & SQLGenerationOptions): string {
|
|
2737
|
+
const schema = options.schema ?? 'public';
|
|
2738
|
+
const edgeTable = options.edgeTable ?? `${tableName}_edges`;
|
|
2739
|
+
const algorithm = options.algorithm ?? 'louvain';
|
|
2740
|
+
const resolution = options.resolution ?? 1.0;
|
|
2741
|
+
|
|
2742
|
+
return `
|
|
2743
|
+
SELECT ruvector.community_detection(
|
|
2744
|
+
(SELECT array_agg(ARRAY[source_id::text, target_id::text]) FROM "${schema}"."${edgeTable}"),
|
|
2745
|
+
'${algorithm}',
|
|
2746
|
+
${resolution}
|
|
2747
|
+
);`.trim();
|
|
2748
|
+
}
|
|
2749
|
+
}
|
|
2750
|
+
|
|
2751
|
+
// ============================================================================
|
|
2752
|
+
// SQL Generator for GNN Operations
|
|
2753
|
+
// ============================================================================
|
|
2754
|
+
|
|
2755
|
+
/**
|
|
2756
|
+
* SQL generator for GNN operations in PostgreSQL with RuVector.
|
|
2757
|
+
*/
|
|
2758
|
+
export class GNNSQLGenerator {
|
|
2759
|
+
/**
|
|
2760
|
+
* Generate SQL for GNN layer forward pass.
|
|
2761
|
+
*/
|
|
2762
|
+
static layerForwardSQL(
|
|
2763
|
+
layer: IGNNLayer,
|
|
2764
|
+
tableName: string,
|
|
2765
|
+
options: SQLGenerationOptions = {}
|
|
2766
|
+
): string {
|
|
2767
|
+
return layer.toSQL(tableName, options);
|
|
2768
|
+
}
|
|
2769
|
+
|
|
2770
|
+
/**
|
|
2771
|
+
* Generate SQL for batch GNN operations.
|
|
2772
|
+
*/
|
|
2773
|
+
static batchGNNSQL(
|
|
2774
|
+
layers: IGNNLayer[],
|
|
2775
|
+
tableName: string,
|
|
2776
|
+
options: SQLGenerationOptions = {}
|
|
2777
|
+
): string {
|
|
2778
|
+
const schema = options.schema ?? 'public';
|
|
2779
|
+
const nodeColumn = options.nodeColumn ?? 'embedding';
|
|
2780
|
+
const edgeTable = options.edgeTable ?? `${tableName}_edges`;
|
|
2781
|
+
|
|
2782
|
+
const layerConfigs = layers.map((l) => ({
|
|
2783
|
+
type: l.type,
|
|
2784
|
+
input_dim: l.config.inputDim,
|
|
2785
|
+
output_dim: l.config.outputDim,
|
|
2786
|
+
num_heads: l.config.numHeads,
|
|
2787
|
+
dropout: l.config.dropout,
|
|
2788
|
+
aggregation: l.config.aggregation,
|
|
2789
|
+
params: l.config.params,
|
|
2790
|
+
}));
|
|
2791
|
+
|
|
2792
|
+
return `
|
|
2793
|
+
SELECT ruvector.batch_gnn_forward(
|
|
2794
|
+
(SELECT array_agg(${nodeColumn}) FROM "${schema}"."${tableName}"),
|
|
2795
|
+
(SELECT array_agg(ARRAY[source_id, target_id]) FROM "${schema}"."${edgeTable}"),
|
|
2796
|
+
'${JSON.stringify(layerConfigs)}'::jsonb
|
|
2797
|
+
);`.trim();
|
|
2798
|
+
}
|
|
2799
|
+
|
|
2800
|
+
/**
|
|
2801
|
+
* Generate SQL for caching computed embeddings.
|
|
2802
|
+
*/
|
|
2803
|
+
static cacheEmbeddingsSQL(
|
|
2804
|
+
tableName: string,
|
|
2805
|
+
cacheTable: string,
|
|
2806
|
+
options: SQLGenerationOptions = {}
|
|
2807
|
+
): string {
|
|
2808
|
+
const schema = options.schema ?? 'public';
|
|
2809
|
+
|
|
2810
|
+
return `
|
|
2811
|
+
INSERT INTO "${schema}"."${cacheTable}" (node_id, embedding, computed_at)
|
|
2812
|
+
SELECT
|
|
2813
|
+
id,
|
|
2814
|
+
${options.nodeColumn ?? 'embedding'},
|
|
2815
|
+
NOW()
|
|
2816
|
+
FROM "${schema}"."${tableName}"
|
|
2817
|
+
ON CONFLICT (node_id)
|
|
2818
|
+
DO UPDATE SET
|
|
2819
|
+
embedding = EXCLUDED.embedding,
|
|
2820
|
+
computed_at = NOW();`.trim();
|
|
2821
|
+
}
|
|
2822
|
+
|
|
2823
|
+
/**
|
|
2824
|
+
* Generate SQL for creating GNN cache table.
|
|
2825
|
+
*/
|
|
2826
|
+
static createCacheTableSQL(
|
|
2827
|
+
cacheTable: string,
|
|
2828
|
+
dimension: number,
|
|
2829
|
+
options: SQLGenerationOptions = {}
|
|
2830
|
+
): string {
|
|
2831
|
+
const schema = options.schema ?? 'public';
|
|
2832
|
+
|
|
2833
|
+
return `
|
|
2834
|
+
CREATE TABLE IF NOT EXISTS "${schema}"."${cacheTable}" (
|
|
2835
|
+
node_id TEXT PRIMARY KEY,
|
|
2836
|
+
embedding vector(${dimension}) NOT NULL,
|
|
2837
|
+
computed_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
|
2838
|
+
layer_config JSONB,
|
|
2839
|
+
version INTEGER DEFAULT 1
|
|
2840
|
+
);
|
|
2841
|
+
|
|
2842
|
+
CREATE INDEX IF NOT EXISTS "${cacheTable}_computed_at_idx"
|
|
2843
|
+
ON "${schema}"."${cacheTable}" (computed_at);`.trim();
|
|
2844
|
+
}
|
|
2845
|
+
|
|
2846
|
+
/**
|
|
2847
|
+
* Generate SQL for message passing operation.
|
|
2848
|
+
*/
|
|
2849
|
+
static messagePassingSQL(
|
|
2850
|
+
tableName: string,
|
|
2851
|
+
aggregation: GNNAggregation,
|
|
2852
|
+
options: SQLGenerationOptions = {}
|
|
2853
|
+
): string {
|
|
2854
|
+
const schema = options.schema ?? 'public';
|
|
2855
|
+
const nodeColumn = options.nodeColumn ?? 'embedding';
|
|
2856
|
+
const edgeTable = options.edgeTable ?? `${tableName}_edges`;
|
|
2857
|
+
|
|
2858
|
+
const aggFunctionMap: Record<GNNAggregation, string> = {
|
|
2859
|
+
mean: 'avg',
|
|
2860
|
+
sum: 'sum',
|
|
2861
|
+
max: 'max',
|
|
2862
|
+
min: 'min',
|
|
2863
|
+
attention: 'attention_avg',
|
|
2864
|
+
lstm: 'lstm_agg',
|
|
2865
|
+
softmax: 'softmax_avg',
|
|
2866
|
+
power_mean: 'power_mean',
|
|
2867
|
+
std: 'std',
|
|
2868
|
+
var: 'var',
|
|
2869
|
+
};
|
|
2870
|
+
const aggFunction = aggFunctionMap[aggregation] ?? 'avg';
|
|
2871
|
+
|
|
2872
|
+
return `
|
|
2873
|
+
SELECT
|
|
2874
|
+
n.id,
|
|
2875
|
+
ruvector.vector_${aggFunction}(array_agg(neighbor.${nodeColumn})) AS aggregated_embedding
|
|
2876
|
+
FROM "${schema}"."${tableName}" n
|
|
2877
|
+
LEFT JOIN "${schema}"."${edgeTable}" e ON n.id = e.target_id
|
|
2878
|
+
LEFT JOIN "${schema}"."${tableName}" neighbor ON e.source_id = neighbor.id
|
|
2879
|
+
GROUP BY n.id;`.trim();
|
|
2880
|
+
}
|
|
2881
|
+
|
|
2882
|
+
/**
|
|
2883
|
+
* Generate SQL for graph pooling.
|
|
2884
|
+
*/
|
|
2885
|
+
static graphPoolingSQL(
|
|
2886
|
+
tableName: string,
|
|
2887
|
+
poolingMethod: 'mean' | 'sum' | 'max' | 'attention',
|
|
2888
|
+
options: SQLGenerationOptions = {}
|
|
2889
|
+
): string {
|
|
2890
|
+
const schema = options.schema ?? 'public';
|
|
2891
|
+
const nodeColumn = options.nodeColumn ?? 'embedding';
|
|
2892
|
+
|
|
2893
|
+
const poolFunction = {
|
|
2894
|
+
mean: 'vector_avg',
|
|
2895
|
+
sum: 'vector_sum',
|
|
2896
|
+
max: 'vector_max',
|
|
2897
|
+
attention: 'vector_attention_pool',
|
|
2898
|
+
}[poolingMethod] ?? 'vector_avg';
|
|
2899
|
+
|
|
2900
|
+
return `
|
|
2901
|
+
SELECT ruvector.${poolFunction}(
|
|
2902
|
+
(SELECT array_agg(${nodeColumn}) FROM "${schema}"."${tableName}")
|
|
2903
|
+
) AS graph_embedding;`.trim();
|
|
2904
|
+
}
|
|
2905
|
+
}
|
|
2906
|
+
|
|
2907
|
+
// ============================================================================
|
|
2908
|
+
// Embedding Cache Manager
|
|
2909
|
+
// ============================================================================
|
|
2910
|
+
|
|
2911
|
+
/**
|
|
2912
|
+
* Manager for caching computed GNN embeddings.
|
|
2913
|
+
*/
|
|
2914
|
+
export class GNNEmbeddingCache {
|
|
2915
|
+
private cache: Map<string, { embedding: number[]; timestamp: number; version: number }> =
|
|
2916
|
+
new Map();
|
|
2917
|
+
private maxSize: number;
|
|
2918
|
+
private ttlMs: number;
|
|
2919
|
+
|
|
2920
|
+
constructor(maxSize: number = 10000, ttlMs: number = 3600000) {
|
|
2921
|
+
this.maxSize = maxSize;
|
|
2922
|
+
this.ttlMs = ttlMs;
|
|
2923
|
+
}
|
|
2924
|
+
|
|
2925
|
+
/**
|
|
2926
|
+
* Get cached embedding.
|
|
2927
|
+
*/
|
|
2928
|
+
get(nodeId: NodeId, version?: number): number[] | undefined {
|
|
2929
|
+
const key = String(nodeId);
|
|
2930
|
+
const entry = this.cache.get(key);
|
|
2931
|
+
|
|
2932
|
+
if (!entry) return undefined;
|
|
2933
|
+
|
|
2934
|
+
// Check TTL
|
|
2935
|
+
if (Date.now() - entry.timestamp > this.ttlMs) {
|
|
2936
|
+
this.cache.delete(key);
|
|
2937
|
+
return undefined;
|
|
2938
|
+
}
|
|
2939
|
+
|
|
2940
|
+
// Check version
|
|
2941
|
+
if (version !== undefined && entry.version !== version) {
|
|
2942
|
+
return undefined;
|
|
2943
|
+
}
|
|
2944
|
+
|
|
2945
|
+
return entry.embedding;
|
|
2946
|
+
}
|
|
2947
|
+
|
|
2948
|
+
/**
|
|
2949
|
+
* Set cached embedding.
|
|
2950
|
+
*/
|
|
2951
|
+
set(nodeId: NodeId, embedding: number[], version: number = 1): void {
|
|
2952
|
+
// Evict if at capacity
|
|
2953
|
+
if (this.cache.size >= this.maxSize) {
|
|
2954
|
+
this.evictOldest();
|
|
2955
|
+
}
|
|
2956
|
+
|
|
2957
|
+
this.cache.set(String(nodeId), {
|
|
2958
|
+
embedding,
|
|
2959
|
+
timestamp: Date.now(),
|
|
2960
|
+
version,
|
|
2961
|
+
});
|
|
2962
|
+
}
|
|
2963
|
+
|
|
2964
|
+
/**
|
|
2965
|
+
* Batch get embeddings.
|
|
2966
|
+
*/
|
|
2967
|
+
getBatch(nodeIds: NodeId[], version?: number): Map<NodeId, number[]> {
|
|
2968
|
+
const result = new Map<NodeId, number[]>();
|
|
2969
|
+
|
|
2970
|
+
for (const id of nodeIds) {
|
|
2971
|
+
const embedding = this.get(id, version);
|
|
2972
|
+
if (embedding) {
|
|
2973
|
+
result.set(id, embedding);
|
|
2974
|
+
}
|
|
2975
|
+
}
|
|
2976
|
+
|
|
2977
|
+
return result;
|
|
2978
|
+
}
|
|
2979
|
+
|
|
2980
|
+
/**
|
|
2981
|
+
* Batch set embeddings.
|
|
2982
|
+
*/
|
|
2983
|
+
setBatch(embeddings: Map<NodeId, number[]>, version: number = 1): void {
|
|
2984
|
+
for (const [id, embedding] of embeddings.entries()) {
|
|
2985
|
+
this.set(id, embedding, version);
|
|
2986
|
+
}
|
|
2987
|
+
}
|
|
2988
|
+
|
|
2989
|
+
/**
|
|
2990
|
+
* Clear cache.
|
|
2991
|
+
*/
|
|
2992
|
+
clear(): void {
|
|
2993
|
+
this.cache.clear();
|
|
2994
|
+
}
|
|
2995
|
+
|
|
2996
|
+
/**
|
|
2997
|
+
* Get cache statistics.
|
|
2998
|
+
*/
|
|
2999
|
+
getStats(): { size: number; maxSize: number; hitRate: number } {
|
|
3000
|
+
return {
|
|
3001
|
+
size: this.cache.size,
|
|
3002
|
+
maxSize: this.maxSize,
|
|
3003
|
+
hitRate: 0, // Would need to track hits/misses
|
|
3004
|
+
};
|
|
3005
|
+
}
|
|
3006
|
+
|
|
3007
|
+
private evictOldest(): void {
|
|
3008
|
+
let oldestKey: string | null = null;
|
|
3009
|
+
let oldestTimestamp = Infinity;
|
|
3010
|
+
|
|
3011
|
+
for (const [key, entry] of this.cache.entries()) {
|
|
3012
|
+
if (entry.timestamp < oldestTimestamp) {
|
|
3013
|
+
oldestTimestamp = entry.timestamp;
|
|
3014
|
+
oldestKey = key;
|
|
3015
|
+
}
|
|
3016
|
+
}
|
|
3017
|
+
|
|
3018
|
+
if (oldestKey) {
|
|
3019
|
+
this.cache.delete(oldestKey);
|
|
3020
|
+
}
|
|
3021
|
+
}
|
|
3022
|
+
}
|
|
3023
|
+
|
|
3024
|
+
// ============================================================================
|
|
3025
|
+
// Factory and Default Instance
|
|
3026
|
+
// ============================================================================
|
|
3027
|
+
|
|
3028
|
+
/**
|
|
3029
|
+
* Create a default GNN layer registry with all built-in layers.
|
|
3030
|
+
*/
|
|
3031
|
+
export function createGNNLayerRegistry(): GNNLayerRegistry {
|
|
3032
|
+
return new GNNLayerRegistry();
|
|
3033
|
+
}
|
|
3034
|
+
|
|
3035
|
+
/**
|
|
3036
|
+
* Create a GNN layer with the default registry.
|
|
3037
|
+
*/
|
|
3038
|
+
export function createGNNLayer(type: GNNLayerType, config: Partial<GNNLayerConfig>): IGNNLayer {
|
|
3039
|
+
const registry = createGNNLayerRegistry();
|
|
3040
|
+
return registry.createLayer(type, config);
|
|
3041
|
+
}
|
|
3042
|
+
|
|
3043
|
+
/**
|
|
3044
|
+
* Create graph operations instance.
|
|
3045
|
+
*/
|
|
3046
|
+
export function createGraphOperations(): GraphOperations {
|
|
3047
|
+
return new GraphOperations();
|
|
3048
|
+
}
|
|
3049
|
+
|
|
3050
|
+
// GNN_DEFAULTS and GNN_SQL_FUNCTIONS are already exported via export const above
|