@directus/api 11.0.1 → 11.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. package/dist/emitter.d.ts +3 -2
  2. package/dist/emitter.js +12 -4
  3. package/dist/env.js +15 -4
  4. package/dist/messenger.d.ts +3 -3
  5. package/dist/messenger.js +14 -5
  6. package/dist/middleware/authenticate.js +2 -38
  7. package/dist/server.js +10 -0
  8. package/dist/services/graphql/index.d.ts +0 -6
  9. package/dist/services/graphql/index.js +98 -57
  10. package/dist/services/graphql/subscription.d.ts +16 -0
  11. package/dist/services/graphql/subscription.js +77 -0
  12. package/dist/services/server.js +24 -0
  13. package/dist/services/websocket.d.ts +14 -0
  14. package/dist/services/websocket.js +26 -0
  15. package/dist/utils/apply-diff.js +11 -2
  16. package/dist/utils/apply-query.js +5 -6
  17. package/dist/utils/get-accountability-for-token.d.ts +2 -0
  18. package/dist/utils/get-accountability-for-token.js +50 -0
  19. package/dist/utils/get-service.d.ts +7 -0
  20. package/dist/utils/get-service.js +49 -0
  21. package/dist/utils/redact.d.ts +4 -0
  22. package/dist/utils/redact.js +15 -1
  23. package/dist/utils/to-boolean.d.ts +4 -0
  24. package/dist/utils/to-boolean.js +6 -0
  25. package/dist/websocket/authenticate.d.ts +6 -0
  26. package/dist/websocket/authenticate.js +62 -0
  27. package/dist/websocket/controllers/base.d.ts +42 -0
  28. package/dist/websocket/controllers/base.js +276 -0
  29. package/dist/websocket/controllers/graphql.d.ts +12 -0
  30. package/dist/websocket/controllers/graphql.js +102 -0
  31. package/dist/websocket/controllers/hooks.d.ts +1 -0
  32. package/dist/websocket/controllers/hooks.js +122 -0
  33. package/dist/websocket/controllers/index.d.ts +10 -0
  34. package/dist/websocket/controllers/index.js +35 -0
  35. package/dist/websocket/controllers/rest.d.ts +9 -0
  36. package/dist/websocket/controllers/rest.js +47 -0
  37. package/dist/websocket/exceptions.d.ts +16 -0
  38. package/dist/websocket/exceptions.js +55 -0
  39. package/dist/websocket/handlers/heartbeat.d.ts +11 -0
  40. package/dist/websocket/handlers/heartbeat.js +72 -0
  41. package/dist/websocket/handlers/index.d.ts +4 -0
  42. package/dist/websocket/handlers/index.js +11 -0
  43. package/dist/websocket/handlers/items.d.ts +6 -0
  44. package/dist/websocket/handlers/items.js +103 -0
  45. package/dist/websocket/handlers/subscribe.d.ts +43 -0
  46. package/dist/websocket/handlers/subscribe.js +278 -0
  47. package/dist/websocket/messages.d.ts +311 -0
  48. package/dist/websocket/messages.js +96 -0
  49. package/dist/websocket/types.d.ts +34 -0
  50. package/dist/websocket/types.js +1 -0
  51. package/dist/websocket/utils/get-expires-at-for-token.d.ts +1 -0
  52. package/dist/websocket/utils/get-expires-at-for-token.js +8 -0
  53. package/dist/websocket/utils/message.d.ts +4 -0
  54. package/dist/websocket/utils/message.js +27 -0
  55. package/dist/websocket/utils/wait-for-message.d.ts +4 -0
  56. package/dist/websocket/utils/wait-for-message.js +45 -0
  57. package/package.json +19 -14
@@ -0,0 +1,26 @@
1
+ import { getWebSocketController } from '../websocket/controllers/index.js';
2
+ import emitter from '../emitter.js';
3
+ export class WebSocketService {
4
+ controller;
5
+ constructor() {
6
+ this.controller = getWebSocketController();
7
+ }
8
+ on(event, callback) {
9
+ emitter.onAction('websocket.' + event, callback);
10
+ }
11
+ off(event, callback) {
12
+ emitter.offAction('websocket.' + event, callback);
13
+ }
14
+ broadcast(message, filter) {
15
+ this.controller.clients.forEach((client) => {
16
+ if (filter && filter.user && filter.user !== client.accountability?.user)
17
+ return;
18
+ if (filter && filter.role && filter.role !== client.accountability?.role)
19
+ return;
20
+ client.send(typeof message === 'string' ? message : JSON.stringify(message));
21
+ });
22
+ }
23
+ clients() {
24
+ return this.controller.clients;
25
+ }
26
+ }
@@ -121,7 +121,12 @@ export async function applyDiff(currentSnapshot, snapshotDiff, options) {
121
121
  // then continue with nested collections recursively
122
122
  await createCollections(snapshotDiff.collections.filter(filterCollectionsForCreation));
123
123
  // delete top level collections (no group) first, then continue with nested collections recursively
124
- await deleteCollections(snapshotDiff.collections.filter(({ diff }) => diff[0]?.kind === DiffKind.DELETE && diff[0].lhs.meta?.group === null));
124
+ await deleteCollections(snapshotDiff.collections.filter(({ diff }) => {
125
+ if (diff.length === 0 || diff[0] === undefined)
126
+ return false;
127
+ const collectionDiff = diff[0];
128
+ return collectionDiff.kind === DiffKind.DELETE && collectionDiff.lhs?.meta?.group === null;
129
+ }));
125
130
  for (const { collection, diff } of snapshotDiff.collections) {
126
131
  if (diff?.[0]?.kind === DiffKind.EDIT || diff?.[0]?.kind === DiffKind.ARRAY) {
127
132
  const currentCollection = currentSnapshot.collections.find((field) => {
@@ -198,7 +203,11 @@ export async function applyDiff(currentSnapshot, snapshotDiff, options) {
198
203
  }
199
204
  if (diff?.[0]?.kind === DiffKind.NEW) {
200
205
  try {
201
- await relationsService.createOne(diff[0].rhs, mutationOptions);
206
+ await relationsService.createOne({
207
+ ...diff[0].rhs,
208
+ collection,
209
+ field,
210
+ }, mutationOptions);
202
211
  }
203
212
  catch (err) {
204
213
  logger.error(`Failed to create relation "${collection}.${field}"`);
@@ -1,5 +1,6 @@
1
1
  import { getFilterOperatorsForType, getOutputTypeForFunction } from '@directus/utils';
2
2
  import { clone, isPlainObject } from 'lodash-es';
3
+ import { customAlphabet } from 'nanoid/non-secure';
3
4
  import validate from 'uuid-validate';
4
5
  import { getHelpers } from '../database/helpers/index.js';
5
6
  import { InvalidQueryException } from '../exceptions/invalid-query.js';
@@ -7,8 +8,6 @@ import { getColumnPath } from './get-column-path.js';
7
8
  import { getColumn } from './get-column.js';
8
9
  import { getRelationInfo } from './get-relation-info.js';
9
10
  import { stripFunction } from './strip-function.js';
10
- // @ts-ignore
11
- import { customAlphabet } from 'nanoid/non-secure';
12
11
  export const generateAlias = customAlphabet('abcdefghijklmnopqrstuvwxyz', 5);
13
12
  /**
14
13
  * Apply the Query to a given Knex query builder instance
@@ -337,18 +336,18 @@ export function applyFilter(knex, schema, rootQuery, rootFilter, collection, ali
337
336
  // Knex supports "raw" in the columnName parameter, but isn't typed as such. Too bad..
338
337
  // See https://github.com/knex/knex/issues/4518 @TODO remove as any once knex is updated
339
338
  // These operators don't rely on a value, and can thus be used without one (eg `?filter[field][_null]`)
340
- if (operator === '_null' || (operator === '_nnull' && compareValue === false)) {
339
+ if ((operator === '_null' && compareValue !== false) || (operator === '_nnull' && compareValue === false)) {
341
340
  dbQuery[logical].whereNull(selectionRaw);
342
341
  }
343
- if (operator === '_nnull' || (operator === '_null' && compareValue === false)) {
342
+ if ((operator === '_nnull' && compareValue !== false) || (operator === '_null' && compareValue === false)) {
344
343
  dbQuery[logical].whereNotNull(selectionRaw);
345
344
  }
346
- if (operator === '_empty' || (operator === '_nempty' && compareValue === false)) {
345
+ if ((operator === '_empty' && compareValue !== false) || (operator === '_nempty' && compareValue === false)) {
347
346
  dbQuery[logical].andWhere((query) => {
348
347
  query.whereNull(key).orWhere(key, '=', '');
349
348
  });
350
349
  }
351
- if (operator === '_nempty' || (operator === '_empty' && compareValue === false)) {
350
+ if ((operator === '_nempty' && compareValue !== false) || (operator === '_empty' && compareValue === false)) {
352
351
  dbQuery[logical].andWhere((query) => {
353
352
  query.whereNotNull(key).andWhere(key, '!=', '');
354
353
  });
@@ -0,0 +1,2 @@
1
+ import type { Accountability } from '@directus/types';
2
+ export declare function getAccountabilityForToken(token?: string | null, accountability?: Accountability): Promise<Accountability>;
@@ -0,0 +1,50 @@
1
+ import getDatabase from '../database/index.js';
2
+ import isDirectusJWT from './is-directus-jwt.js';
3
+ import { InvalidCredentialsException } from '../index.js';
4
+ import env from '../env.js';
5
+ import { verifyAccessJWT } from './jwt.js';
6
+ export async function getAccountabilityForToken(token, accountability) {
7
+ if (!accountability) {
8
+ accountability = {
9
+ user: null,
10
+ role: null,
11
+ admin: false,
12
+ app: false,
13
+ };
14
+ }
15
+ if (token) {
16
+ if (isDirectusJWT(token)) {
17
+ const payload = verifyAccessJWT(token, env['SECRET']);
18
+ accountability.role = payload.role;
19
+ accountability.admin = payload.admin_access === true || payload.admin_access == 1;
20
+ accountability.app = payload.app_access === true || payload.app_access == 1;
21
+ if (payload.share)
22
+ accountability.share = payload.share;
23
+ if (payload.share_scope)
24
+ accountability.share_scope = payload.share_scope;
25
+ if (payload.id)
26
+ accountability.user = payload.id;
27
+ }
28
+ else {
29
+ // Try finding the user with the provided token
30
+ const database = getDatabase();
31
+ const user = await database
32
+ .select('directus_users.id', 'directus_users.role', 'directus_roles.admin_access', 'directus_roles.app_access')
33
+ .from('directus_users')
34
+ .leftJoin('directus_roles', 'directus_users.role', 'directus_roles.id')
35
+ .where({
36
+ 'directus_users.token': token,
37
+ status: 'active',
38
+ })
39
+ .first();
40
+ if (!user) {
41
+ throw new InvalidCredentialsException();
42
+ }
43
+ accountability.user = user.id;
44
+ accountability.role = user.role;
45
+ accountability.admin = user.admin_access === true || user.admin_access == 1;
46
+ accountability.app = user.app_access === true || user.app_access == 1;
47
+ }
48
+ }
49
+ return accountability;
50
+ }
@@ -0,0 +1,7 @@
1
+ import { ItemsService } from '../index.js';
2
+ import type { AbstractServiceOptions } from '../types/services.js';
3
+ /**
4
+ * Select the correct service for the given collection. This allows the individual services to run
5
+ * their custom checks (f.e. it allows UsersService to prevent updating TFA secret from outside)
6
+ */
7
+ export declare function getService(collection: string, opts: AbstractServiceOptions): ItemsService;
@@ -0,0 +1,49 @@
1
+ import { ActivityService, DashboardsService, FilesService, FlowsService, FoldersService, ItemsService, NotificationsService, OperationsService, PanelsService, PermissionsService, PresetsService, RevisionsService, RolesService, SettingsService, SharesService, UsersService, WebhooksService, } from '../index.js';
2
+ /**
3
+ * Select the correct service for the given collection. This allows the individual services to run
4
+ * their custom checks (f.e. it allows UsersService to prevent updating TFA secret from outside)
5
+ */
6
+ export function getService(collection, opts) {
7
+ switch (collection) {
8
+ case 'directus_activity':
9
+ return new ActivityService(opts);
10
+ // case 'directus_collections':
11
+ // return new CollectionsService(opts);
12
+ case 'directus_dashboards':
13
+ return new DashboardsService(opts);
14
+ // case 'directus_fields':
15
+ // return new FieldsService(opts);
16
+ case 'directus_files':
17
+ return new FilesService(opts);
18
+ case 'directus_flows':
19
+ return new FlowsService(opts);
20
+ case 'directus_folders':
21
+ return new FoldersService(opts);
22
+ case 'directus_notifications':
23
+ return new NotificationsService(opts);
24
+ case 'directus_operations':
25
+ return new OperationsService(opts);
26
+ case 'directus_panels':
27
+ return new PanelsService(opts);
28
+ case 'directus_permissions':
29
+ return new PermissionsService(opts);
30
+ case 'directus_presets':
31
+ return new PresetsService(opts);
32
+ // case 'directus_relations':
33
+ // return new RelationsService(opts);
34
+ case 'directus_revisions':
35
+ return new RevisionsService(opts);
36
+ case 'directus_roles':
37
+ return new RolesService(opts);
38
+ case 'directus_settings':
39
+ return new SettingsService(opts);
40
+ case 'directus_shares':
41
+ return new SharesService(opts);
42
+ case 'directus_users':
43
+ return new UsersService(opts);
44
+ case 'directus_webhooks':
45
+ return new WebhooksService(opts);
46
+ default:
47
+ return new ItemsService(collection, opts);
48
+ }
49
+ }
@@ -8,4 +8,8 @@ type Paths = string[][];
8
8
  * @returns Redacted object.
9
9
  */
10
10
  export declare function redact(input: UnknownObject, paths: Paths, replacement: string): UnknownObject;
11
+ /**
12
+ * Extract values from Error objects for use with JSON.stringify()
13
+ */
14
+ export declare function errorReplacer(_key: string, value: unknown): unknown;
11
15
  export {};
@@ -8,7 +8,7 @@ import { isObject } from '@directus/utils';
8
8
  */
9
9
  export function redact(input, paths, replacement) {
10
10
  const wildcardChars = ['*', '**'];
11
- const clone = structuredClone(input);
11
+ const clone = JSON.parse(JSON.stringify(input, errorReplacer));
12
12
  const visited = new WeakSet();
13
13
  traverse(clone, paths);
14
14
  return clone;
@@ -73,3 +73,17 @@ export function redact(input, paths, replacement) {
73
73
  }
74
74
  }
75
75
  }
76
+ /**
77
+ * Extract values from Error objects for use with JSON.stringify()
78
+ */
79
+ export function errorReplacer(_key, value) {
80
+ if (value instanceof Error) {
81
+ return {
82
+ name: value.name,
83
+ message: value.message,
84
+ stack: value.stack,
85
+ cause: value.cause,
86
+ };
87
+ }
88
+ return value;
89
+ }
@@ -0,0 +1,4 @@
1
+ /**
2
+ * Convert environment variable to Boolean
3
+ */
4
+ export declare function toBoolean(value: any): boolean;
@@ -0,0 +1,6 @@
1
+ /**
2
+ * Convert environment variable to Boolean
3
+ */
4
+ export function toBoolean(value) {
5
+ return value === 'true' || value === true || value === '1' || value === 1;
6
+ }
@@ -0,0 +1,6 @@
1
+ import type { Accountability } from '@directus/types';
2
+ import type { BasicAuthMessage } from './messages.js';
3
+ import type { AuthenticationState } from './types.js';
4
+ export declare function authenticateConnection(message: BasicAuthMessage & Record<string, any>): Promise<AuthenticationState>;
5
+ export declare function refreshAccountability(accountability: Accountability | null | undefined): Promise<Accountability>;
6
+ export declare function authenticationSuccess(uid?: string | number, refresh_token?: string): string;
@@ -0,0 +1,62 @@
1
+ import { DEFAULT_AUTH_PROVIDER } from '../constants.js';
2
+ import getDatabase from '../database/index.js';
3
+ import { InvalidCredentialsException } from '../exceptions/index.js';
4
+ import { AuthenticationService } from '../services/index.js';
5
+ import { getAccountabilityForRole } from '../utils/get-accountability-for-role.js';
6
+ import { getAccountabilityForToken } from '../utils/get-accountability-for-token.js';
7
+ import { getSchema } from '../utils/get-schema.js';
8
+ import { WebSocketException } from './exceptions.js';
9
+ import { getExpiresAtForToken } from './utils/get-expires-at-for-token.js';
10
+ export async function authenticateConnection(message) {
11
+ let access_token, refresh_token;
12
+ try {
13
+ if ('email' in message && 'password' in message) {
14
+ const authenticationService = new AuthenticationService({ schema: await getSchema() });
15
+ const { accessToken, refreshToken } = await authenticationService.login(DEFAULT_AUTH_PROVIDER, message);
16
+ access_token = accessToken;
17
+ refresh_token = refreshToken;
18
+ }
19
+ if ('refresh_token' in message) {
20
+ const authenticationService = new AuthenticationService({ schema: await getSchema() });
21
+ const { accessToken, refreshToken } = await authenticationService.refresh(message.refresh_token);
22
+ access_token = accessToken;
23
+ refresh_token = refreshToken;
24
+ }
25
+ if ('access_token' in message) {
26
+ access_token = message.access_token;
27
+ }
28
+ if (!access_token)
29
+ throw new Error();
30
+ const accountability = await getAccountabilityForToken(access_token);
31
+ const expires_at = getExpiresAtForToken(access_token);
32
+ return { accountability, expires_at, refresh_token };
33
+ }
34
+ catch (error) {
35
+ if (error instanceof InvalidCredentialsException && error.message === 'Token expired.') {
36
+ throw new WebSocketException('auth', 'TOKEN_EXPIRED', 'Token expired.', message['uid']);
37
+ }
38
+ throw new WebSocketException('auth', 'AUTH_FAILED', 'Authentication failed.', message['uid']);
39
+ }
40
+ }
41
+ export async function refreshAccountability(accountability) {
42
+ const result = await getAccountabilityForRole(accountability?.role || null, {
43
+ accountability: accountability || null,
44
+ schema: await getSchema(),
45
+ database: getDatabase(),
46
+ });
47
+ result.user = accountability?.user || null;
48
+ return result;
49
+ }
50
+ export function authenticationSuccess(uid, refresh_token) {
51
+ const message = {
52
+ type: 'auth',
53
+ status: 'ok',
54
+ };
55
+ if (uid !== undefined) {
56
+ message.uid = uid;
57
+ }
58
+ if (refresh_token !== undefined) {
59
+ message['refresh_token'] = refresh_token;
60
+ }
61
+ return JSON.stringify(message);
62
+ }
@@ -0,0 +1,42 @@
1
+ /// <reference types="node" resolution-mode="require"/>
2
+ /// <reference types="node" resolution-mode="require"/>
3
+ /// <reference types="node" resolution-mode="require"/>
4
+ /// <reference types="node" resolution-mode="require"/>
5
+ import type { IncomingMessage, Server as httpServer } from 'http';
6
+ import type { ParsedUrlQuery } from 'querystring';
7
+ import type { RateLimiterAbstract } from 'rate-limiter-flexible';
8
+ import type internal from 'stream';
9
+ import WebSocket from 'ws';
10
+ import { AuthMode, WebSocketAuthMessage, WebSocketMessage } from '../messages.js';
11
+ import type { AuthenticationState, UpgradeContext, WebSocketClient } from '../types.js';
12
+ export default abstract class SocketController {
13
+ server: WebSocket.Server;
14
+ clients: Set<WebSocketClient>;
15
+ authentication: {
16
+ mode: AuthMode;
17
+ timeout: number;
18
+ };
19
+ endpoint: string;
20
+ maxConnections: number;
21
+ private rateLimiter;
22
+ private authInterval;
23
+ constructor(httpServer: httpServer, configPrefix: string);
24
+ protected getEnvironmentConfig(configPrefix: string): {
25
+ endpoint: string;
26
+ authentication: {
27
+ mode: AuthMode;
28
+ timeout: number;
29
+ };
30
+ maxConnections: number;
31
+ };
32
+ protected getRateLimiter(): RateLimiterAbstract | null;
33
+ protected handleUpgrade(request: IncomingMessage, socket: internal.Duplex, head: Buffer): Promise<void>;
34
+ protected handleStrictUpgrade({ request, socket, head }: UpgradeContext, query: ParsedUrlQuery): Promise<void>;
35
+ protected handleHandshakeUpgrade({ request, socket, head }: UpgradeContext): Promise<void>;
36
+ createClient(ws: WebSocket, { accountability, expires_at }: AuthenticationState): WebSocketClient;
37
+ protected parseMessage(data: string): WebSocketMessage;
38
+ protected handleAuthRequest(client: WebSocketClient, message: WebSocketAuthMessage): Promise<void>;
39
+ setTokenExpireTimer(client: WebSocketClient): void;
40
+ checkClientTokens(): void;
41
+ terminate(): void;
42
+ }
@@ -0,0 +1,276 @@
1
+ import { parseJSON } from '@directus/utils';
2
+ import { parse } from 'url';
3
+ import { v4 as uuid } from 'uuid';
4
+ import WebSocket, { WebSocketServer } from 'ws';
5
+ import { fromZodError } from 'zod-validation-error';
6
+ import emitter from '../../emitter.js';
7
+ import env from '../../env.js';
8
+ import { InvalidConfigException, TokenExpiredException } from '../../exceptions/index.js';
9
+ import logger from '../../logger.js';
10
+ import { createRateLimiter } from '../../rate-limiter.js';
11
+ import { getAccountabilityForToken } from '../../utils/get-accountability-for-token.js';
12
+ import { toBoolean } from '../../utils/to-boolean.js';
13
+ import { authenticateConnection, authenticationSuccess } from '../authenticate.js';
14
+ import { WebSocketException, handleWebSocketException } from '../exceptions.js';
15
+ import { AuthMode, WebSocketAuthMessage, WebSocketMessage } from '../messages.js';
16
+ import { getExpiresAtForToken } from '../utils/get-expires-at-for-token.js';
17
+ import { getMessageType } from '../utils/message.js';
18
+ import { waitForAnyMessage, waitForMessageType } from '../utils/wait-for-message.js';
19
+ import { registerWebSocketEvents } from './hooks.js';
20
+ const TOKEN_CHECK_INTERVAL = 15 * 60 * 1000; // 15 minutes
21
+ export default class SocketController {
22
+ server;
23
+ clients;
24
+ authentication;
25
+ endpoint;
26
+ maxConnections;
27
+ rateLimiter;
28
+ authInterval;
29
+ constructor(httpServer, configPrefix) {
30
+ this.server = new WebSocketServer({ noServer: true });
31
+ this.clients = new Set();
32
+ this.authInterval = null;
33
+ const { endpoint, authentication, maxConnections } = this.getEnvironmentConfig(configPrefix);
34
+ this.endpoint = endpoint;
35
+ this.authentication = authentication;
36
+ this.maxConnections = maxConnections;
37
+ this.rateLimiter = this.getRateLimiter();
38
+ httpServer.on('upgrade', this.handleUpgrade.bind(this));
39
+ this.checkClientTokens();
40
+ registerWebSocketEvents();
41
+ }
42
+ getEnvironmentConfig(configPrefix) {
43
+ const endpoint = String(env[`${configPrefix}_PATH`]);
44
+ const authMode = AuthMode.safeParse(String(env[`${configPrefix}_AUTH`]).toLowerCase());
45
+ const authTimeout = Number(env[`${configPrefix}_AUTH_TIMEOUT`]) * 1000;
46
+ const maxConnections = `${configPrefix}_CONN_LIMIT` in env ? Number(env[`${configPrefix}_CONN_LIMIT`]) : Number.POSITIVE_INFINITY;
47
+ if (!authMode.success) {
48
+ throw new InvalidConfigException(fromZodError(authMode.error, { prefix: `${configPrefix}_AUTH` }).message);
49
+ }
50
+ return {
51
+ endpoint,
52
+ maxConnections,
53
+ authentication: {
54
+ mode: authMode.data,
55
+ timeout: authTimeout,
56
+ },
57
+ };
58
+ }
59
+ getRateLimiter() {
60
+ if (toBoolean(env['RATE_LIMITER_ENABLED']) === true) {
61
+ return createRateLimiter('RATE_LIMITER', {
62
+ keyPrefix: 'websocket',
63
+ });
64
+ }
65
+ return null;
66
+ }
67
+ async handleUpgrade(request, socket, head) {
68
+ const { pathname, query } = parse(request.url, true);
69
+ if (pathname !== this.endpoint)
70
+ return;
71
+ if (this.clients.size >= this.maxConnections) {
72
+ logger.debug('WebSocket upgrade denied - max connections reached');
73
+ socket.write('HTTP/1.1 403 Forbidden\r\n\r\n');
74
+ socket.destroy();
75
+ return;
76
+ }
77
+ const context = { request, socket, head };
78
+ if (this.authentication.mode === 'strict') {
79
+ await this.handleStrictUpgrade(context, query);
80
+ return;
81
+ }
82
+ if (this.authentication.mode === 'handshake') {
83
+ await this.handleHandshakeUpgrade(context);
84
+ return;
85
+ }
86
+ this.server.handleUpgrade(request, socket, head, async (ws) => {
87
+ const state = { accountability: null, expires_at: null };
88
+ this.server.emit('connection', ws, state);
89
+ });
90
+ }
91
+ async handleStrictUpgrade({ request, socket, head }, query) {
92
+ let accountability, expires_at;
93
+ try {
94
+ const token = query['access_token'];
95
+ accountability = await getAccountabilityForToken(token);
96
+ expires_at = getExpiresAtForToken(token);
97
+ }
98
+ catch {
99
+ accountability = null;
100
+ expires_at = null;
101
+ }
102
+ if (!accountability || !accountability.user) {
103
+ logger.debug('WebSocket upgrade denied - ' + JSON.stringify(accountability || 'invalid'));
104
+ socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
105
+ socket.destroy();
106
+ return;
107
+ }
108
+ this.server.handleUpgrade(request, socket, head, async (ws) => {
109
+ const state = { accountability, expires_at };
110
+ this.server.emit('connection', ws, state);
111
+ });
112
+ }
113
+ async handleHandshakeUpgrade({ request, socket, head }) {
114
+ this.server.handleUpgrade(request, socket, head, async (ws) => {
115
+ try {
116
+ const payload = await waitForAnyMessage(ws, this.authentication.timeout);
117
+ if (getMessageType(payload) !== 'auth')
118
+ throw new Error();
119
+ const state = await authenticateConnection(WebSocketAuthMessage.parse(payload));
120
+ ws.send(authenticationSuccess(payload['uid'], state.refresh_token));
121
+ this.server.emit('connection', ws, state);
122
+ }
123
+ catch {
124
+ logger.debug('WebSocket authentication handshake failed');
125
+ const error = new WebSocketException('auth', 'AUTH_FAILED', 'Authentication handshake failed.');
126
+ handleWebSocketException(ws, error, 'auth');
127
+ ws.close();
128
+ }
129
+ });
130
+ }
131
+ createClient(ws, { accountability, expires_at }) {
132
+ const client = ws;
133
+ client.accountability = accountability;
134
+ client.expires_at = expires_at;
135
+ client.uid = uuid();
136
+ client.auth_timer = null;
137
+ ws.on('message', async (data) => {
138
+ if (this.rateLimiter !== null) {
139
+ try {
140
+ await this.rateLimiter.consume(client.uid);
141
+ }
142
+ catch (limit) {
143
+ const timeout = limit?.msBeforeNext ?? this.rateLimiter.msDuration;
144
+ const error = new WebSocketException('server', 'REQUESTS_EXCEEDED', `Too many messages, retry after ${timeout}ms.`);
145
+ handleWebSocketException(client, error, 'server');
146
+ logger.debug(`WebSocket#${client.uid} is rate limited`);
147
+ return;
148
+ }
149
+ }
150
+ let message;
151
+ try {
152
+ message = this.parseMessage(data.toString());
153
+ }
154
+ catch (err) {
155
+ handleWebSocketException(client, err, 'server');
156
+ return;
157
+ }
158
+ if (getMessageType(message) === 'auth') {
159
+ try {
160
+ await this.handleAuthRequest(client, WebSocketAuthMessage.parse(message));
161
+ }
162
+ catch {
163
+ // ignore errors
164
+ }
165
+ return;
166
+ }
167
+ // this log cannot be higher in the function or it will leak credentials
168
+ logger.trace(`WebSocket#${client.uid} - ${JSON.stringify(message)}`);
169
+ ws.emit('parsed-message', message);
170
+ });
171
+ ws.on('error', () => {
172
+ logger.debug(`WebSocket#${client.uid} connection errored`);
173
+ if (client.auth_timer) {
174
+ clearTimeout(client.auth_timer);
175
+ client.auth_timer = null;
176
+ }
177
+ this.clients.delete(client);
178
+ });
179
+ ws.on('close', () => {
180
+ logger.debug(`WebSocket#${client.uid} connection closed`);
181
+ if (client.auth_timer) {
182
+ clearTimeout(client.auth_timer);
183
+ client.auth_timer = null;
184
+ }
185
+ this.clients.delete(client);
186
+ });
187
+ logger.debug(`WebSocket#${client.uid} connected`);
188
+ if (accountability) {
189
+ logger.trace(`WebSocket#${client.uid} authenticated as ${JSON.stringify(accountability)}`);
190
+ }
191
+ this.setTokenExpireTimer(client);
192
+ this.clients.add(client);
193
+ return client;
194
+ }
195
+ parseMessage(data) {
196
+ let message;
197
+ try {
198
+ message = WebSocketMessage.parse(parseJSON(data));
199
+ }
200
+ catch (err) {
201
+ throw new WebSocketException('server', 'INVALID_PAYLOAD', 'Unable to parse the incoming message.');
202
+ }
203
+ return message;
204
+ }
205
+ async handleAuthRequest(client, message) {
206
+ try {
207
+ const { accountability, expires_at, refresh_token } = await authenticateConnection(message);
208
+ client.accountability = accountability;
209
+ client.expires_at = expires_at;
210
+ this.setTokenExpireTimer(client);
211
+ emitter.emitAction('websocket.auth.success', { client });
212
+ client.send(authenticationSuccess(message.uid, refresh_token));
213
+ logger.trace(`WebSocket#${client.uid} authenticated as ${JSON.stringify(client.accountability)}`);
214
+ }
215
+ catch (error) {
216
+ logger.trace(`WebSocket#${client.uid} failed authentication`);
217
+ emitter.emitAction('websocket.auth.failure', { client });
218
+ client.accountability = null;
219
+ client.expires_at = null;
220
+ const _error = error instanceof WebSocketException
221
+ ? error
222
+ : new WebSocketException('auth', 'AUTH_FAILED', 'Authentication failed.', message.uid);
223
+ handleWebSocketException(client, _error, 'auth');
224
+ if (this.authentication.mode !== 'public') {
225
+ client.close();
226
+ }
227
+ }
228
+ }
229
+ setTokenExpireTimer(client) {
230
+ if (client.auth_timer !== null) {
231
+ // clear up old timeouts if needed
232
+ clearTimeout(client.auth_timer);
233
+ client.auth_timer = null;
234
+ }
235
+ if (!client.expires_at)
236
+ return;
237
+ const expiresIn = client.expires_at * 1000 - Date.now();
238
+ if (expiresIn > TOKEN_CHECK_INTERVAL)
239
+ return;
240
+ client.auth_timer = setTimeout(() => {
241
+ client.accountability = null;
242
+ client.expires_at = null;
243
+ handleWebSocketException(client, new TokenExpiredException(), 'auth');
244
+ waitForMessageType(client, 'auth', this.authentication.timeout).catch((msg) => {
245
+ const error = new WebSocketException('auth', 'AUTH_TIMEOUT', 'Authentication timed out.', msg?.uid);
246
+ handleWebSocketException(client, error, 'auth');
247
+ if (this.authentication.mode !== 'public') {
248
+ client.close();
249
+ }
250
+ });
251
+ }, expiresIn);
252
+ }
253
+ checkClientTokens() {
254
+ this.authInterval = setInterval(() => {
255
+ if (this.clients.size === 0)
256
+ return;
257
+ // check the clients and set shorter timeouts if needed
258
+ for (const client of this.clients) {
259
+ if (client.expires_at === null || client.auth_timer !== null)
260
+ continue;
261
+ this.setTokenExpireTimer(client);
262
+ }
263
+ }, TOKEN_CHECK_INTERVAL);
264
+ }
265
+ terminate() {
266
+ if (this.authInterval)
267
+ clearInterval(this.authInterval);
268
+ this.clients.forEach((client) => {
269
+ if (client.auth_timer)
270
+ clearTimeout(client.auth_timer);
271
+ });
272
+ this.server.clients.forEach((ws) => {
273
+ ws.terminate();
274
+ });
275
+ }
276
+ }
@@ -0,0 +1,12 @@
1
+ /// <reference types="node" resolution-mode="require"/>
2
+ import type { Server } from 'graphql-ws';
3
+ import type { Server as httpServer } from 'http';
4
+ import type { GraphQLSocket, UpgradeContext, WebSocketClient } from '../types.js';
5
+ import SocketController from './base.js';
6
+ export declare class GraphQLSubscriptionController extends SocketController {
7
+ gql: Server<GraphQLSocket>;
8
+ constructor(httpServer: httpServer);
9
+ private bindEvents;
10
+ setTokenExpireTimer(client: WebSocketClient): void;
11
+ protected handleHandshakeUpgrade({ request, socket, head }: UpgradeContext): Promise<void>;
12
+ }