dns2 2.1.0 → 2.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/packet.js CHANGED
@@ -1,13 +1,33 @@
1
- const { debuglog } = require('util');
1
+ const { debuglog } = require('node:util');
2
+ const { randomInt } = require('node:crypto');
2
3
  const BufferReader = require('./lib/reader');
3
4
  const BufferWriter = require('./lib/writer');
4
5
 
5
6
  const debug = debuglog('dns2');
6
7
 
7
- const toIPv6 = buffer => buffer
8
- .map(part => (part > 0 ? part.toString(16) : '0'))
9
- .join(':')
10
- .replace(/\b(?:0+:){1,}/, ':');
8
+ // Canonical IPv6 text form per RFC 5952:
9
+ // - lower case hex, no leading zeros per group (handled by toString(16))
10
+ // - the longest run of >= 2 zero groups is replaced with "::"
11
+ // - on ties, the first such run is chosen
12
+ // - a single zero group is NOT compressed
13
+ const toIPv6 = buffer => {
14
+ const segments = buffer.map(part => (part > 0 ? part.toString(16) : '0'));
15
+ let bestStart = -1; let bestLen = 0;
16
+ let curStart = -1; let curLen = 0;
17
+ for (let i = 0; i < segments.length; i++) {
18
+ if (segments[i] === '0') {
19
+ if (curLen === 0) curStart = i;
20
+ curLen++;
21
+ if (curLen > bestLen) { bestLen = curLen; bestStart = curStart; }
22
+ } else {
23
+ curLen = 0;
24
+ }
25
+ }
26
+ if (bestLen < 2) return segments.join(':');
27
+ const before = segments.slice(0, bestStart).join(':');
28
+ const after = segments.slice(bestStart + bestLen).join(':');
29
+ return `${before}::${after}`;
30
+ };
11
31
 
12
32
  const fromIPv6 = (address) => {
13
33
  const digits = address.split(':');
@@ -19,6 +39,13 @@ const fromIPv6 = (address) => {
19
39
  if (digits[digits.length - 1] === '') {
20
40
  digits.pop();
21
41
  }
42
+ // node js 10 does not support Array.prototype.flatMap
43
+ if (!Array.prototype.flatMap) {
44
+ Array.prototype.flatMap = function(f, ctx) {
45
+ return this.reduce((r, x, i, a) => r.concat(f.call(ctx, x, i, a)), []);
46
+ };
47
+ }
48
+
22
49
  // CAVEAT we have to take into account
23
50
  // the extra space used by the empty string
24
51
  const missingFields = 8 - digits.length + 1;
@@ -76,31 +103,32 @@ function Packet(data) {
76
103
  * @docs https://tools.ietf.org/html/rfc1035#section-3.2.2
77
104
  */
78
105
  Packet.TYPE = {
79
- A : 0x01,
80
- NS : 0x02,
81
- MD : 0x03,
82
- MF : 0x04,
83
- CNAME : 0x05,
84
- SOA : 0x06,
85
- MB : 0x07,
86
- MG : 0x08,
87
- MR : 0x09,
88
- NULL : 0x0A,
89
- WKS : 0x0B,
90
- PTR : 0x0C,
91
- HINFO : 0x0D,
92
- MINFO : 0x0E,
93
- MX : 0x0F,
94
- TXT : 0x10,
95
- AAAA : 0x1C,
96
- SRV : 0x21,
97
- EDNS : 0x29,
98
- SPF : 0x63,
99
- AXFR : 0xFC,
100
- MAILB : 0xFD,
101
- MAILA : 0xFE,
102
- ANY : 0xFF,
103
- CAA : 0x101,
106
+ A : 0x01,
107
+ NS : 0x02,
108
+ MD : 0x03,
109
+ MF : 0x04,
110
+ CNAME : 0x05,
111
+ SOA : 0x06,
112
+ MB : 0x07,
113
+ MG : 0x08,
114
+ MR : 0x09,
115
+ NULL : 0x0A,
116
+ WKS : 0x0B,
117
+ PTR : 0x0C,
118
+ HINFO : 0x0D,
119
+ MINFO : 0x0E,
120
+ MX : 0x0F,
121
+ TXT : 0x10,
122
+ AAAA : 0x1C,
123
+ SRV : 0x21,
124
+ EDNS : 0x29,
125
+ SPF : 0x63,
126
+ AXFR : 0xFC,
127
+ MAILB : 0xFD,
128
+ MAILA : 0xFE,
129
+ ANY : 0xFF,
130
+ CAA : 0x101,
131
+ DNSKEY : 0x30,
104
132
  };
105
133
  /**
106
134
  * [QUERY_CLASS description]
@@ -114,6 +142,19 @@ Packet.CLASS = {
114
142
  HS : 0x04,
115
143
  ANY : 0xFF,
116
144
  };
145
+ /**
146
+ * DNS response codes
147
+ * @type {Object}
148
+ * @docs https://tools.ietf.org/html/rfc1035#section-4.1.1
149
+ */
150
+ Packet.RCODE = {
151
+ NOERROR : 0,
152
+ FORMERR : 1,
153
+ SERVFAIL : 2,
154
+ NXDOMAIN : 3,
155
+ NOTIMP : 4,
156
+ REFUSED : 5,
157
+ };
117
158
  /**
118
159
  * [EDNS_OPTION_CODE description]
119
160
  * @type {Object}
@@ -124,11 +165,13 @@ Packet.EDNS_OPTION_CODE = {
124
165
  };
125
166
 
126
167
  /**
127
- * [uuid description]
128
- * @return {[type]} [description]
168
+ * Generate a cryptographically random 16-bit DNS transaction ID.
169
+ * RFC 5452 §3 — the full 16-bit space must be used from a CSPRNG to make
170
+ * response forgery / cache poisoning impractical.
171
+ * @return {number} integer in [0, 0xFFFF]
129
172
  */
130
173
  Packet.uuid = function() {
131
- return Math.floor(Math.random() * 1e5);
174
+ return randomInt(0x10000);
132
175
  };
133
176
 
134
177
  /**
@@ -172,7 +215,6 @@ Object.defineProperty(Packet.prototype, 'recursive', {
172
215
  },
173
216
  set(yn) {
174
217
  this.header.rd = +yn;
175
- return this.header.rd;
176
218
  },
177
219
  });
178
220
 
@@ -217,8 +259,11 @@ Packet.Header = function(header) {
217
259
  this.rd = 0;
218
260
  this.ra = 0;
219
261
  this.z = 0;
262
+ this.ad = 0;
263
+ this.cd = 0;
220
264
  this.rcode = 0;
221
265
  this.qdcount = 0;
266
+ this.ancount = 0;
222
267
  this.nscount = 0;
223
268
  this.arcount = 0;
224
269
  for (const k in header) {
@@ -244,7 +289,10 @@ Packet.Header.parse = function(reader) {
244
289
  header.tc = reader.read(1);
245
290
  header.rd = reader.read(1);
246
291
  header.ra = reader.read(1);
247
- header.z = reader.read(3);
292
+ // RFC 4035 §3.2.3 repurposed the second and third Z bits as AD and CD.
293
+ header.z = reader.read(1);
294
+ header.ad = reader.read(1);
295
+ header.cd = reader.read(1);
248
296
  header.rcode = reader.read(4);
249
297
  header.qdcount = reader.read(16);
250
298
  header.ancount = reader.read(16);
@@ -266,7 +314,9 @@ Packet.Header.prototype.toBuffer = function(writer) {
266
314
  writer.write(this.tc, 1);
267
315
  writer.write(this.rd, 1);
268
316
  writer.write(this.ra, 1);
269
- writer.write(this.z, 3);
317
+ writer.write(this.z, 1);
318
+ writer.write(this.ad, 1);
319
+ writer.write(this.cd, 1);
270
320
  writer.write(this.rcode, 4);
271
321
  writer.write(this.qdcount, 16);
272
322
  writer.write(this.ancount, 16);
@@ -379,9 +429,18 @@ Packet.Resource.encode = function(resource, writer) {
379
429
  })[0];
380
430
  if (encoder in Packet.Resource && Packet.Resource[encoder].encode) {
381
431
  return Packet.Resource[encoder].encode(resource, writer);
382
- } else {
383
- debug('node-dns > unknown encoder %s(%j)', encoder, resource.type);
384
432
  }
433
+ debug('node-dns > unknown encoder %s(%j)', encoder, resource.type);
434
+ // Fallback for unknown / decoder-only types: round-trip the raw RDATA the
435
+ // decoder preserved as `resource.data`. Without this, RDLENGTH and RDATA
436
+ // would be omitted entirely, truncating the wire format and corrupting any
437
+ // records that follow.
438
+ const data = Buffer.isBuffer(resource.data) ? resource.data : Buffer.alloc(0);
439
+ writer.write(data.length, 16);
440
+ for (const byte of data) {
441
+ writer.write(byte, 8);
442
+ }
443
+ return writer.toBuffer();
385
444
  };
386
445
  /**
387
446
  * [parse description]
@@ -425,11 +484,18 @@ Packet.Name = {
425
484
  reader = new Packet.Reader(reader);
426
485
  }
427
486
  const name = []; let o; let len = reader.read(8);
487
+ // Track each pointer target we follow. A crafted packet can chain
488
+ // pointers in a cycle; without this guard, decode would loop forever.
489
+ const visited = new Set();
428
490
  while (len) {
429
491
  if ((len & Packet.Name.COPY) === Packet.Name.COPY) {
430
492
  len -= Packet.Name.COPY;
431
493
  len = len << 8;
432
494
  const pos = len + reader.read(8);
495
+ if (visited.has(pos)) {
496
+ throw new Error('Name decode: pointer cycle detected');
497
+ }
498
+ visited.add(pos);
433
499
  if (!o) o = reader.offset;
434
500
  reader.offset = pos * 8;
435
501
  len = reader.read(8);
@@ -597,7 +663,7 @@ Packet.Resource.SPF =
597
663
  Packet.Resource.TXT = {
598
664
  decode: function(reader, length) {
599
665
  const parts = [];
600
- let bytesRead = 0; let chunkLength = 0;
666
+ let bytesRead = 0; let chunkLength;
601
667
 
602
668
  while (bytesRead < length) {
603
669
  chunkLength = reader.read(8); // text length
@@ -708,19 +774,42 @@ Packet.Resource.SRV = {
708
774
  },
709
775
  };
710
776
 
711
- Packet.Resource.EDNS = function(rdata) {
777
+ // RFC 6891 §6.1.3 the OPT record's TTL field carries:
778
+ // bits 0- 7: extended RCODE (high byte of a 12-bit RCODE)
779
+ // bits 8-15: EDNS version
780
+ // bit 16: DO (DNSSEC OK)
781
+ // bits 17-31: reserved Z, must be zero
782
+ const ednsTtl = (extendedRcode, version, doFlag) =>
783
+ (((extendedRcode & 0xff) << 24) >>> 0)
784
+ | ((version & 0xff) << 16)
785
+ | (doFlag ? 0x8000 : 0);
786
+
787
+ Packet.Resource.EDNS = function(rdata, opts = {}) {
788
+ const extendedRcode = opts.extendedRcode || 0;
789
+ const version = opts.version || 0;
790
+ const doFlag = !!opts.doFlag;
791
+ const udpPayloadSize = opts.udpPayloadSize || 512;
712
792
  return {
713
793
  type : Packet.TYPE.EDNS,
714
- class : 512, // Supported UDP Payload size
715
- ttl : 0, // Extended RCODE and flags
794
+ class : udpPayloadSize,
795
+ ttl : ednsTtl(extendedRcode, version, doFlag),
796
+ extendedRcode,
797
+ version,
798
+ doFlag,
716
799
  rdata, // Objects of type Packet.Resource.EDNS.*
717
800
  };
718
801
  };
719
802
 
720
803
  Packet.Resource.EDNS.decode = function(reader, length) {
721
- this.type = Packet.TYPE.EDNS;
722
- this.class = 512;
723
- this.ttl = 0;
804
+ // When invoked through Resource.parse, this.type/class/ttl are already set
805
+ // from the wire. Direct callers (e.g. unit tests) hit defaults instead.
806
+ this.type = this.type ?? Packet.TYPE.EDNS;
807
+ this.class = this.class ?? 512;
808
+ const ttl = this.ttl ?? 0;
809
+ this.ttl = ttl;
810
+ this.extendedRcode = (ttl >>> 24) & 0xff;
811
+ this.version = (ttl >>> 16) & 0xff;
812
+ this.doFlag = !!(ttl & 0x8000);
724
813
  this.rdata = [];
725
814
 
726
815
  while (length) {
@@ -785,35 +874,75 @@ Packet.Resource.EDNS.ECS.decode = function(reader, length) {
785
874
  rdata.scopePrefixLength = reader.read(8);
786
875
  length -= 4;
787
876
 
788
- if (rdata.family !== 1) {
789
- debug('node-dns > unimplemented address family');
790
- reader.read(length * 8); // Ignore data that doesn't understand
791
- return rdata;
877
+ if (rdata.family === 1) {
878
+ const ipv4Octets = [];
879
+ while (length--) {
880
+ const octet = reader.read(8);
881
+ ipv4Octets.push(octet);
882
+ }
883
+ while (ipv4Octets.length < 4) {
884
+ ipv4Octets.push(0);
885
+ }
886
+ rdata.ip = ipv4Octets.join('.');
792
887
  }
793
888
 
794
- const ipv4Octets = [];
795
- while (length--) {
796
- const octet = reader.read(8);
797
- ipv4Octets.push(octet);
798
- }
799
- while (ipv4Octets.length < 4) {
800
- ipv4Octets.push(0);
889
+ if (rdata.family === 2) {
890
+ const ipv6Segments = [];
891
+ for (; length; length -= 2) {
892
+ const segment = reader.read(16).toString(16);
893
+ ipv6Segments.push(segment);
894
+ }
895
+ while (ipv6Segments.length < 8) {
896
+ ipv6Segments.push('0');
897
+ }
898
+ rdata.ip = ipv6Segments.join(':');
801
899
  }
802
- rdata.ip = ipv4Octets.join('.');
900
+
803
901
  return rdata;
804
902
  };
805
903
 
806
904
  Packet.Resource.EDNS.ECS.encode = function(record, writer) {
807
- const ip = record.ip.split('.').map(s => parseInt(s));
905
+ // RFC 7871 §6: the ADDRESS field carries only the leftmost
906
+ // ceil(sourcePrefixLength / 8) octets.
907
+ const octets = Math.ceil(record.sourcePrefixLength / 8);
808
908
  writer.write(record.family, 16);
809
909
  writer.write(record.sourcePrefixLength, 8);
810
910
  writer.write(record.scopePrefixLength, 8);
811
- writer.write(ip[0], 8);
812
- writer.write(ip[1], 8);
813
- writer.write(ip[2], 8);
814
- writer.write(ip[3], 8);
911
+ let bytes;
912
+ if (record.family === 1) {
913
+ bytes = record.ip.split('.').map(s => parseInt(s, 10) || 0);
914
+ } else if (record.family === 2) {
915
+ bytes = expandIPv6ToBytes(record.ip);
916
+ } else {
917
+ throw new Error(`EDNS.ECS encode: unsupported family ${record.family}`);
918
+ }
919
+ for (let i = 0; i < octets; i++) {
920
+ writer.write(bytes[i] || 0, 8);
921
+ }
815
922
  };
816
923
 
924
+ // Expand a (possibly compressed) IPv6 text address into a 16-byte array.
925
+ function expandIPv6ToBytes(address) {
926
+ let head, tail;
927
+ const idx = address.indexOf('::');
928
+ if (idx === -1) {
929
+ head = address.split(':');
930
+ tail = [];
931
+ } else {
932
+ head = address.slice(0, idx).split(':').filter(Boolean);
933
+ tail = address.slice(idx + 2).split(':').filter(Boolean);
934
+ }
935
+ const missing = 8 - head.length - tail.length;
936
+ const groups = [ ...head, ...new Array(missing).fill('0'), ...tail ];
937
+ const out = new Array(16).fill(0);
938
+ for (let g = 0; g < 8; g++) {
939
+ const n = parseInt(groups[g], 16) || 0;
940
+ out[g * 2] = (n >> 8) & 0xff;
941
+ out[g * 2 + 1] = n & 0xff;
942
+ }
943
+ return out;
944
+ }
945
+
817
946
  Packet.Resource.CAA = {
818
947
  encode: function(record, writer) {
819
948
  writer = writer || new Packet.Writer();
@@ -828,15 +957,135 @@ Packet.Resource.CAA = {
828
957
  });
829
958
  return writer.toBuffer();
830
959
  },
960
+ decode: function(reader, length) {
961
+ this.flags = reader.read(8);
962
+ const tagLength = reader.read(8);
963
+ const bytes = [];
964
+ let remaining = length - 2;
965
+ while (remaining--) bytes.push(reader.read(8));
966
+ const buffer = Buffer.from(bytes);
967
+ this.tag = buffer.slice(0, tagLength).toString('utf8');
968
+ this.value = buffer.slice(tagLength).toString('utf8');
969
+ return this;
970
+ },
971
+ };
972
+
973
+ /**
974
+ * @type {{decode: (function(*, *): Packet.Resource.DNSKEY)}}
975
+ * @link https://tools.ietf.org/html/rfc4034
976
+ * @link https://www.iana.org/assignments/dns-sec-alg-numbers/dns-sec-alg-numbers.xhtml#table-dns-sec-alg-numbers-1
977
+ */
978
+ Packet.Resource.DNSKEY = {
979
+ decode: function(reader, length) {
980
+ const RData = [];
981
+ while (RData.length < length) {
982
+ RData.push(reader.read(8));
983
+ }
984
+ this.flags = RData[0] << 8 | RData[1];
985
+ this.protocol = RData[2];
986
+ this.algorithm = RData[3];
987
+ // for key tag
988
+ let ac = 0;
989
+ for (let i = 0; i < length; ++i) {
990
+ ac += (i & 1) ? RData[i] : RData[i] << 8;
991
+ }
992
+ ac += (ac >> 16) & 0xFFFF;
993
+ this.keyTag = ac & 0XFFFF;
994
+
995
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 = 16
996
+ // convert binary flags
997
+ let binFlags = this.flags.toString(2);
998
+ // add left padding until 16 chars
999
+ while (binFlags.length < 16) {
1000
+ binFlags = '0' + binFlags;
1001
+ }
1002
+ this.zoneKey = binFlags[7] === '1';
1003
+ this.zoneSep = binFlags[15] === '1';
1004
+ this.key = Buffer.from(RData.slice(4)).toString('base64');
1005
+ return this;
1006
+ },
1007
+ encode: function(record, writer) {
1008
+ writer = writer || new Packet.Writer();
1009
+ const buffer = Buffer.from(record.key, 'base64');
1010
+ writer.write(4 + buffer.length, 16);
1011
+ writer.write(record.flags, 16);
1012
+ writer.write(record.protocol, 8);
1013
+ writer.write(record.algorithm, 8);
1014
+ buffer.forEach(function(c) {
1015
+ writer.write(c, 8);
1016
+ });
1017
+ return writer.toBuffer();
1018
+ },
1019
+ };
1020
+
1021
+ /**
1022
+ * RRSIG just support decode
1023
+ * test with dns.resolveRRSIG('example.com')
1024
+ *
1025
+ * @type {{decode: (function(*, *): Packet.Resource.RRSIG)}}
1026
+ */
1027
+ Packet.Resource.RRSIG = {
1028
+ decode: function(reader, length) {
1029
+ function dateForSig(date) {
1030
+ // javascript date is from millisecond
1031
+ date = new Date(date * 1000);
1032
+ const definitions = {
1033
+ month : (date.getUTCMonth() + 1),
1034
+ date : date.getUTCDate(),
1035
+ hour : date.getUTCHours(),
1036
+ minutes : date.getUTCMinutes(),
1037
+ seconds : date.getUTCSeconds(),
1038
+ };
1039
+ let i;
1040
+ for (i in definitions) {
1041
+ // if less than 10 > single
1042
+ if (definitions[i] < 10) {
1043
+ definitions[i] = '0' + '' + definitions[i];
1044
+ }
1045
+ }
1046
+ return date.getFullYear() + '' +
1047
+ definitions.month + '' +
1048
+ definitions.date + '' +
1049
+ definitions.hour + '' +
1050
+ definitions.minutes + '' +
1051
+ definitions.seconds;
1052
+ }
1053
+
1054
+ // calculate max-offset uint8
1055
+ const maxOffset = reader.offset + (length * 8);
1056
+ /*
1057
+ * Stuff sign contains 18 octets
1058
+ */
1059
+ this.sigType = reader.read(16); // 2
1060
+ this.algorithm = reader.read(8); // 1
1061
+ this.labels = reader.read(8); // 1
1062
+ this.originalTtl = reader.read(32); // 4
1063
+ this.expiration = dateForSig(reader.read(32)); // 4
1064
+ this.inception = dateForSig(reader.read(32)); // 4
1065
+ this.keyTag = reader.read(16); // 2
1066
+ this.signer = Packet.Name.decode(reader);
1067
+ const maxLength = (maxOffset - reader.offset) / 8;
1068
+ const signature = [];
1069
+ while (signature.length < maxLength) {
1070
+ signature.push(reader.read(8));
1071
+ }
1072
+ this.signature = Buffer.from(signature).toString('base64');
1073
+ return this;
1074
+ },
831
1075
  };
832
1076
 
833
1077
  Packet.Reader = BufferReader;
834
1078
  Packet.Writer = BufferWriter;
835
1079
 
836
1080
  Packet.createResponseFromRequest = function(request) {
837
- const response = new Packet(request);
838
- response.header.qr = 1;
839
- response.additionals = [];
1081
+ const response = new Packet();
1082
+ response.header = new Packet.Header({
1083
+ id : request.header.id,
1084
+ opcode : request.header.opcode,
1085
+ rd : request.header.rd,
1086
+ qr : 1,
1087
+ });
1088
+ response.questions = request.questions.slice();
840
1089
  return response;
841
1090
  };
842
1091
 
package/server/dns.js CHANGED
@@ -1,4 +1,5 @@
1
- const EventEmitter = require('events');
1
+ const EventEmitter = require('node:events');
2
+ const Packet = require('../packet');
2
3
  const DOHServer = require('./doh');
3
4
  const TCPServer = require('./tcp');
4
5
  const UDPServer = require('./udp');
@@ -34,7 +35,23 @@ class DNSServer extends EventEmitter {
34
35
  return addresses;
35
36
  });
36
37
 
37
- const emitRequest = (request, send, client) => this.emit('request', request, send, client);
38
+ const maxConcurrent = options.maxConcurrent > 0 ? options.maxConcurrent : 0;
39
+ let active = 0;
40
+
41
+ const emitRequest = (request, send, client) => {
42
+ if (maxConcurrent && active >= maxConcurrent) {
43
+ const response = Packet.createResponseFromRequest(request);
44
+ response.header.rcode = Packet.RCODE.SERVFAIL;
45
+ send(response);
46
+ return;
47
+ }
48
+ active++;
49
+ const wrappedSend = (...args) => {
50
+ active--;
51
+ return send(...args);
52
+ };
53
+ this.emit('request', request, wrappedSend, client);
54
+ };
38
55
  const emitRequestError = (error) => this.emit('requestError', error);
39
56
  for (const server of servers) {
40
57
  server.on('request', emitRequest);
package/server/doh.js CHANGED
@@ -1,9 +1,9 @@
1
- const http = require('http');
2
- const https = require('https');
3
- const { URL } = require('url');
1
+ const http = require('node:http');
2
+ const https = require('node:https');
3
+ const { URL } = require('node:url');
4
4
  const Packet = require('../packet');
5
- const EventEmitter = require('events');
6
- const { debuglog } = require('util');
5
+ const EventEmitter = require('node:events');
6
+ const { debuglog } = require('node:util');
7
7
 
8
8
  const debug = debuglog('dns2-server');
9
9
 
@@ -20,11 +20,11 @@ const decodeBase64URL = str => {
20
20
  };
21
21
 
22
22
  const readStream = stream => new Promise((resolve, reject) => {
23
- let buffer = '';
23
+ const chunks = [];
24
24
  stream
25
25
  .on('error', reject)
26
- .on('data', chunk => { buffer += chunk; })
27
- .on('end', () => resolve(buffer));
26
+ .on('data', chunk => chunks.push(chunk))
27
+ .on('end', () => resolve(Buffer.concat(chunks)));
28
28
  });
29
29
 
30
30
  class Server extends EventEmitter {
@@ -139,6 +139,7 @@ class Server extends EventEmitter {
139
139
  }
140
140
 
141
141
  close() {
142
+ this.server.closeIdleConnections();
142
143
  return this.server.close();
143
144
  }
144
145
  }
package/server/tcp.js CHANGED
@@ -1,17 +1,31 @@
1
- const tcp = require('net');
1
+ const tcp = require('node:net');
2
2
  const Packet = require('../packet');
3
+ const proxyProtocol = require('../lib/proxy-protocol');
3
4
 
4
5
  class Server extends tcp.Server {
5
6
  constructor(options) {
6
7
  super();
8
+ let proxyProtocolEnabled = false;
9
+ if (typeof options === 'object' && options !== null) {
10
+ proxyProtocolEnabled = options.proxyProtocol ?? false;
11
+ }
7
12
  if (typeof options === 'function') {
8
13
  this.on('request', options);
9
14
  }
15
+ this.proxyProtocol = proxyProtocolEnabled;
10
16
  this.on('connection', this.handle.bind(this));
11
17
  }
12
18
 
13
19
  async handle(client) {
14
20
  try {
21
+ if (this.proxyProtocol) {
22
+ const header = await consumeProxyHeader(client);
23
+ client.proxy = header;
24
+ if (header.command === 'PROXY') {
25
+ client.proxyAddress = header.sourceAddress;
26
+ client.proxyPort = header.sourcePort;
27
+ }
28
+ }
15
29
  const data = await Packet.readStream(client);
16
30
  const message = Packet.parse(data);
17
31
  this.emit('request', message, this.response.bind(this, client), client);
@@ -31,4 +45,61 @@ class Server extends tcp.Server {
31
45
  }
32
46
  }
33
47
 
48
+ // Read and consume the PROXY header from the front of the socket's stream.
49
+ // Any bytes that arrive past the header are unshifted back into the socket
50
+ // so the next reader (Packet.readStream) sees them.
51
+ function consumeProxyHeader(socket) {
52
+ return new Promise((resolve, reject) => {
53
+ const chunks = [];
54
+ let chunklen = 0;
55
+ let done = false;
56
+
57
+ const cleanup = () => {
58
+ socket.removeListener('readable', onReadable);
59
+ socket.removeListener('end', onEnd);
60
+ socket.removeListener('error', onError);
61
+ };
62
+ const onError = err => {
63
+ if (done) return;
64
+ done = true;
65
+ cleanup();
66
+ reject(err);
67
+ };
68
+ const onEnd = () => {
69
+ if (done) return;
70
+ done = true;
71
+ cleanup();
72
+ reject(new Error('PROXY protocol: stream ended before header complete'));
73
+ };
74
+ const onReadable = () => {
75
+ if (done) return;
76
+ let chunk;
77
+ while ((chunk = socket.read()) !== null) {
78
+ chunks.push(chunk);
79
+ chunklen += chunk.length;
80
+ }
81
+ if (chunklen === 0) return;
82
+ const buffer = Buffer.concat(chunks, chunklen);
83
+ let parsed;
84
+ try {
85
+ parsed = proxyProtocol.parse(buffer);
86
+ } catch (e) {
87
+ return onError(e);
88
+ }
89
+ if (!parsed) return;
90
+ done = true;
91
+ cleanup();
92
+ const leftover = buffer.slice(parsed.headerLength);
93
+ if (leftover.length) socket.unshift(leftover);
94
+ resolve(parsed.header);
95
+ };
96
+
97
+ socket.on('readable', onReadable);
98
+ socket.on('end', onEnd);
99
+ socket.on('error', onError);
100
+ // Drain anything already buffered before our 'readable' listener attached.
101
+ onReadable();
102
+ });
103
+ }
104
+
34
105
  module.exports = Server;