@mhingston5/conduit 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. package/.env.example +13 -0
  2. package/.github/workflows/ci.yml +88 -0
  3. package/.github/workflows/pr-checks.yml +90 -0
  4. package/.tool-versions +2 -0
  5. package/README.md +177 -0
  6. package/conduit.yaml.test +3 -0
  7. package/docs/ARCHITECTURE.md +35 -0
  8. package/docs/CODE_MODE.md +33 -0
  9. package/docs/SECURITY.md +52 -0
  10. package/logo.png +0 -0
  11. package/package.json +74 -0
  12. package/src/assets/deno-shim.ts +93 -0
  13. package/src/assets/python-shim.py +21 -0
  14. package/src/core/asset.utils.ts +42 -0
  15. package/src/core/concurrency.service.ts +70 -0
  16. package/src/core/config.service.ts +147 -0
  17. package/src/core/execution.context.ts +37 -0
  18. package/src/core/execution.service.ts +209 -0
  19. package/src/core/interfaces/app.config.ts +17 -0
  20. package/src/core/interfaces/executor.interface.ts +31 -0
  21. package/src/core/interfaces/middleware.interface.ts +12 -0
  22. package/src/core/interfaces/url.validator.interface.ts +3 -0
  23. package/src/core/logger.ts +64 -0
  24. package/src/core/metrics.service.ts +112 -0
  25. package/src/core/middleware/auth.middleware.ts +56 -0
  26. package/src/core/middleware/error.middleware.ts +21 -0
  27. package/src/core/middleware/logging.middleware.ts +25 -0
  28. package/src/core/middleware/middleware.builder.ts +24 -0
  29. package/src/core/middleware/ratelimit.middleware.ts +31 -0
  30. package/src/core/network.policy.service.ts +106 -0
  31. package/src/core/ops.server.ts +74 -0
  32. package/src/core/otel.service.ts +41 -0
  33. package/src/core/policy.service.ts +77 -0
  34. package/src/core/registries/executor.registry.ts +26 -0
  35. package/src/core/request.controller.ts +297 -0
  36. package/src/core/security.service.ts +68 -0
  37. package/src/core/session.manager.ts +44 -0
  38. package/src/core/types.ts +47 -0
  39. package/src/executors/deno.executor.ts +342 -0
  40. package/src/executors/isolate.executor.ts +281 -0
  41. package/src/executors/pyodide.executor.ts +327 -0
  42. package/src/executors/pyodide.worker.ts +195 -0
  43. package/src/gateway/auth.service.ts +104 -0
  44. package/src/gateway/gateway.service.ts +345 -0
  45. package/src/gateway/schema.cache.ts +46 -0
  46. package/src/gateway/upstream.client.ts +244 -0
  47. package/src/index.ts +92 -0
  48. package/src/sdk/index.ts +2 -0
  49. package/src/sdk/sdk-generator.ts +245 -0
  50. package/src/sdk/tool-binding.ts +86 -0
  51. package/src/transport/socket.transport.ts +203 -0
  52. package/tests/__snapshots__/assets.test.ts.snap +97 -0
  53. package/tests/assets.test.ts +50 -0
  54. package/tests/auth.service.test.ts +78 -0
  55. package/tests/code-mode-lite-execution.test.ts +84 -0
  56. package/tests/code-mode-lite-gateway.test.ts +150 -0
  57. package/tests/concurrency.service.test.ts +50 -0
  58. package/tests/concurrency.test.ts +41 -0
  59. package/tests/config.service.test.ts +70 -0
  60. package/tests/contract.test.ts +43 -0
  61. package/tests/deno.executor.test.ts +68 -0
  62. package/tests/deno_hardening.test.ts +45 -0
  63. package/tests/dynamic.tool.test.ts +237 -0
  64. package/tests/e2e_stdio_upstream.test.ts +197 -0
  65. package/tests/fixtures/stdio-server.ts +42 -0
  66. package/tests/gateway.manifest.test.ts +82 -0
  67. package/tests/gateway.service.test.ts +58 -0
  68. package/tests/gateway.strict.unit.test.ts +74 -0
  69. package/tests/gateway.validation.unit.test.ts +89 -0
  70. package/tests/gateway_validation.test.ts +86 -0
  71. package/tests/hardening.test.ts +139 -0
  72. package/tests/hardening_v1.test.ts +72 -0
  73. package/tests/isolate.executor.test.ts +100 -0
  74. package/tests/log-limit.test.ts +55 -0
  75. package/tests/middleware.test.ts +106 -0
  76. package/tests/ops.server.test.ts +65 -0
  77. package/tests/policy.service.test.ts +90 -0
  78. package/tests/pyodide.executor.test.ts +101 -0
  79. package/tests/reference_mcp.ts +40 -0
  80. package/tests/remediation.test.ts +119 -0
  81. package/tests/routing.test.ts +148 -0
  82. package/tests/schema.cache.test.ts +27 -0
  83. package/tests/sdk/sdk-generator.test.ts +205 -0
  84. package/tests/socket.transport.test.ts +182 -0
  85. package/tests/stdio_upstream.test.ts +54 -0
  86. package/tsconfig.json +25 -0
  87. package/tsup.config.ts +22 -0
@@ -0,0 +1,327 @@
1
+ import { Worker } from 'node:worker_threads';
2
+ import fs from 'node:fs';
3
+ import path from 'node:path';
4
+ import { fileURLToPath } from 'node:url';
5
+ import { ExecutionContext } from '../core/execution.context.js';
6
+ import { ResourceLimits as ConduitResourceLimits } from '../core/config.service.js';
7
+ import { ConduitError } from '../core/types.js';
8
+ import { resolveAssetPath } from '../core/asset.utils.js';
9
+
10
+ const __dirname = path.dirname(fileURLToPath(import.meta.url));
11
+
12
+ import { Executor, ExecutorConfig, ExecutionResult } from '../core/interfaces/executor.interface.js';
13
+
14
+ export { ExecutionResult };
15
+
16
+ // Deprecated: use ExecutorConfig
17
+ export interface IPCInfo {
18
+ ipcAddress: string;
19
+ ipcToken: string;
20
+ sdkCode?: string;
21
+ }
22
+
23
+
24
+ interface PooledWorker {
25
+ worker: Worker;
26
+ busy: boolean;
27
+ runs: number;
28
+ lastUsed: number;
29
+ }
30
+
31
+ export class PyodideExecutor implements Executor {
32
+ private shimContent: string = '';
33
+ private pool: PooledWorker[] = [];
34
+ private maxPoolSize: number;
35
+ private maxRunsPerWorker = 1;
36
+
37
+ constructor(maxPoolSize = 3) {
38
+ this.maxPoolSize = maxPoolSize;
39
+ }
40
+
41
+ private getShim(): string {
42
+ if (this.shimContent) return this.shimContent;
43
+ try {
44
+ const assetPath = resolveAssetPath('python-shim.py');
45
+ this.shimContent = fs.readFileSync(assetPath, 'utf-8');
46
+ return this.shimContent;
47
+ } catch (err: any) {
48
+ throw new Error(`Failed to load Python shim: ${err.message}`);
49
+ }
50
+ }
51
+
52
+ private waitQueue: Array<(worker: PooledWorker) => void> = [];
53
+
54
+ private async getWorker(logger: any, limits?: ConduitResourceLimits): Promise<PooledWorker> {
55
+ // Find available worker
56
+ let pooled = this.pool.find(w => !w.busy);
57
+ if (pooled) {
58
+ pooled.busy = true;
59
+ return pooled;
60
+ }
61
+
62
+ // Create new worker if pool not full
63
+ if (this.pool.length < this.maxPoolSize) {
64
+ logger.info('Creating new Pyodide worker for pool');
65
+ const worker = this.createWorker(limits);
66
+ pooled = { worker, busy: true, runs: 0, lastUsed: Date.now() };
67
+ this.pool.push(pooled);
68
+
69
+ // Wait for ready signal
70
+ await new Promise<void>((resolve, reject) => {
71
+ const onMessage = (msg: any) => {
72
+ if (msg.type === 'ready') {
73
+ worker.off('message', onMessage);
74
+ resolve();
75
+ }
76
+ };
77
+ worker.on('message', onMessage);
78
+ worker.on('error', reject);
79
+ setTimeout(() => {
80
+ // Cleanup worker on timeout
81
+ worker.terminate();
82
+ this.pool = this.pool.filter(p => p !== pooled);
83
+ reject(new Error('Worker init timeout'));
84
+ }, 10000);
85
+ });
86
+
87
+ return pooled;
88
+ }
89
+
90
+ // Wait for a worker to become available via queue
91
+ return new Promise((resolve) => {
92
+ this.waitQueue.push(resolve);
93
+ });
94
+ }
95
+
96
+ private createWorker(limits?: ConduitResourceLimits): Worker {
97
+ let workerPath = path.resolve(__dirname, './pyodide.worker.js');
98
+ if (!fs.existsSync(workerPath)) {
99
+ workerPath = path.resolve(__dirname, './pyodide.worker.ts');
100
+ }
101
+
102
+ return new Worker(workerPath, {
103
+ execArgv: process.execArgv.includes('--loader') ? process.execArgv : [],
104
+ resourceLimits: limits ? {
105
+ maxOldSpaceSizeMb: limits.memoryLimitMb,
106
+ // Stack size and young generation are usually fine with defaults
107
+ } as any : undefined
108
+ });
109
+ }
110
+
111
+ async warmup(limits: ConduitResourceLimits) {
112
+ // Pre-fill the pool up to maxPoolSize
113
+ const needed = this.maxPoolSize - this.pool.length;
114
+ if (needed <= 0) return;
115
+
116
+ console.info(`Pre-warming ${needed} Pyodide workers...`);
117
+ const promises = [];
118
+ for (let i = 0; i < needed; i++) {
119
+ promises.push(this.createAndPoolWorker(limits));
120
+ }
121
+ await Promise.all(promises);
122
+ console.info(`Pyodide pool pre-warmed with ${this.pool.length} workers.`);
123
+ }
124
+
125
+ private async createAndPoolWorker(limits: ConduitResourceLimits) {
126
+ // Small optimization: don't double-fill if racing
127
+ if (this.pool.length >= this.maxPoolSize) return;
128
+
129
+ const worker = this.createWorker(limits);
130
+ const pooled: PooledWorker = { worker, busy: true, runs: 0, lastUsed: Date.now() };
131
+ this.pool.push(pooled);
132
+
133
+ // Wait for ready signal
134
+ try {
135
+ await new Promise<void>((resolve, reject) => {
136
+ const onMessage = (msg: any) => {
137
+ if (msg.type === 'ready') {
138
+ worker.off('message', onMessage);
139
+ resolve();
140
+ }
141
+ };
142
+ worker.on('message', onMessage);
143
+ worker.on('error', reject);
144
+ setTimeout(() => reject(new Error('Worker init timeout')), 10000);
145
+ });
146
+ pooled.busy = false;
147
+
148
+ // Check if anyone is waiting for a worker immediately
149
+ if (this.waitQueue.length > 0) {
150
+ const nextResolve = this.waitQueue.shift();
151
+ if (nextResolve) {
152
+ pooled.busy = true;
153
+ nextResolve(pooled);
154
+ }
155
+ }
156
+ } catch (err) {
157
+ // If failed, remove from pool
158
+ this.pool = this.pool.filter(p => p !== pooled);
159
+ worker.terminate();
160
+ }
161
+ }
162
+
163
+ async execute(code: string, limits: ConduitResourceLimits, context: ExecutionContext, config?: ExecutorConfig): Promise<ExecutionResult> {
164
+ const { logger } = context;
165
+ const pooledWorker = await this.getWorker(logger, limits);
166
+ const worker = pooledWorker.worker;
167
+
168
+ return new Promise((resolve) => {
169
+ const timeout = setTimeout(() => {
170
+ logger.warn('Python execution timed out, terminating worker');
171
+ worker.terminate();
172
+ // Remove from pool
173
+ this.pool = this.pool.filter(w => w !== pooledWorker);
174
+ resolve({
175
+ stdout: '',
176
+ stderr: 'Execution timed out',
177
+ exitCode: null,
178
+ error: {
179
+ code: ConduitError.RequestTimeout,
180
+ message: 'Execution timed out',
181
+ },
182
+ });
183
+ }, limits.timeoutMs);
184
+
185
+ const onMessage = (msg: any) => {
186
+ if (msg.type === 'ready' || msg.type === 'pong') return;
187
+
188
+ clearTimeout(timeout);
189
+ worker.off('message', onMessage);
190
+ worker.off('error', onError);
191
+
192
+ pooledWorker.busy = false;
193
+
194
+ // Notify next waiter if any
195
+ if (this.waitQueue.length > 0) {
196
+ const nextResolve = this.waitQueue.shift();
197
+ if (nextResolve) {
198
+ pooledWorker.busy = true;
199
+ nextResolve(pooledWorker);
200
+ }
201
+ }
202
+
203
+ pooledWorker.runs++;
204
+ pooledWorker.lastUsed = Date.now();
205
+
206
+ // Recycle if too many runs
207
+ if (pooledWorker.runs >= this.maxRunsPerWorker) {
208
+ logger.info('Recycling Pyodide worker after max runs');
209
+ worker.terminate();
210
+ this.pool = this.pool.filter(w => w !== pooledWorker);
211
+ }
212
+
213
+ if (msg.success) {
214
+ resolve({
215
+ stdout: msg.stdout,
216
+ stderr: msg.stderr,
217
+ exitCode: 0,
218
+ });
219
+ } else {
220
+ logger.warn({ error: msg.error }, 'Python execution failed or limit breached, terminating worker');
221
+ worker.terminate();
222
+ this.pool = this.pool.filter(w => w !== pooledWorker);
223
+
224
+ logger.debug({ error: msg.error }, 'Python execution error from worker');
225
+ const normalizedError = (msg.error || '').toLowerCase();
226
+ const limitBreached = msg.limitBreached || '';
227
+
228
+ const isLogLimit = limitBreached === 'log' || normalizedError.includes('[limit_log]');
229
+ const isOutputLimit = limitBreached === 'output' || normalizedError.includes('[limit_output]');
230
+ const isAmbiguousLimit = !isOutputLimit && !isLogLimit && (normalizedError.includes('i/o error') || normalizedError.includes('errno 29') || normalizedError.includes('limit exceeded'));
231
+
232
+ resolve({
233
+ stdout: msg.stdout,
234
+ stderr: msg.stderr,
235
+ exitCode: 1,
236
+ error: {
237
+ code: isLogLimit ? ConduitError.LogLimitExceeded : ((isOutputLimit || isAmbiguousLimit) ? ConduitError.OutputLimitExceeded : ConduitError.InternalError),
238
+ message: isLogLimit ? 'Log entry limit exceeded' : ((isOutputLimit || isAmbiguousLimit) ? 'Output limit exceeded' : msg.error),
239
+ },
240
+ });
241
+ }
242
+ };
243
+
244
+ const onError = (err: any) => {
245
+ clearTimeout(timeout);
246
+ worker.off('message', onMessage);
247
+ worker.off('error', onError);
248
+
249
+ logger.error({ err }, 'Pyodide worker error');
250
+ worker.terminate();
251
+ this.pool = this.pool.filter(w => w !== pooledWorker);
252
+
253
+ resolve({
254
+ stdout: '',
255
+ stderr: err.message,
256
+ exitCode: null,
257
+ error: {
258
+ code: ConduitError.InternalError,
259
+ message: err.message,
260
+ },
261
+ });
262
+ };
263
+
264
+ worker.on('message', onMessage);
265
+ worker.on('error', onError);
266
+
267
+ // Prepare shim with SDK injection
268
+ let shim = this.getShim();
269
+ if (config?.sdkCode) {
270
+ shim = shim.replace('# __CONDUIT_SDK_INJECTION__', config.sdkCode);
271
+ }
272
+
273
+ worker.postMessage({
274
+ type: 'execute',
275
+ data: { code, limits, ipcInfo: config, shim }
276
+ });
277
+ });
278
+ }
279
+
280
+ async shutdown() {
281
+ for (const pooled of this.pool) {
282
+ await pooled.worker.terminate();
283
+ }
284
+ this.pool = [];
285
+ }
286
+
287
+ async healthCheck(): Promise<{ status: string; workers: number; detail?: string }> {
288
+ try {
289
+ // Find an available worker or create a temporary one for health check
290
+ const pooled = await this.getWorker(console, {
291
+ timeoutMs: 5000,
292
+ memoryLimitMb: 128,
293
+ maxOutputBytes: 1024,
294
+ maxLogEntries: 10
295
+ });
296
+
297
+ return new Promise((resolve) => {
298
+ let timeout: NodeJS.Timeout;
299
+
300
+ const onMessage = (msg: any) => {
301
+ if (msg.type === 'pong') {
302
+ cleanup();
303
+ pooled.busy = false;
304
+ resolve({ status: 'ok', workers: this.pool.length });
305
+ }
306
+ };
307
+
308
+ const cleanup = () => {
309
+ clearTimeout(timeout);
310
+ pooled.worker.off('message', onMessage);
311
+ };
312
+
313
+ timeout = setTimeout(() => {
314
+ cleanup();
315
+ pooled.busy = false;
316
+ resolve({ status: 'error', workers: this.pool.length, detail: 'Health check timeout' });
317
+ }, 2000);
318
+
319
+ pooled.worker.on('message', onMessage);
320
+ pooled.worker.postMessage({ type: 'ping' });
321
+ });
322
+ } catch (err: any) {
323
+ return { status: 'error', workers: this.pool.length, detail: err.message };
324
+ }
325
+ }
326
+ }
327
+
@@ -0,0 +1,195 @@
1
+ import { parentPort, workerData } from 'node:worker_threads';
2
+ import { loadPyodide, type PyodideInterface } from 'pyodide';
3
+ import net from 'node:net';
4
+
5
+ let pyodide: PyodideInterface | null = null;
6
+ let currentStdout = '';
7
+ let currentStderr = '';
8
+ let totalOutputBytes = 0;
9
+ let totalLogEntries = 0;
10
+ let currentLimits: any = null;
11
+
12
+ async function init() {
13
+ if (pyodide) return pyodide;
14
+
15
+ pyodide = await loadPyodide({
16
+ stdout: (text) => {
17
+ if (currentLimits && (totalOutputBytes > (currentLimits.maxOutputBytes || 1024 * 1024) || totalLogEntries > (currentLimits.maxLogEntries || 10000))) {
18
+ return; // Stop processing logs once limit breached
19
+ }
20
+ currentStdout += text + '\n';
21
+ totalOutputBytes += text.length + 1;
22
+ totalLogEntries++;
23
+ },
24
+ stderr: (text) => {
25
+ if (currentLimits && (totalOutputBytes > (currentLimits.maxOutputBytes || 1024 * 1024) || totalLogEntries > (currentLimits.maxLogEntries || 10000))) {
26
+ return; // Stop processing logs once limit breached
27
+ }
28
+ currentStderr += text + '\n';
29
+ totalOutputBytes += text.length + 1;
30
+ totalLogEntries++;
31
+ },
32
+ });
33
+
34
+ return pyodide;
35
+ }
36
+
37
+ async function handleTask(data: any) {
38
+ const { code, limits, ipcInfo, shim } = data;
39
+ currentStdout = '';
40
+ currentStderr = '';
41
+ totalOutputBytes = 0;
42
+ totalLogEntries = 0;
43
+ currentLimits = limits;
44
+
45
+ try {
46
+ const p = await init();
47
+
48
+ const sendIPCRequest = async (method: string, params: any) => {
49
+ if (!ipcInfo?.ipcAddress) throw new Error('Conduit IPC address not configured');
50
+
51
+ return new Promise((resolve, reject) => {
52
+ let client: net.Socket;
53
+
54
+ if (ipcInfo.ipcAddress.includes(':')) {
55
+ const lastColon = ipcInfo.ipcAddress.lastIndexOf(':');
56
+ const host = ipcInfo.ipcAddress.substring(0, lastColon);
57
+ const port = ipcInfo.ipcAddress.substring(lastColon + 1);
58
+
59
+ let targetHost = host.replace(/[\[\]]/g, '');
60
+ if (targetHost === '0.0.0.0' || targetHost === '::' || targetHost === '::1' || targetHost === '') {
61
+ targetHost = '127.0.0.1';
62
+ }
63
+
64
+ client = net.createConnection({
65
+ host: targetHost,
66
+ port: parseInt(port)
67
+ });
68
+ } else {
69
+ client = net.createConnection({ path: ipcInfo.ipcAddress });
70
+ }
71
+
72
+ const id = Math.random().toString(36).substring(7);
73
+ const request = {
74
+ jsonrpc: '2.0',
75
+ id,
76
+ method,
77
+ params: params || {},
78
+ auth: { bearerToken: ipcInfo.ipcToken }
79
+ };
80
+
81
+ client.on('error', (err) => {
82
+ reject(err);
83
+ client.destroy();
84
+ });
85
+
86
+ client.write(JSON.stringify(request) + '\n');
87
+
88
+ let buffer = '';
89
+ client.on('data', (data) => {
90
+ buffer += data.toString();
91
+ // Robust framing: read until we find a complete JSON object on a line
92
+ const lines = buffer.split('\n');
93
+ buffer = lines.pop() || ''; // Keep the last partial line
94
+
95
+ for (const line of lines) {
96
+ if (!line.trim()) continue;
97
+ try {
98
+ const response = JSON.parse(line);
99
+ if (response.id === id) {
100
+ if (response.error) {
101
+ reject(new Error(response.error.message));
102
+ } else {
103
+ resolve(response.result);
104
+ }
105
+ client.end();
106
+ return;
107
+ }
108
+ } catch (e) {
109
+ // If parse fails, it might be a partial line that we haven't seen the end of yet
110
+ // but since we split by \n, this shouldn't happen unless the \n was inside the JSON.
111
+ // However, Conduit ensures JSON-RPC is one line.
112
+ }
113
+ }
114
+ });
115
+
116
+ client.on('end', () => {
117
+ if (buffer.trim()) {
118
+ try {
119
+ const response = JSON.parse(buffer);
120
+ if (response.id === id) {
121
+ if (response.error) {
122
+ reject(new Error(response.error.message));
123
+ } else {
124
+ resolve(response.result);
125
+ }
126
+ }
127
+ } catch (e) { }
128
+ }
129
+ });
130
+ });
131
+ };
132
+
133
+ (p as any).globals.set('discover_mcp_tools_js', (options: any) => {
134
+ return sendIPCRequest('mcp.discoverTools', options);
135
+ });
136
+
137
+ (p as any).globals.set('call_mcp_tool_js', (name: string, args: any) => {
138
+ return sendIPCRequest('mcp.callTool', { name, arguments: args });
139
+ });
140
+
141
+ if (shim) {
142
+ await p.runPythonAsync(shim);
143
+ }
144
+
145
+ const result = await p.runPythonAsync(code);
146
+
147
+ if (totalOutputBytes > (limits.maxOutputBytes || 1024 * 1024)) {
148
+ throw new Error('[LIMIT_OUTPUT]');
149
+ }
150
+ if (totalLogEntries > (limits.maxLogEntries || 10000)) {
151
+ throw new Error('[LIMIT_LOG]');
152
+ }
153
+
154
+ parentPort?.postMessage({
155
+ stdout: currentStdout,
156
+ stderr: currentStderr,
157
+ result: String(result),
158
+ success: true,
159
+ });
160
+ } catch (err: any) {
161
+ let isOutput = err.message.includes('[LIMIT_OUTPUT]');
162
+ let isLog = err.message.includes('[LIMIT_LOG]');
163
+
164
+ // Fallback: check counters if message doesn't match (e.g. wrapped in OSError)
165
+ if (!isOutput && !isLog && currentLimits) {
166
+ if (totalOutputBytes > (currentLimits.maxOutputBytes || 1024 * 1024)) {
167
+ isOutput = true;
168
+ }
169
+ // Check specific log limit breach
170
+ if (totalLogEntries > (currentLimits.maxLogEntries || 10000)) {
171
+ isLog = true;
172
+ }
173
+ }
174
+
175
+ parentPort?.postMessage({
176
+ stdout: currentStdout,
177
+ stderr: currentStderr,
178
+ error: err.message,
179
+ limitBreached: isOutput ? 'output' : (isLog ? 'log' : undefined),
180
+ success: false,
181
+ });
182
+ }
183
+ }
184
+
185
+ parentPort?.on('message', async (msg) => {
186
+ if (msg.type === 'execute') {
187
+ await handleTask(msg.data);
188
+ } else if (msg.type === 'ping') {
189
+ parentPort?.postMessage({ type: 'pong' });
190
+ }
191
+ });
192
+
193
+ // Signal ready
194
+ parentPort?.postMessage({ type: 'ready' });
195
+
@@ -0,0 +1,104 @@
1
+ import { Logger } from 'pino';
2
+ import axios from 'axios';
3
+
4
+ export type AuthType = 'apiKey' | 'oauth2' | 'bearer';
5
+
6
+ export interface UpstreamCredentials {
7
+ type: AuthType;
8
+ apiKey?: string;
9
+ bearerToken?: string;
10
+ oauth2?: {
11
+ clientId: string;
12
+ clientSecret: string;
13
+ tokenUrl: string;
14
+ refreshToken: string;
15
+ };
16
+ }
17
+
18
+ interface CachedToken {
19
+ accessToken: string;
20
+ expiresAt: number;
21
+ }
22
+
23
+ export class AuthService {
24
+ private logger: Logger;
25
+ // Cache tokens separately from credentials to avoid mutation
26
+ private tokenCache = new Map<string, CachedToken>();
27
+ // Prevent concurrent refresh requests for the same client
28
+ private refreshLocks = new Map<string, Promise<string>>();
29
+
30
+ constructor(logger: Logger) {
31
+ this.logger = logger;
32
+ }
33
+
34
+ async getAuthHeaders(creds: UpstreamCredentials): Promise<Record<string, string>> {
35
+ switch (creds.type) {
36
+ case 'apiKey':
37
+ return { 'X-API-Key': creds.apiKey || '' };
38
+ case 'bearer':
39
+ return { 'Authorization': `Bearer ${creds.bearerToken}` };
40
+ case 'oauth2':
41
+ return { 'Authorization': await this.getOAuth2Token(creds) };
42
+ default:
43
+ throw new Error(`Unsupported auth type: ${creds.type}`);
44
+ }
45
+ }
46
+
47
+ private async getOAuth2Token(creds: UpstreamCredentials): Promise<string> {
48
+ if (!creds.oauth2) throw new Error('OAuth2 credentials missing');
49
+
50
+ const { oauth2 } = creds;
51
+ const cacheKey = `${oauth2.clientId}:${oauth2.tokenUrl}`;
52
+
53
+ // Check cache first (with 30s buffer)
54
+ const cached = this.tokenCache.get(cacheKey);
55
+ if (cached && cached.expiresAt > Date.now() + 30000) {
56
+ return `Bearer ${cached.accessToken}`;
57
+ }
58
+
59
+ // Check if refresh is already in progress
60
+ const existingRefresh = this.refreshLocks.get(cacheKey);
61
+ if (existingRefresh) {
62
+ return existingRefresh;
63
+ }
64
+
65
+ // Start refresh with lock
66
+ const refreshPromise = this.doRefresh(creds, cacheKey);
67
+ this.refreshLocks.set(cacheKey, refreshPromise);
68
+
69
+ try {
70
+ return await refreshPromise;
71
+ } finally {
72
+ this.refreshLocks.delete(cacheKey);
73
+ }
74
+ }
75
+
76
+ private async doRefresh(creds: UpstreamCredentials, cacheKey: string): Promise<string> {
77
+ const { oauth2 } = creds;
78
+ if (!oauth2) throw new Error('OAuth2 credentials missing');
79
+
80
+ this.logger.info('Refreshing OAuth2 token');
81
+
82
+ try {
83
+ const response = await axios.post(oauth2.tokenUrl, {
84
+ grant_type: 'refresh_token',
85
+ refresh_token: oauth2.refreshToken,
86
+ client_id: oauth2.clientId,
87
+ client_secret: oauth2.clientSecret,
88
+ });
89
+
90
+ const { access_token, expires_in } = response.data;
91
+
92
+ // Cache the token (don't mutate the input credentials)
93
+ this.tokenCache.set(cacheKey, {
94
+ accessToken: access_token,
95
+ expiresAt: Date.now() + (expires_in * 1000),
96
+ });
97
+
98
+ return `Bearer ${access_token}`;
99
+ } catch (err: any) {
100
+ this.logger.error({ err: err.message }, 'Failed to refresh OAuth2 token');
101
+ throw new Error(`OAuth2 refresh failed: ${err.message}`);
102
+ }
103
+ }
104
+ }