@atproto/lex-server 0.0.11 → 0.0.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. package/CHANGELOG.md +27 -0
  2. package/README.md +38 -21
  3. package/dist/errors.d.ts +28 -58
  4. package/dist/errors.d.ts.map +1 -1
  5. package/dist/errors.js +72 -72
  6. package/dist/errors.js.map +1 -1
  7. package/dist/index.d.ts +1 -2
  8. package/dist/index.d.ts.map +1 -1
  9. package/dist/index.js +1 -4
  10. package/dist/index.js.map +1 -1
  11. package/dist/{lex-server.d.ts → lex-router.d.ts} +55 -21
  12. package/dist/lex-router.d.ts.map +1 -0
  13. package/dist/{lex-server.js → lex-router.js} +169 -73
  14. package/dist/lex-router.js.map +1 -0
  15. package/dist/lib/drain-websocket.d.ts +7 -0
  16. package/dist/lib/drain-websocket.d.ts.map +1 -1
  17. package/dist/lib/drain-websocket.js +11 -0
  18. package/dist/lib/drain-websocket.js.map +1 -1
  19. package/dist/lib/www-authenticate.d.ts +4 -3
  20. package/dist/lib/www-authenticate.d.ts.map +1 -1
  21. package/dist/lib/www-authenticate.js +29 -16
  22. package/dist/lib/www-authenticate.js.map +1 -1
  23. package/dist/nodejs.d.ts +1 -1
  24. package/dist/nodejs.d.ts.map +1 -1
  25. package/dist/nodejs.js +1 -1
  26. package/dist/nodejs.js.map +1 -1
  27. package/dist/service-auth.d.ts +1 -1
  28. package/dist/service-auth.d.ts.map +1 -1
  29. package/dist/service-auth.js.map +1 -1
  30. package/package.json +9 -8
  31. package/src/errors.test.ts +262 -0
  32. package/src/errors.ts +103 -78
  33. package/src/index.ts +1 -7
  34. package/src/{lex-server.test.ts → lex-router.test.ts} +591 -24
  35. package/src/{lex-server.ts → lex-router.ts} +275 -119
  36. package/src/lib/drain-websocket.ts +11 -0
  37. package/src/lib/www-authenticate.test.ts +134 -0
  38. package/src/lib/www-authenticate.ts +36 -17
  39. package/src/nodejs.ts +2 -2
  40. package/src/service-auth.ts +1 -1
  41. package/dist/lex-server.d.ts.map +0 -1
  42. package/dist/lex-server.js.map +0 -1
@@ -4,13 +4,18 @@ import { describe, expect, it, vi } from 'vitest'
4
4
  import { WebSocket } from 'ws'
5
5
  import { decodeAll } from '@atproto/lex-cbor'
6
6
  import { buildAgent, xrpc } from '@atproto/lex-client'
7
- import { LexError, parseCid } from '@atproto/lex-data'
7
+ import { parseCid } from '@atproto/lex-data'
8
8
  import { l } from '@atproto/lex-schema'
9
+ import { LexError, LexServerAuthError, LexServerError } from './errors.js'
9
10
  import {
11
+ ConnectionInfo,
12
+ HandlerErrorHook,
13
+ HealthCheckHandler,
10
14
  LexRouter,
11
15
  LexRouterAuth,
12
16
  LexRouterMethodHandler,
13
- } from './lex-server.js'
17
+ SocketErrorHook,
18
+ } from './lex-router.js'
14
19
  import { serve, upgradeWebSocket } from './nodejs.js'
15
20
 
16
21
  // ============================================================================
@@ -83,7 +88,7 @@ const handlers: {
83
88
  // Basic LexRouter Tests
84
89
  // ============================================================================
85
90
 
86
- describe('LexRouter', () => {
91
+ describe(LexRouter, () => {
87
92
  it('returns MethodNotImplemented when the route is not found', async () => {
88
93
  const router = new LexRouter()
89
94
  const request = new Request(`https://example.com/xrpc/foo.bar.baz`)
@@ -291,14 +296,24 @@ describe('Authentication', () => {
291
296
  return async ({ request }) => {
292
297
  const header = request.headers.get('authorization') ?? ''
293
298
  if (!header.startsWith('Basic ')) {
294
- throw new LexError('AuthenticationRequired', 'Authentication required')
299
+ throw new LexServerAuthError(
300
+ 'AuthenticationRequired',
301
+ 'Authentication required',
302
+ )
295
303
  }
296
304
  const original = header.slice(6)
297
- const [username, password] = Buffer.from(original, 'base64')
298
- .toString()
299
- .split(':')
305
+ const decoded = Buffer.from(original, 'base64').toString()
306
+ // @NOTE not using .split(':') to allow colons in password
307
+ const colonIndex = decoded.indexOf(':')
308
+ const [username, password] =
309
+ colonIndex === -1
310
+ ? [decoded, '']
311
+ : [decoded.slice(0, colonIndex), decoded.slice(colonIndex + 1)]
300
312
  if (username !== allowed.username || password !== allowed.password) {
301
- throw new LexError('AuthenticationRequired', 'Invalid credentials')
313
+ throw new LexServerAuthError(
314
+ 'AuthenticationRequired',
315
+ 'Invalid credentials',
316
+ )
302
317
  }
303
318
  return { username, original }
304
319
  }
@@ -337,7 +352,7 @@ describe('Authentication', () => {
337
352
  )
338
353
  const response = await router.fetch(request)
339
354
 
340
- expect(response.status).toBe(400)
355
+ expect(response.status).toBe(401)
341
356
  const data = await response.json()
342
357
  expect(data.error).toBe('AuthenticationRequired')
343
358
  })
@@ -407,7 +422,7 @@ describe('Authentication', () => {
407
422
  )
408
423
  const response = await router.fetch(request)
409
424
 
410
- expect(response.status).toBe(400)
425
+ expect(response.status).toBe(401)
411
426
  const data = await response.json()
412
427
  expect(data.error).toBe('AuthenticationRequired')
413
428
  })
@@ -451,7 +466,10 @@ describe('Error Handling', () => {
451
466
  params,
452
467
  }) => {
453
468
  if (params.which === 'foo') {
454
- throw new LexError('Foo', 'It was this one!')
469
+ throw new LexServerError(400, {
470
+ error: 'Foo',
471
+ message: 'It was this one!',
472
+ })
455
473
  }
456
474
  return {}
457
475
  }
@@ -495,7 +513,7 @@ describe('Error Handling', () => {
495
513
  expect(data.message).toBe('It was that one!')
496
514
  })
497
515
 
498
- it('handles falsy values thrown as InternalError', async () => {
516
+ it('handles falsy values thrown as InternalServerError', async () => {
499
517
  const handler: LexRouterMethodHandler<
500
518
  typeof io.example.throwFalsyValue
501
519
  > = async () => {
@@ -511,7 +529,7 @@ describe('Error Handling', () => {
511
529
 
512
530
  expect(response.status).toBe(500)
513
531
  const data = await response.json()
514
- expect(data.error).toBe('InternalError')
532
+ expect(data.error).toBe('InternalServerError')
515
533
  })
516
534
  })
517
535
 
@@ -577,7 +595,7 @@ describe('Error Handling', () => {
577
595
 
578
596
  describe('Custom Error Handlers', () => {
579
597
  it('allows custom onHandlerError handler', async () => {
580
- const onHandlerError = vi.fn()
598
+ const onHandlerError = vi.fn<HandlerErrorHook>()
581
599
  const customRouter = new LexRouter({
582
600
  onHandlerError,
583
601
  })
@@ -599,6 +617,342 @@ describe('Error Handling', () => {
599
617
  })
600
618
  })
601
619
 
620
+ // ============================================================================
621
+ // Routing Tests
622
+ // ============================================================================
623
+
624
+ describe('Routing', () => {
625
+ describe('non-/xrpc/ paths', () => {
626
+ it('returns 404 for non-xrpc paths without fallback', async () => {
627
+ const router = new LexRouter()
628
+ const request = new Request('https://example.com/health')
629
+ const response = await router.fetch(request)
630
+
631
+ expect(response.status).toBe(404)
632
+ expect(await response.text()).toBe('Not Found')
633
+ })
634
+
635
+ it('delegates to fallback handler for non-xrpc paths', async () => {
636
+ const fallback = vi.fn(async () => new Response('OK from fallback'))
637
+ const router = new LexRouter({ fallback })
638
+
639
+ const request = new Request('https://example.com/health')
640
+ const connection: ConnectionInfo = {
641
+ completed: Promise.resolve(),
642
+ remoteAddr: { hostname: '127.0.0.1', port: 3000, transport: 'tcp' },
643
+ }
644
+ const response = await router.fetch(request, connection)
645
+
646
+ expect(fallback).toHaveBeenCalledWith(request, connection)
647
+ expect(response.status).toBe(200)
648
+ expect(await response.text()).toBe('OK from fallback')
649
+ })
650
+ })
651
+
652
+ describe('/xrpc/_health endpoint', () => {
653
+ it('returns default health check response', async () => {
654
+ const router = new LexRouter()
655
+ const request = new Request('https://example.com/xrpc/_health')
656
+ const response = await router.fetch(request)
657
+
658
+ expect(response.status).toBe(200)
659
+ expect(await response.json()).toEqual({ status: 'ok' })
660
+ })
661
+
662
+ it('calls custom healthCheck handler', async () => {
663
+ const healthCheck = vi.fn<HealthCheckHandler>(async () => ({
664
+ status: 'ok',
665
+ version: '1.0.0',
666
+ }))
667
+ const router = new LexRouter({ healthCheck })
668
+
669
+ const request = new Request('https://example.com/xrpc/_health')
670
+ const response = await router.fetch(request)
671
+
672
+ expect(healthCheck).toHaveBeenCalledWith(request)
673
+ expect(response.status).toBe(200)
674
+ expect(await response.json()).toEqual({ status: 'ok', version: '1.0.0' })
675
+ })
676
+
677
+ it('returns 405 for non-GET requests', async () => {
678
+ const router = new LexRouter()
679
+ const request = new Request('https://example.com/xrpc/_health', {
680
+ method: 'POST',
681
+ })
682
+ const response = await router.fetch(request)
683
+
684
+ expect(response.status).toBe(405)
685
+ const data = await response.json()
686
+ expect(data.error).toBe('InvalidRequest')
687
+ expect(data.message).toBe('Method not allowed')
688
+ })
689
+
690
+ it('returns 400 when atproto-proxy header is set', async () => {
691
+ const router = new LexRouter()
692
+ const request = new Request('https://example.com/xrpc/_health', {
693
+ headers: { 'atproto-proxy': 'did:plc:example#atproto_labeler' },
694
+ })
695
+ const response = await router.fetch(request)
696
+
697
+ expect(response.status).toBe(400)
698
+ const data = await response.json()
699
+ expect(data.error).toBe('InvalidRequest')
700
+ expect(data.message).toContain('atproto-proxy')
701
+ })
702
+
703
+ it('does not call healthCheck when atproto-proxy is set', async () => {
704
+ const healthCheck = vi.fn<HealthCheckHandler>(async () => ({
705
+ status: 'ok',
706
+ }))
707
+ const router = new LexRouter({ healthCheck })
708
+ const request = new Request('https://example.com/xrpc/_health', {
709
+ headers: { 'atproto-proxy': 'did:plc:example#atproto_labeler' },
710
+ })
711
+ const response = await router.fetch(request)
712
+
713
+ expect(healthCheck).not.toHaveBeenCalled()
714
+ expect(response.status).toBe(400)
715
+ })
716
+ })
717
+
718
+ describe('invalid NSID', () => {
719
+ it('returns 400 for invalid NSID format', async () => {
720
+ const router = new LexRouter()
721
+ const request = new Request('https://example.com/xrpc/not-an-nsid!!')
722
+ const response = await router.fetch(request)
723
+
724
+ expect(response.status).toBe(400)
725
+ const data = await response.json()
726
+ expect(data.error).toBe('InvalidRequest')
727
+ expect(data.message).toContain('Invalid NSID')
728
+ })
729
+
730
+ it('returns 400 for empty NSID', async () => {
731
+ const router = new LexRouter()
732
+ const request = new Request('https://example.com/xrpc/')
733
+ const response = await router.fetch(request)
734
+
735
+ expect(response.status).toBe(400)
736
+ const data = await response.json()
737
+ expect(data.error).toBe('InvalidRequest')
738
+ })
739
+ })
740
+
741
+ describe('atproto-proxy header', () => {
742
+ it('bypasses local handler when atproto-proxy header is set', async () => {
743
+ const router = new LexRouter().add(io.example.status, handlers.status)
744
+
745
+ const request = new Request(
746
+ 'https://example.com/xrpc/io.example.status',
747
+ { headers: { 'atproto-proxy': 'did:plc:example#atproto_labeler' } },
748
+ )
749
+ const response = await router.fetch(request)
750
+
751
+ // The handler should NOT be called - currently returns MethodNotImplemented
752
+ // because proxy is not yet implemented
753
+ expect(response.status).toBe(501)
754
+ })
755
+
756
+ it('returns 400 for invalid atproto-proxy header format', async () => {
757
+ const router = new LexRouter()
758
+
759
+ const request = new Request(
760
+ 'https://example.com/xrpc/io.example.status',
761
+ { headers: { 'atproto-proxy': 'not-a-valid-proxy' } },
762
+ )
763
+ const response = await router.fetch(request)
764
+
765
+ expect(response.status).toBe(400)
766
+ const data = await response.json()
767
+ expect(data.error).toBe('InvalidRequest')
768
+ expect(data.message).toContain('atproto-proxy')
769
+ })
770
+
771
+ it('returns 400 for atproto-proxy without fragment', async () => {
772
+ const router = new LexRouter()
773
+
774
+ const request = new Request(
775
+ 'https://example.com/xrpc/io.example.status',
776
+ { headers: { 'atproto-proxy': 'did:plc:example' } },
777
+ )
778
+ const response = await router.fetch(request)
779
+
780
+ expect(response.status).toBe(400)
781
+ const data = await response.json()
782
+ expect(data.error).toBe('InvalidRequest')
783
+ })
784
+
785
+ it('returns 400 for atproto-proxy with empty fragment', async () => {
786
+ const router = new LexRouter()
787
+
788
+ const request = new Request(
789
+ 'https://example.com/xrpc/io.example.status',
790
+ { headers: { 'atproto-proxy': 'did:plc:example#' } },
791
+ )
792
+ const response = await router.fetch(request)
793
+
794
+ expect(response.status).toBe(400)
795
+ })
796
+
797
+ it('returns 400 for atproto-proxy with spaces', async () => {
798
+ const router = new LexRouter()
799
+
800
+ const request = new Request(
801
+ 'https://example.com/xrpc/io.example.status',
802
+ { headers: { 'atproto-proxy': 'did:plc:example #service' } },
803
+ )
804
+ const response = await router.fetch(request)
805
+
806
+ expect(response.status).toBe(400)
807
+ })
808
+
809
+ it('returns 400 for atproto-proxy with multiple fragments', async () => {
810
+ const router = new LexRouter()
811
+
812
+ const request = new Request(
813
+ 'https://example.com/xrpc/io.example.status',
814
+ { headers: { 'atproto-proxy': 'did:plc:example#svc#extra' } },
815
+ )
816
+ const response = await router.fetch(request)
817
+
818
+ expect(response.status).toBe(400)
819
+ })
820
+
821
+ it('returns 400 for atproto-proxy with space in fragment', async () => {
822
+ const router = new LexRouter()
823
+
824
+ const request = new Request(
825
+ 'https://example.com/xrpc/io.example.status',
826
+ { headers: { 'atproto-proxy': 'did:plc:example#service id' } },
827
+ )
828
+ const response = await router.fetch(request)
829
+
830
+ expect(response.status).toBe(400)
831
+ })
832
+ })
833
+
834
+ describe('NSID normalization', () => {
835
+ it('matches handler when URL has uppercase domain segments', async () => {
836
+ const router = new LexRouter().add(io.example.status, handlers.status)
837
+
838
+ const request = new Request('https://example.com/xrpc/IO.Example.status')
839
+ const response = await router.fetch(request)
840
+
841
+ expect(response.status).toBe(200)
842
+ expect(await response.json()).toEqual({ status: 'ok' })
843
+ })
844
+
845
+ it('matches handler when URL has mixed-case domain segments', async () => {
846
+ const router = new LexRouter().add(io.example.status, handlers.status)
847
+
848
+ const request = new Request('https://example.com/xrpc/IO.EXAMPLE.status')
849
+ const response = await router.fetch(request)
850
+
851
+ expect(response.status).toBe(200)
852
+ expect(await response.json()).toEqual({ status: 'ok' })
853
+ })
854
+
855
+ it('preserves case sensitivity of method name (last segment)', async () => {
856
+ const router = new LexRouter().add(io.example.status, handlers.status)
857
+
858
+ // "Status" (uppercase S) should not match "status"
859
+ const request = new Request('https://example.com/xrpc/io.example.Status')
860
+ const response = await router.fetch(request)
861
+
862
+ expect(response.status).toBe(501)
863
+ expect(await response.json()).toMatchObject({
864
+ error: 'MethodNotImplemented',
865
+ })
866
+ })
867
+
868
+ it('prevents duplicate registration with different domain casing', async () => {
869
+ const router = new LexRouter().add(io.example.status, handlers.status)
870
+
871
+ expect(() => {
872
+ // Same NSID with different domain casing should be detected as duplicate
873
+ const statusUpperCase = l.query(
874
+ 'IO.Example.status' as 'io.example.status',
875
+ l.params(),
876
+ l.payload('application/json', l.object({ status: l.string() })),
877
+ )
878
+ router.add(statusUpperCase, handlers.status)
879
+ }).toThrow(/already registered/)
880
+ })
881
+ })
882
+
883
+ describe('error handling', () => {
884
+ it('onHandlerError receives LexServerError', async () => {
885
+ const onHandlerError = vi.fn<HandlerErrorHook>()
886
+ const router = new LexRouter({ onHandlerError })
887
+
888
+ router.add(io.example.status, async () => {
889
+ throw new Error('Unexpected error')
890
+ })
891
+
892
+ const request = new Request('https://example.com/xrpc/io.example.status')
893
+ await router.fetch(request)
894
+
895
+ expect(onHandlerError).toHaveBeenCalledTimes(1)
896
+ const ctx = onHandlerError.mock.calls[0][0]
897
+ expect(ctx.error).toBeInstanceOf(LexServerError)
898
+ expect(ctx.error.status).toBe(500)
899
+ expect(ctx.method).toBeDefined()
900
+ expect(ctx.request).toBe(request)
901
+ })
902
+
903
+ it('does not call onHandlerError for aborted requests', async () => {
904
+ const onHandlerError = vi.fn<HandlerErrorHook>()
905
+ const router = new LexRouter({ onHandlerError })
906
+
907
+ router.add(io.example.status, async (_ctx) => {
908
+ const reason = new Error('aborted')
909
+ throw new Error('handler error', { cause: reason })
910
+ })
911
+
912
+ const controller = new AbortController()
913
+ const reason = new Error('aborted')
914
+ controller.abort(reason)
915
+
916
+ const request = new Request(
917
+ 'https://example.com/xrpc/io.example.status',
918
+ { signal: controller.signal },
919
+ )
920
+
921
+ // Need to create a handler that actually throws with the abort reason
922
+ const router2 = new LexRouter({ onHandlerError })
923
+ router2.add(io.example.status, async ({ signal }) => {
924
+ throw new Error('handler error', { cause: signal.reason })
925
+ })
926
+
927
+ const response = await router2.fetch(request)
928
+
929
+ expect(response.status).toBe(499)
930
+ expect(onHandlerError).not.toHaveBeenCalled()
931
+ })
932
+
933
+ it('returns 499 for aborted requests', async () => {
934
+ const controller = new AbortController()
935
+ const reason = new Error('Client disconnected')
936
+ controller.abort(reason)
937
+
938
+ const router = new LexRouter()
939
+ router.add(io.example.status, async () => {
940
+ throw new Error('after abort', { cause: reason })
941
+ })
942
+
943
+ const request = new Request(
944
+ 'https://example.com/xrpc/io.example.status',
945
+ { signal: controller.signal },
946
+ )
947
+ const response = await router.fetch(request)
948
+
949
+ expect(response.status).toBe(499)
950
+ const data = await response.json()
951
+ expect(data.error).toBe('RequestAborted')
952
+ })
953
+ })
954
+ })
955
+
602
956
  // ============================================================================
603
957
  // Parameter Tests (ported from xrpc-server/tests/parameters.test.ts)
604
958
  // ============================================================================
@@ -1605,18 +1959,231 @@ describe('Subscription', () => {
1605
1959
  'XRPC subscriptions are only available over WebSocket',
1606
1960
  )
1607
1961
  })
1962
+
1963
+ it('closes with 1003 when client sends a message to the subscription', async () => {
1964
+ const router = new LexRouter({ upgradeWebSocket }).add(
1965
+ io.example.subscribe,
1966
+ async function* ({ signal }) {
1967
+ while (true) {
1968
+ await scheduler.wait(50, { signal })
1969
+ yield { message: 'ping', count: 1 }
1970
+ }
1971
+ },
1972
+ )
1973
+
1974
+ await using server = await serve(router)
1975
+ const { port } = server.address() as AddressInfo
1976
+
1977
+ const { resolve, reject, promise } = timeoutDeferred<{ code: number }>(5000)
1978
+
1979
+ const ws = new WebSocket(
1980
+ `ws://localhost:${port}/xrpc/io.example.subscribe?message=ping`,
1981
+ )
1982
+ ws.addEventListener('open', () => {
1983
+ ws.send('unexpected message from client')
1984
+ })
1985
+ ws.addEventListener('error', reject)
1986
+ ws.addEventListener('close', resolve)
1987
+
1988
+ const { code } = await promise
1989
+
1990
+ expect(code).toBe(1003)
1991
+ })
1992
+
1993
+ describe('error close codes', () => {
1994
+ const subscribeWithErrors = l.subscription(
1995
+ 'io.example.subscribeWithErrors',
1996
+ l.params(),
1997
+ l.object({ message: l.string() }),
1998
+ ['FutureCursor', 'ConsumerTooSlow'],
1999
+ )
2000
+
2001
+ it('closes with 1008 and sends error frame for known LexError', async () => {
2002
+ const router = new LexRouter({ upgradeWebSocket }).add(
2003
+ subscribeWithErrors,
2004
+ async function* () {
2005
+ yield await Promise.reject(
2006
+ new LexError('FutureCursor', 'Too far in the future'),
2007
+ )
2008
+ },
2009
+ )
2010
+
2011
+ await using server = await serve(router)
2012
+ const { port } = server.address() as AddressInfo
2013
+
2014
+ const { resolve, reject, promise } = timeoutDeferred<{ code: number }>(
2015
+ 5000,
2016
+ )
2017
+ const receivedFrames: unknown[][] = []
2018
+
2019
+ const ws = new WebSocket(
2020
+ `ws://localhost:${port}/xrpc/io.example.subscribeWithErrors`,
2021
+ )
2022
+ ws.binaryType = 'arraybuffer'
2023
+ ws.addEventListener('message', (event) => {
2024
+ const bytes = new Uint8Array(event.data as ArrayBuffer)
2025
+ receivedFrames.push([...decodeAll(bytes)])
2026
+ })
2027
+ ws.addEventListener('close', resolve)
2028
+ ws.addEventListener('error', reject)
2029
+
2030
+ const { code } = await promise
2031
+
2032
+ expect(code).toBe(1008)
2033
+ expect(receivedFrames).toHaveLength(1)
2034
+ const [header, body] = receivedFrames[0]
2035
+ expect(header).toEqual({ op: -1 })
2036
+ expect(body).toMatchObject({ error: 'FutureCursor' })
2037
+ })
2038
+
2039
+ it('closes with 1011 and sends InternalServerError frame for unknown error', async () => {
2040
+ const router = new LexRouter({ upgradeWebSocket }).add(
2041
+ subscribeWithErrors,
2042
+ async function* () {
2043
+ yield await Promise.reject(new Error('unexpected failure'))
2044
+ },
2045
+ )
2046
+
2047
+ await using server = await serve(router)
2048
+ const { port } = server.address() as AddressInfo
2049
+
2050
+ const { resolve, reject, promise } = timeoutDeferred<{ code: number }>(
2051
+ 5000,
2052
+ )
2053
+ const receivedFrames: unknown[][] = []
2054
+
2055
+ const ws = new WebSocket(
2056
+ `ws://localhost:${port}/xrpc/io.example.subscribeWithErrors`,
2057
+ )
2058
+ ws.binaryType = 'arraybuffer'
2059
+ ws.addEventListener('message', (event) => {
2060
+ const bytes = new Uint8Array(event.data as ArrayBuffer)
2061
+ receivedFrames.push([...decodeAll(bytes)])
2062
+ })
2063
+ ws.addEventListener('close', resolve)
2064
+ ws.addEventListener('error', reject)
2065
+
2066
+ const { code } = await promise
2067
+
2068
+ expect(code).toBe(1011)
2069
+ expect(receivedFrames).toHaveLength(1)
2070
+ const [header, body] = receivedFrames[0]
2071
+ expect(header).toEqual({ op: -1 })
2072
+ expect(body).toMatchObject({ error: 'InternalServerError' })
2073
+ })
2074
+
2075
+ it('closes with 1011 for a LexError not listed in method.errors', async () => {
2076
+ const router = new LexRouter({ upgradeWebSocket }).add(
2077
+ subscribeWithErrors,
2078
+ async function* () {
2079
+ yield await Promise.reject(
2080
+ new LexError('SomeOtherError', 'Not a declared error'),
2081
+ )
2082
+ },
2083
+ )
2084
+
2085
+ await using server = await serve(router)
2086
+ const { port } = server.address() as AddressInfo
2087
+
2088
+ const { resolve, reject, promise } = timeoutDeferred<{ code: number }>(
2089
+ 5000,
2090
+ )
2091
+
2092
+ const ws = new WebSocket(
2093
+ `ws://localhost:${port}/xrpc/io.example.subscribeWithErrors`,
2094
+ )
2095
+ ws.addEventListener('close', resolve)
2096
+ ws.addEventListener('error', reject)
2097
+
2098
+ const { code } = await promise
2099
+
2100
+ expect(code).toBe(1011)
2101
+ })
2102
+ })
2103
+
2104
+ describe('onSocketError hook', () => {
2105
+ it('calls onSocketError when the generator throws a non-abort error', async () => {
2106
+ const onSocketError = vi.fn<SocketErrorHook>()
2107
+ const router = new LexRouter({ upgradeWebSocket, onSocketError }).add(
2108
+ io.example.subscribe,
2109
+ async function* () {
2110
+ yield await Promise.reject(new Error('generator failure'))
2111
+ },
2112
+ )
2113
+
2114
+ await using server = await serve(router)
2115
+ const { port } = server.address() as AddressInfo
2116
+
2117
+ const { resolve, reject, promise } = timeoutDeferred<{ code: number }>(
2118
+ 5000,
2119
+ )
2120
+ const ws = new WebSocket(
2121
+ `ws://localhost:${port}/xrpc/io.example.subscribe?message=ping`,
2122
+ )
2123
+ ws.addEventListener('close', resolve)
2124
+ ws.addEventListener('error', reject)
2125
+
2126
+ await promise
2127
+
2128
+ expect(onSocketError).toHaveBeenCalledTimes(1)
2129
+ const ctx = onSocketError.mock.calls[0][0]
2130
+ expect(ctx.error).toBeInstanceOf(Error)
2131
+ expect(ctx.method).toBeDefined()
2132
+ expect(ctx.request).toBeDefined()
2133
+ })
2134
+
2135
+ it('does not call onSocketError when the error matches the abort reason', async () => {
2136
+ const onSocketError = vi.fn<SocketErrorHook>()
2137
+ const router = new LexRouter({ upgradeWebSocket, onSocketError }).add(
2138
+ io.example.subscribe,
2139
+ async function* ({ signal }) {
2140
+ // Wait for abort, then throw with the abort reason as cause
2141
+ await new Promise<void>((_, reject) => {
2142
+ signal.addEventListener('abort', () => {
2143
+ reject(new Error('aborted', { cause: signal.reason }))
2144
+ })
2145
+ })
2146
+ yield { message: 'never', count: 0 }
2147
+ },
2148
+ )
2149
+
2150
+ await using server = await serve(router)
2151
+ const { port } = server.address() as AddressInfo
2152
+
2153
+ const { resolve, reject, promise } = timeoutDeferred<{ code: number }>(
2154
+ 5000,
2155
+ )
2156
+ const ws = new WebSocket(
2157
+ `ws://localhost:${port}/xrpc/io.example.subscribe?message=ping`,
2158
+ )
2159
+ // Close from the client side to trigger the abort
2160
+ ws.addEventListener('open', () => ws.close())
2161
+ ws.addEventListener('close', resolve)
2162
+ ws.addEventListener('error', reject)
2163
+
2164
+ await promise
2165
+
2166
+ expect(onSocketError).not.toHaveBeenCalled()
2167
+ })
2168
+ })
1608
2169
  })
1609
2170
 
1610
- function timeoutDeferred(ms: number) {
1611
- let resolve: () => void
1612
- let reject: (err: unknown) => void
1613
- const promise = new Promise<void>((res, rej) => {
1614
- resolve = res
1615
- reject = rej
2171
+ function defer<T = void>() {
2172
+ let res: (value: T | PromiseLike<T>) => void
2173
+ let rej: (err: unknown) => void
2174
+ const promise = new Promise<T>((resolve, reject) => {
2175
+ res = resolve
2176
+ rej = reject
1616
2177
  })
2178
+ return { resolve: res!, reject: rej!, promise }
2179
+ }
2180
+
2181
+ function timeoutDeferred<T = void>(ms: number) {
2182
+ const { resolve, reject, promise } = defer<T>()
1617
2183
  const to = setTimeout(() => reject(new Error('Timed out')), ms).unref()
1618
- promise.finally(() => {
1619
- clearTimeout(to)
1620
- })
1621
- return { resolve: resolve!, promise }
2184
+ return {
2185
+ resolve,
2186
+ reject,
2187
+ promise: promise.finally(() => clearTimeout(to)),
2188
+ }
1622
2189
  }