@risingwave/wavelet-server 0.2.1 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,7 +1,16 @@
1
1
  {
2
2
  "name": "@risingwave/wavelet-server",
3
- "version": "0.2.1",
3
+ "version": "0.2.4",
4
4
  "description": "Wavelet server - WebSocket fanout layer for RisingWave",
5
+ "homepage": "https://github.com/risingwavelabs/wavelet",
6
+ "repository": {
7
+ "type": "git",
8
+ "url": "https://github.com/risingwavelabs/wavelet.git",
9
+ "directory": "packages/server"
10
+ },
11
+ "bugs": {
12
+ "url": "https://github.com/risingwavelabs/wavelet/issues"
13
+ },
5
14
  "main": "./dist/index.js",
6
15
  "types": "./dist/index.d.ts",
7
16
  "scripts": {
@@ -13,7 +22,7 @@
13
22
  "pg": "^8.13.0",
14
23
  "ws": "^8.18.0",
15
24
  "jose": "^6.0.0",
16
- "@risingwave/wavelet": "0.2.1"
25
+ "@risingwave/wavelet": "0.2.4"
17
26
  },
18
27
  "devDependencies": {
19
28
  "@types/pg": "^8.11.0",
@@ -188,10 +188,12 @@ describe.runIf(process.env.WAVELET_INTEGRATION === '1')('Integration: Full Serve
188
188
  it('pushes diffs via WebSocket', async () => {
189
189
  const diff = await new Promise<any>((resolve, reject) => {
190
190
  const ws = new WebSocket(`ws://localhost:${port}/subscribe/${VIEW_NAME}`)
191
+ let sawSnapshot = false
191
192
 
192
193
  ws.on('message', (data: Buffer) => {
193
194
  const msg = JSON.parse(data.toString())
194
- if (msg.type === 'connected') {
195
+ if (msg.type === 'snapshot') {
196
+ sawSnapshot = true
195
197
  // Write event to trigger diff
196
198
  fetch(`http://localhost:${port}/v1/events/${STREAM_NAME}`, {
197
199
  method: 'POST',
@@ -200,6 +202,7 @@ describe.runIf(process.env.WAVELET_INTEGRATION === '1')('Integration: Full Serve
200
202
  })
201
203
  }
202
204
  if (msg.type === 'diff') {
205
+ expect(sawSnapshot).toBe(true)
203
206
  ws.close()
204
207
  resolve(msg)
205
208
  }
@@ -214,4 +217,33 @@ describe.runIf(process.env.WAVELET_INTEGRATION === '1')('Integration: Full Serve
214
217
  // Should have at least one insert or update
215
218
  expect(diff.inserted.length + diff.updated.length).toBeGreaterThan(0)
216
219
  }, 20000)
220
+
221
+ it('pushes a snapshot immediately on connect', async () => {
222
+ await fetch(`http://localhost:${port}/v1/events/${STREAM_NAME}`, {
223
+ method: 'POST',
224
+ headers: { 'Content-Type': 'application/json' },
225
+ body: JSON.stringify({ user_id: 'snapshot_user', value: 11 }),
226
+ })
227
+
228
+ await new Promise(r => setTimeout(r, 2000))
229
+
230
+ const snapshot = await new Promise<any>((resolve, reject) => {
231
+ const ws = new WebSocket(`ws://localhost:${port}/subscribe/${VIEW_NAME}`)
232
+
233
+ ws.on('message', (data: Buffer) => {
234
+ const msg = JSON.parse(data.toString())
235
+ if (msg.type === 'snapshot') {
236
+ ws.close()
237
+ resolve(msg)
238
+ }
239
+ })
240
+
241
+ ws.on('error', reject)
242
+ setTimeout(() => { ws.close(); reject(new Error('WebSocket snapshot timeout')) }, 15000)
243
+ })
244
+
245
+ expect(snapshot.type).toBe('snapshot')
246
+ expect(Array.isArray(snapshot.rows)).toBe(true)
247
+ expect(snapshot.rows.some((row: any) => row.user_id === 'snapshot_user')).toBe(true)
248
+ }, 20000)
217
249
  })
@@ -0,0 +1,143 @@
1
+ import { EventEmitter } from 'node:events'
2
+ import { describe, it, expect, vi } from 'vitest'
3
+ import { WebSocket } from 'ws'
4
+ import { WebSocketFanout } from '../ws-fanout.js'
5
+ import type { BootstrapResult, ViewDiff } from '../cursor-manager.js'
6
+
7
+ class MockSocket extends EventEmitter {
8
+ readyState: number = WebSocket.OPEN
9
+ sent: string[] = []
10
+
11
+ send(message: string): void {
12
+ this.sent.push(message)
13
+ }
14
+
15
+ ping(): void {}
16
+
17
+ close(): void {
18
+ this.readyState = WebSocket.CLOSED
19
+ this.emit('close')
20
+ }
21
+ }
22
+
23
+ function deferred<T>() {
24
+ let resolve!: (value: T) => void
25
+ let reject!: (reason?: unknown) => void
26
+ const promise = new Promise<T>((res, rej) => {
27
+ resolve = res
28
+ reject = rej
29
+ })
30
+ return { promise, resolve, reject }
31
+ }
32
+
33
+ describe('WebSocketFanout', () => {
34
+ it('drops queued shared diffs that are already covered by bootstrap', async () => {
35
+ const bootstrapDeferred = deferred<BootstrapResult>()
36
+ const cursorManager = {
37
+ bootstrap: vi.fn().mockReturnValue(bootstrapDeferred.promise),
38
+ } as any
39
+ const jwt = {
40
+ isConfigured: () => false,
41
+ } as any
42
+
43
+ const fanout = new WebSocketFanout(cursorManager, jwt, {
44
+ leaderboard: {} as any,
45
+ })
46
+
47
+ const ws = new MockSocket()
48
+ const req = {
49
+ url: '/subscribe/leaderboard',
50
+ headers: { host: 'localhost' },
51
+ } as any
52
+
53
+ const connectionPromise = (fanout as any).handleConnection(ws, req)
54
+
55
+ fanout.broadcast('leaderboard', {
56
+ cursor: '150',
57
+ inserted: [{ player_id: 'alice', score: 10 }],
58
+ updated: [],
59
+ deleted: [],
60
+ })
61
+ fanout.broadcast('leaderboard', {
62
+ cursor: '300',
63
+ inserted: [{ player_id: 'bob', score: 20 }],
64
+ updated: [],
65
+ deleted: [],
66
+ })
67
+
68
+ bootstrapDeferred.resolve({
69
+ snapshotRows: [{ player_id: 'alice', score: 10 }],
70
+ diffs: [{
71
+ cursor: '200',
72
+ inserted: [],
73
+ updated: [{ player_id: 'alice', score: 15 }],
74
+ deleted: [],
75
+ }],
76
+ lastCursor: '200',
77
+ })
78
+
79
+ await connectionPromise
80
+
81
+ const messages = ws.sent.map((message) => JSON.parse(message))
82
+ expect(messages).toEqual([
83
+ { type: 'connected', query: 'leaderboard' },
84
+ {
85
+ type: 'snapshot',
86
+ query: 'leaderboard',
87
+ rows: [{ player_id: 'alice', score: 10 }],
88
+ },
89
+ {
90
+ type: 'diff',
91
+ query: 'leaderboard',
92
+ cursor: '200',
93
+ inserted: [],
94
+ updated: [{ player_id: 'alice', score: 15 }],
95
+ deleted: [],
96
+ },
97
+ {
98
+ type: 'diff',
99
+ query: 'leaderboard',
100
+ cursor: '300',
101
+ inserted: [{ player_id: 'bob', score: 20 }],
102
+ updated: [],
103
+ deleted: [],
104
+ },
105
+ ])
106
+ })
107
+
108
+ it('filters snapshot rows with the same claim rule as diffs', async () => {
109
+ const cursorManager = {
110
+ bootstrap: vi.fn().mockResolvedValue({
111
+ snapshotRows: [
112
+ { user_id: 'u1', total: 10 },
113
+ { user_id: 'u2', total: 20 },
114
+ ],
115
+ diffs: [] as ViewDiff[],
116
+ lastCursor: null,
117
+ }),
118
+ } as any
119
+ const jwt = {
120
+ isConfigured: () => true,
121
+ verify: vi.fn().mockResolvedValue({ user_id: 'u1' }),
122
+ } as any
123
+
124
+ const fanout = new WebSocketFanout(cursorManager, jwt, {
125
+ totals: { filterBy: 'user_id' } as any,
126
+ })
127
+
128
+ const ws = new MockSocket()
129
+ const req = {
130
+ url: '/subscribe/totals?token=test-token',
131
+ headers: { host: 'localhost' },
132
+ } as any
133
+
134
+ await (fanout as any).handleConnection(ws, req)
135
+
136
+ const snapshotMessage = JSON.parse(ws.sent[1])
137
+ expect(snapshotMessage).toEqual({
138
+ type: 'snapshot',
139
+ query: 'totals',
140
+ rows: [{ user_id: 'u1', total: 10 }],
141
+ })
142
+ })
143
+ })
@@ -16,6 +16,12 @@ export interface ViewDiff {
16
16
  deleted: Record<string, unknown>[]
17
17
  }
18
18
 
19
+ export interface BootstrapResult {
20
+ snapshotRows: Record<string, unknown>[]
21
+ diffs: ViewDiff[]
22
+ lastCursor: string | null
23
+ }
24
+
19
25
  type DiffCallback = (queryName: string, diff: ViewDiff) => void
20
26
 
21
27
  /**
@@ -122,9 +128,12 @@ export class CursorManager {
122
128
  }
123
129
  }
124
130
 
125
- const diff = this.parseDiffs(allRows)
131
+ const diffs = this.parseDiffBatches(allRows)
126
132
 
127
- if (diff.inserted.length > 0 || diff.updated.length > 0 || diff.deleted.length > 0) {
133
+ for (const diff of diffs) {
134
+ if (diff.inserted.length === 0 && diff.updated.length === 0 && diff.deleted.length === 0) {
135
+ continue
136
+ }
128
137
  callback(queryName, diff)
129
138
  }
130
139
  } catch (err: any) {
@@ -173,6 +182,88 @@ export class CursorManager {
173
182
  return diff
174
183
  }
175
184
 
185
+ parseDiffBatches(rows: any[]): ViewDiff[] {
186
+ const diffs: ViewDiff[] = []
187
+ let currentRows: any[] = []
188
+ let currentCursor: string | null = null
189
+
190
+ for (const row of rows) {
191
+ const cursor = this.normalizeCursor(row.rw_timestamp)
192
+ if (!cursor) continue
193
+
194
+ if (currentCursor !== null && cursor !== currentCursor) {
195
+ diffs.push(this.parseDiffs(currentRows))
196
+ currentRows = []
197
+ }
198
+
199
+ currentCursor = cursor
200
+ currentRows.push(row)
201
+ }
202
+
203
+ if (currentRows.length > 0) {
204
+ diffs.push(this.parseDiffs(currentRows))
205
+ }
206
+
207
+ return diffs
208
+ }
209
+
210
+ async bootstrap(queryName: string): Promise<BootstrapResult> {
211
+ const subName = this.subscriptions.get(queryName)
212
+ if (!subName) {
213
+ throw new Error(
214
+ `Subscription for query '${queryName}' is not initialized. Start the server before accepting WebSocket clients.`
215
+ )
216
+ }
217
+
218
+ const conn = new Client({ connectionString: this.connectionString })
219
+ await conn.connect()
220
+
221
+ const cursorName = `wavelet_boot_${Date.now()}_${Math.random().toString(36).slice(2, 8)}`
222
+
223
+ try {
224
+ await conn.query(`DECLARE ${cursorName} SUBSCRIPTION CURSOR FOR ${subName} FULL`)
225
+
226
+ const snapshotRows: Record<string, unknown>[] = []
227
+ const incrementalRows: any[] = []
228
+ let readingSnapshot = true
229
+
230
+ while (readingSnapshot) {
231
+ const result = await conn.query(`FETCH 1000 FROM ${cursorName}`)
232
+ if (result.rows.length === 0) break
233
+
234
+ let firstIncrementalIndex = result.rows.findIndex((row) => this.normalizeCursor(row.rw_timestamp) !== null)
235
+ if (firstIncrementalIndex === -1) firstIncrementalIndex = result.rows.length
236
+
237
+ for (const row of result.rows.slice(0, firstIncrementalIndex)) {
238
+ snapshotRows.push(this.stripSubscriptionMetadata(row))
239
+ }
240
+
241
+ if (firstIncrementalIndex < result.rows.length) {
242
+ incrementalRows.push(...result.rows.slice(firstIncrementalIndex))
243
+ readingSnapshot = false
244
+ }
245
+ }
246
+
247
+ while (true) {
248
+ const result = await conn.query(`FETCH 1000 FROM ${cursorName}`)
249
+ if (result.rows.length === 0) break
250
+ incrementalRows.push(...result.rows)
251
+ }
252
+
253
+ const diffs = this.parseDiffBatches(incrementalRows)
254
+ const lastCursor = diffs.length > 0 ? diffs[diffs.length - 1].cursor : null
255
+
256
+ return { snapshotRows, diffs, lastCursor }
257
+ } finally {
258
+ try {
259
+ await conn.query(`CLOSE ${cursorName}`)
260
+ } catch {}
261
+ try {
262
+ await conn.end()
263
+ } catch {}
264
+ }
265
+ }
266
+
176
267
  async query(sql: string): Promise<any[]> {
177
268
  if (!this.client) throw new Error('Not connected')
178
269
  const result = await this.client.query(sql)
@@ -206,4 +297,15 @@ export class CursorManager {
206
297
  } catch {}
207
298
  this.client = null
208
299
  }
300
+
301
+ private stripSubscriptionMetadata(row: Record<string, unknown>): Record<string, unknown> {
302
+ const { op: _op, rw_timestamp: _rwTimestamp, ...data } = row
303
+ return data
304
+ }
305
+
306
+ private normalizeCursor(value: unknown): string | null {
307
+ if (value === null || value === undefined) return null
308
+ const cursor = String(value).trim()
309
+ return cursor === '' ? null : cursor
310
+ }
209
311
  }
package/src/ws-fanout.ts CHANGED
@@ -8,6 +8,8 @@ interface Subscriber {
8
8
  ws: WebSocket
9
9
  queryName: string
10
10
  claims: JwtClaims | null
11
+ ready: boolean
12
+ pendingDiffs: ViewDiff[]
11
13
  }
12
14
 
13
15
  export class WebSocketFanout {
@@ -80,7 +82,13 @@ export class WebSocketFanout {
80
82
  claims = await this.jwt.verify(token)
81
83
  }
82
84
 
83
- const subscriber: Subscriber = { ws, queryName, claims }
85
+ const subscriber: Subscriber = {
86
+ ws,
87
+ queryName,
88
+ claims,
89
+ ready: false,
90
+ pendingDiffs: [],
91
+ }
84
92
 
85
93
  if (!this.subscribers.has(queryName)) {
86
94
  this.subscribers.set(queryName, new Set())
@@ -101,40 +109,84 @@ export class WebSocketFanout {
101
109
  ws.on('pong', () => { /* connection alive */ })
102
110
 
103
111
  ws.send(JSON.stringify({ type: 'connected', query: queryName }))
112
+
113
+ const bootstrap = await this.cursorManager.bootstrap(queryName)
114
+ const snapshotRows = this.filterSnapshotRows(queryName, bootstrap.snapshotRows, claims)
115
+ ws.send(JSON.stringify({
116
+ type: 'snapshot',
117
+ query: queryName,
118
+ rows: snapshotRows,
119
+ }))
120
+
121
+ for (const diff of bootstrap.diffs) {
122
+ const filteredDiff = this.filterDiffForSubscriber(queryName, diff, claims)
123
+ if (this.isEmptyDiff(filteredDiff)) continue
124
+ if (ws.readyState !== WebSocket.OPEN) break
125
+ ws.send(this.serializeDiffMessage(queryName, filteredDiff))
126
+ }
127
+
128
+ const handoffCursor = bootstrap.lastCursor
129
+ subscriber.ready = true
130
+ for (const diff of subscriber.pendingDiffs) {
131
+ if (ws.readyState !== WebSocket.OPEN) break
132
+ if (handoffCursor && this.compareCursor(diff.cursor, handoffCursor) <= 0) {
133
+ continue
134
+ }
135
+ ws.send(this.serializeDiffMessage(queryName, diff))
136
+ }
137
+ subscriber.pendingDiffs = []
104
138
  }
105
139
 
106
140
  broadcast(queryName: string, diff: ViewDiff): void {
107
141
  const subs = this.subscribers.get(queryName)
108
142
  if (!subs || subs.size === 0) return
109
143
 
110
- const queryDef = this.queries[queryName]
111
- const filterBy = this.getFilterBy(queryDef)
112
-
113
144
  for (const sub of subs) {
114
145
  if (sub.ws.readyState !== WebSocket.OPEN) continue
115
146
 
116
- let filteredDiff = diff
117
- if (filterBy && sub.claims) {
118
- filteredDiff = this.filterDiff(diff, filterBy, sub.claims)
119
- }
147
+ const filteredDiff = this.filterDiffForSubscriber(queryName, diff, sub.claims)
148
+ if (this.isEmptyDiff(filteredDiff)) continue
120
149
 
121
- if (
122
- filteredDiff.inserted.length === 0 &&
123
- filteredDiff.updated.length === 0 &&
124
- filteredDiff.deleted.length === 0
125
- ) {
150
+ if (!sub.ready) {
151
+ sub.pendingDiffs.push(filteredDiff)
126
152
  continue
127
153
  }
128
154
 
129
- sub.ws.send(JSON.stringify({
130
- type: 'diff',
131
- query: queryName,
132
- cursor: filteredDiff.cursor,
133
- inserted: filteredDiff.inserted,
134
- updated: filteredDiff.updated,
135
- deleted: filteredDiff.deleted,
136
- }))
155
+ sub.ws.send(this.serializeDiffMessage(queryName, filteredDiff))
156
+ }
157
+ }
158
+
159
+ private filterSnapshotRows(
160
+ queryName: string,
161
+ rows: Record<string, unknown>[],
162
+ claims: JwtClaims | null
163
+ ): Record<string, unknown>[] {
164
+ const queryDef = this.queries[queryName]
165
+ const filterBy = this.getFilterBy(queryDef)
166
+
167
+ if (filterBy && claims) {
168
+ const claimValue = claims[filterBy]
169
+ if (claimValue === undefined) return []
170
+
171
+ return rows.filter((row) => String(row[filterBy]) === String(claimValue))
137
172
  }
173
+
174
+ return rows
175
+ }
176
+
177
+ private filterDiffForSubscriber(
178
+ queryName: string,
179
+ diff: ViewDiff,
180
+ claims: JwtClaims | null
181
+ ): ViewDiff {
182
+ const queryDef = this.queries[queryName]
183
+ const filterBy = this.getFilterBy(queryDef)
184
+
185
+ if (filterBy && claims) {
186
+ return this.filterDiff(diff, filterBy, claims)
187
+ }
188
+
189
+ return diff
138
190
  }
139
191
 
140
192
  private filterDiff(diff: ViewDiff, filterBy: string, claims: JwtClaims): ViewDiff {
@@ -160,6 +212,28 @@ export class WebSocketFanout {
160
212
  return (queryDef as QueryDef).filterBy
161
213
  }
162
214
 
215
+ private isEmptyDiff(diff: ViewDiff): boolean {
216
+ return diff.inserted.length === 0 && diff.updated.length === 0 && diff.deleted.length === 0
217
+ }
218
+
219
+ private serializeDiffMessage(queryName: string, diff: ViewDiff): string {
220
+ return JSON.stringify({
221
+ type: 'diff',
222
+ query: queryName,
223
+ cursor: diff.cursor,
224
+ inserted: diff.inserted,
225
+ updated: diff.updated,
226
+ deleted: diff.deleted,
227
+ })
228
+ }
229
+
230
+ private compareCursor(left: string, right: string): number {
231
+ const leftValue = BigInt(left)
232
+ const rightValue = BigInt(right)
233
+ if (leftValue === rightValue) return 0
234
+ return leftValue < rightValue ? -1 : 1
235
+ }
236
+
163
237
  closeAll(): void {
164
238
  for (const [, subs] of this.subscribers) {
165
239
  for (const sub of subs) {