opal-up 0.0.4 → 0.0.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -18,489 +18,532 @@
18
18
  #ifndef UWS_WEBSOCKETPROTOCOL_H
19
19
  #define UWS_WEBSOCKETPROTOCOL_H
20
20
 
21
- #include "libusockets.h"
21
+ #include <libusockets.h>
22
22
 
23
23
  #include <cstdint>
24
- #include <cstring>
25
24
  #include <cstdlib>
25
+ #include <cstring>
26
26
  #include <string_view>
27
27
 
28
28
  namespace uWS {
29
29
 
30
30
  /* We should not overcomplicate these */
31
31
  const std::string_view ERR_TOO_BIG_MESSAGE("Received too big message");
32
- const std::string_view ERR_WEBSOCKET_TIMEOUT("WebSocket timed out from inactivity");
32
+ const std::string_view
33
+ ERR_WEBSOCKET_TIMEOUT("WebSocket timed out from inactivity");
33
34
  const std::string_view ERR_INVALID_TEXT("Received invalid UTF-8");
34
- const std::string_view ERR_TOO_BIG_MESSAGE_INFLATION("Received too big message, or other inflation error");
35
- const std::string_view ERR_INVALID_CLOSE_PAYLOAD("Received invalid close payload");
35
+ const std::string_view ERR_TOO_BIG_MESSAGE_INFLATION(
36
+ "Received too big message, or other inflation error");
37
+ const std::string_view
38
+ ERR_INVALID_CLOSE_PAYLOAD("Received invalid close payload");
36
39
 
37
40
  enum OpCode : unsigned char {
38
- CONTINUATION = 0,
39
- TEXT = 1,
40
- BINARY = 2,
41
- CLOSE = 8,
42
- PING = 9,
43
- PONG = 10
41
+ CONTINUATION = 0,
42
+ TEXT = 1,
43
+ BINARY = 2,
44
+ CLOSE = 8,
45
+ PING = 9,
46
+ PONG = 10
44
47
  };
45
48
 
46
- enum {
47
- CLIENT,
48
- SERVER
49
- };
49
+ enum { CLIENT, SERVER };
50
50
 
51
51
  // 24 bytes perfectly
52
- template <bool isServer>
53
- struct WebSocketState {
52
+ template <bool isServer> struct WebSocketState {
54
53
  public:
55
- static const unsigned int SHORT_MESSAGE_HEADER = isServer ? 6 : 2;
56
- static const unsigned int MEDIUM_MESSAGE_HEADER = isServer ? 8 : 4;
57
- static const unsigned int LONG_MESSAGE_HEADER = isServer ? 14 : 10;
58
-
59
- // 16 bytes
60
- struct State {
61
- unsigned int wantsHead : 1;
62
- unsigned int spillLength : 4;
63
- signed int opStack : 2; // -1, 0, 1
64
- unsigned int lastFin : 1;
65
-
66
- // 15 bytes
67
- unsigned char spill[LONG_MESSAGE_HEADER - 1];
68
- OpCode opCode[2];
69
-
70
- State() {
71
- wantsHead = true;
72
- spillLength = 0;
73
- opStack = -1;
74
- lastFin = true;
75
- }
54
+ static const unsigned int SHORT_MESSAGE_HEADER = isServer ? 6 : 2;
55
+ static const unsigned int MEDIUM_MESSAGE_HEADER = isServer ? 8 : 4;
56
+ static const unsigned int LONG_MESSAGE_HEADER = isServer ? 14 : 10;
57
+
58
+ // 16 bytes
59
+ struct State {
60
+ unsigned int wantsHead : 1;
61
+ unsigned int spillLength : 4;
62
+ signed int opStack : 2; // -1, 0, 1
63
+ unsigned int lastFin : 1;
64
+
65
+ // 15 bytes
66
+ unsigned char spill[LONG_MESSAGE_HEADER - 1];
67
+ OpCode opCode[2];
68
+
69
+ State() {
70
+ wantsHead = true;
71
+ spillLength = 0;
72
+ opStack = -1;
73
+ lastFin = true;
74
+ }
76
75
 
77
- } state;
76
+ } state;
78
77
 
79
- // 8 bytes
80
- unsigned int remainingBytes = 0;
81
- char mask[isServer ? 4 : 1];
78
+ // 8 bytes
79
+ unsigned int remainingBytes = 0;
80
+ char mask[isServer ? 4 : 1];
82
81
  };
83
82
 
84
83
  namespace protocol {
85
84
 
86
- template <typename T>
87
- T bit_cast(char *c) {
88
- T val;
89
- memcpy(&val, c, sizeof(T));
90
- return val;
85
+ template <typename T> T bit_cast(char *c) {
86
+ T val;
87
+ memcpy(&val, c, sizeof(T));
88
+ return val;
91
89
  }
92
90
 
93
91
  /* Byte swap for little-endian systems */
94
- template <typename T>
95
- T cond_byte_swap(T value) {
96
- uint32_t endian_test = 1;
97
- if (*((char *)&endian_test)) {
98
- union {
99
- T i;
100
- uint8_t b[sizeof(T)];
101
- } src = { value }, dst;
102
-
103
- for (unsigned int i = 0; i < sizeof(value); i++) {
104
- dst.b[i] = src.b[sizeof(value) - 1 - i];
105
- }
106
-
107
- return dst.i;
92
+ template <typename T> T cond_byte_swap(T value) {
93
+ uint32_t endian_test = 1;
94
+ if (*((char *)&endian_test)) {
95
+ union {
96
+ T i;
97
+ uint8_t b[sizeof(T)];
98
+ } src = {value}, dst;
99
+
100
+ for (unsigned int i = 0; i < sizeof(value); i++) {
101
+ dst.b[i] = src.b[sizeof(value) - 1 - i];
108
102
  }
109
- return value;
103
+
104
+ return dst.i;
105
+ }
106
+ return value;
110
107
  }
111
108
 
112
109
  // Based on utf8_check.c by Markus Kuhn, 2005
113
110
  // https://www.cl.cam.ac.uk/~mgk25/ucs/utf8_check.c
114
111
  // Optimized for predominantly 7-bit content by Alex Hultman, 2016
115
112
  // Licensed as Zlib, like the rest of this project
116
- static bool isValidUtf8(unsigned char *s, size_t length)
117
- {
118
- for (unsigned char *e = s + length; s != e; ) {
119
- if (s + 4 <= e) {
120
- uint32_t tmp;
121
- memcpy(&tmp, s, 4);
122
- if ((tmp & 0x80808080) == 0) {
123
- s += 4;
124
- continue;
125
- }
126
- }
113
+ static bool isValidUtf8(unsigned char *s, size_t length) {
114
+ for (unsigned char *e = s + length; s != e;) {
115
+ if (s + 4 <= e) {
116
+ uint32_t tmp;
117
+ memcpy(&tmp, s, 4);
118
+ if ((tmp & 0x80808080) == 0) {
119
+ s += 4;
120
+ continue;
121
+ }
122
+ }
127
123
 
128
- while (!(*s & 0x80)) {
129
- if (++s == e) {
130
- return true;
131
- }
132
- }
124
+ while (!(*s & 0x80)) {
125
+ if (++s == e) {
126
+ return true;
127
+ }
128
+ }
133
129
 
134
- if ((s[0] & 0x60) == 0x40) {
135
- if (s + 1 >= e || (s[1] & 0xc0) != 0x80 || (s[0] & 0xfe) == 0xc0) {
136
- return false;
137
- }
138
- s += 2;
139
- } else if ((s[0] & 0xf0) == 0xe0) {
140
- if (s + 2 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 ||
141
- (s[0] == 0xe0 && (s[1] & 0xe0) == 0x80) || (s[0] == 0xed && (s[1] & 0xe0) == 0xa0)) {
142
- return false;
143
- }
144
- s += 3;
145
- } else if ((s[0] & 0xf8) == 0xf0) {
146
- if (s + 3 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 || (s[3] & 0xc0) != 0x80 ||
147
- (s[0] == 0xf0 && (s[1] & 0xf0) == 0x80) || (s[0] == 0xf4 && s[1] > 0x8f) || s[0] > 0xf4) {
148
- return false;
149
- }
150
- s += 4;
151
- } else {
152
- return false;
153
- }
130
+ if ((s[0] & 0x60) == 0x40) {
131
+ if (s + 1 >= e || (s[1] & 0xc0) != 0x80 || (s[0] & 0xfe) == 0xc0) {
132
+ return false;
133
+ }
134
+ s += 2;
135
+ } else if ((s[0] & 0xf0) == 0xe0) {
136
+ if (s + 2 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 ||
137
+ (s[0] == 0xe0 && (s[1] & 0xe0) == 0x80) ||
138
+ (s[0] == 0xed && (s[1] & 0xe0) == 0xa0)) {
139
+ return false;
140
+ }
141
+ s += 3;
142
+ } else if ((s[0] & 0xf8) == 0xf0) {
143
+ if (s + 3 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 ||
144
+ (s[3] & 0xc0) != 0x80 || (s[0] == 0xf0 && (s[1] & 0xf0) == 0x80) ||
145
+ (s[0] == 0xf4 && s[1] > 0x8f) || s[0] > 0xf4) {
146
+ return false;
147
+ }
148
+ s += 4;
149
+ } else {
150
+ return false;
154
151
  }
155
- return true;
152
+ }
153
+ return true;
156
154
  }
157
155
 
158
156
  struct CloseFrame {
159
- uint16_t code;
160
- char *message;
161
- size_t length;
157
+ uint16_t code;
158
+ char *message;
159
+ size_t length;
162
160
  };
163
161
 
164
162
  static inline CloseFrame parseClosePayload(char *src, size_t length) {
165
- /* If we get no code or message, default to reporting 1005 no status code present */
166
- CloseFrame cf = {1005, nullptr, 0};
167
- if (length >= 2) {
168
- memcpy(&cf.code, src, 2);
169
- cf = {cond_byte_swap<uint16_t>(cf.code), src + 2, length - 2};
170
- if (cf.code < 1000 || cf.code > 4999 || (cf.code > 1011 && cf.code < 4000) ||
171
- (cf.code >= 1004 && cf.code <= 1006) || !isValidUtf8((unsigned char *) cf.message, cf.length)) {
172
- /* Even though we got a WebSocket close frame, it in itself is abnormal */
173
- return {1006, nullptr, 0};
174
- }
163
+ /* If we get no code or message, default to reporting 1005 no status code
164
+ * present */
165
+ CloseFrame cf = {1005, nullptr, 0};
166
+ if (length >= 2) {
167
+ memcpy(&cf.code, src, 2);
168
+ cf = {cond_byte_swap<uint16_t>(cf.code), src + 2, length - 2};
169
+ if (cf.code < 1000 || cf.code > 4999 ||
170
+ (cf.code > 1011 && cf.code < 4000) ||
171
+ (cf.code >= 1004 && cf.code <= 1006) ||
172
+ !isValidUtf8((unsigned char *)cf.message, cf.length)) {
173
+ /* Even though we got a WebSocket close frame, it in itself is abnormal */
174
+ return {1006, nullptr, 0};
175
175
  }
176
- return cf;
176
+ }
177
+ return cf;
177
178
  }
178
179
 
179
- static inline size_t formatClosePayload(char *dst, uint16_t code, const char *message, size_t length) {
180
- /* We could have more strict checks here, but never append code 0 or 1005 or 1006 */
181
- if (code && code != 1005 && code != 1006) {
182
- code = cond_byte_swap<uint16_t>(code);
183
- memcpy(dst, &code, 2);
184
- /* It is invalid to pass nullptr to memcpy, even though length is 0 */
185
- if (message) {
186
- memcpy(dst + 2, message, length);
187
- }
188
- return length + 2;
180
+ static inline size_t formatClosePayload(char *dst, uint16_t code,
181
+ const char *message, size_t length) {
182
+ /* We could have more strict checks here, but never append code 0 or 1005 or
183
+ * 1006 */
184
+ if (code && code != 1005 && code != 1006) {
185
+ code = cond_byte_swap<uint16_t>(code);
186
+ memcpy(dst, &code, 2);
187
+ /* It is invalid to pass nullptr to memcpy, even though length is 0 */
188
+ if (message) {
189
+ memcpy(dst + 2, message, length);
189
190
  }
190
- return 0;
191
+ return length + 2;
192
+ }
193
+ return 0;
191
194
  }
192
195
 
193
196
  static inline size_t messageFrameSize(size_t messageSize) {
194
- if (messageSize < 126) {
195
- return 2 + messageSize;
196
- } else if (messageSize <= UINT16_MAX) {
197
- return 4 + messageSize;
198
- }
199
- return 10 + messageSize;
197
+ if (messageSize < 126) {
198
+ return 2 + messageSize;
199
+ } else if (messageSize <= UINT16_MAX) {
200
+ return 4 + messageSize;
201
+ }
202
+ return 10 + messageSize;
200
203
  }
201
204
 
202
- enum {
203
- SND_CONTINUATION = 1,
204
- SND_NO_FIN = 2,
205
- SND_COMPRESSED = 64
206
- };
205
+ enum { SND_CONTINUATION = 1, SND_NO_FIN = 2, SND_COMPRESSED = 64 };
207
206
 
208
207
  template <bool isServer>
209
- static inline size_t formatMessage(char *dst, const char *src, size_t length, OpCode opCode, size_t reportedLength, bool compressed, bool fin) {
210
- size_t messageLength;
211
- size_t headerLength;
212
- if (reportedLength < 126) {
213
- headerLength = 2;
214
- dst[1] = (char) reportedLength;
215
- } else if (reportedLength <= UINT16_MAX) {
216
- headerLength = 4;
217
- dst[1] = 126;
218
- uint16_t tmp = cond_byte_swap<uint16_t>((uint16_t) reportedLength);
219
- memcpy(&dst[2], &tmp, sizeof(uint16_t));
220
- } else {
221
- headerLength = 10;
222
- dst[1] = 127;
223
- uint64_t tmp = cond_byte_swap<uint64_t>((uint64_t) reportedLength);
224
- memcpy(&dst[2], &tmp, sizeof(uint64_t));
225
- }
226
-
227
- dst[0] = (char) ((fin ? 128 : 0) | ((compressed && opCode) ? SND_COMPRESSED : 0) | (char) opCode);
228
-
229
- //printf("%d\n", (int)dst[0]);
230
-
231
- char mask[4];
232
- if (!isServer) {
233
- dst[1] |= 0x80;
234
- uint32_t random = (uint32_t) rand();
235
- memcpy(mask, &random, 4);
236
- memcpy(dst + headerLength, &random, 4);
237
- headerLength += 4;
238
- }
239
-
240
- messageLength = headerLength + length;
241
- memcpy(dst + headerLength, src, length);
242
-
243
- if (!isServer) {
244
-
245
- // overwrites up to 3 bytes outside of the given buffer!
246
- //WebSocketProtocol<isServer>::unmaskInplace(dst + headerLength, dst + headerLength + length, mask);
247
-
248
- // this is not optimal
249
- char *start = dst + headerLength;
250
- char *stop = start + length;
251
- int i = 0;
252
- while (start != stop) {
253
- (*start++) ^= mask[i++ % 4];
254
- }
208
+ static inline size_t formatMessage(char *dst, const char *src, size_t length,
209
+ OpCode opCode, size_t reportedLength,
210
+ bool compressed, bool fin) {
211
+ size_t messageLength;
212
+ size_t headerLength;
213
+ if (reportedLength < 126) {
214
+ headerLength = 2;
215
+ dst[1] = (char)reportedLength;
216
+ } else if (reportedLength <= UINT16_MAX) {
217
+ headerLength = 4;
218
+ dst[1] = 126;
219
+ uint16_t tmp = cond_byte_swap<uint16_t>((uint16_t)reportedLength);
220
+ memcpy(&dst[2], &tmp, sizeof(uint16_t));
221
+ } else {
222
+ headerLength = 10;
223
+ dst[1] = 127;
224
+ uint64_t tmp = cond_byte_swap<uint64_t>((uint64_t)reportedLength);
225
+ memcpy(&dst[2], &tmp, sizeof(uint64_t));
226
+ }
227
+
228
+ dst[0] = (char)((fin ? 128 : 0) |
229
+ ((compressed && opCode) ? SND_COMPRESSED : 0) | (char)opCode);
230
+
231
+ // printf("%d\n", (int)dst[0]);
232
+
233
+ char mask[4];
234
+ if (!isServer) {
235
+ dst[1] |= 0x80;
236
+ uint32_t random = (uint32_t)rand();
237
+ memcpy(mask, &random, 4);
238
+ memcpy(dst + headerLength, &random, 4);
239
+ headerLength += 4;
240
+ }
241
+
242
+ messageLength = headerLength + length;
243
+ memcpy(dst + headerLength, src, length);
244
+
245
+ if (!isServer) {
246
+
247
+ // overwrites up to 3 bytes outside of the given buffer!
248
+ // WebSocketProtocol<isServer>::unmaskInplace(dst + headerLength, dst +
249
+ // headerLength + length, mask);
250
+
251
+ // this is not optimal
252
+ char *start = dst + headerLength;
253
+ char *stop = start + length;
254
+ int i = 0;
255
+ while (start != stop) {
256
+ (*start++) ^= mask[i++ % 4];
255
257
  }
256
- return messageLength;
258
+ }
259
+ return messageLength;
257
260
  }
258
261
 
259
- }
262
+ } // namespace protocol
260
263
 
261
264
  // essentially this is only a parser
262
- template <const bool isServer, typename Impl>
263
- struct WebSocketProtocol {
265
+ template <const bool isServer, typename Impl> struct WebSocketProtocol {
264
266
  public:
265
- static const unsigned int SHORT_MESSAGE_HEADER = isServer ? 6 : 2;
266
- static const unsigned int MEDIUM_MESSAGE_HEADER = isServer ? 8 : 4;
267
- static const unsigned int LONG_MESSAGE_HEADER = isServer ? 14 : 10;
267
+ static const unsigned int SHORT_MESSAGE_HEADER = isServer ? 6 : 2;
268
+ static const unsigned int MEDIUM_MESSAGE_HEADER = isServer ? 8 : 4;
269
+ static const unsigned int LONG_MESSAGE_HEADER = isServer ? 14 : 10;
268
270
 
269
271
  protected:
270
- static inline bool isFin(char *frame) {return *((unsigned char *) frame) & 128;}
271
- static inline unsigned char getOpCode(char *frame) {return *((unsigned char *) frame) & 15;}
272
- static inline unsigned char payloadLength(char *frame) {return ((unsigned char *) frame)[1] & 127;}
273
- static inline bool rsv23(char *frame) {return *((unsigned char *) frame) & 48;}
274
- static inline bool rsv1(char *frame) {return *((unsigned char *) frame) & 64;}
275
-
276
- template <int N>
277
- static inline void UnrolledXor(char * __restrict data, char * __restrict mask) {
278
- if constexpr (N != 1) {
279
- UnrolledXor<N - 1>(data, mask);
280
- }
281
- data[N - 1] ^= mask[(N - 1) % 4];
272
+ static inline bool isFin(char *frame) {
273
+ return *((unsigned char *)frame) & 128;
274
+ }
275
+ static inline unsigned char getOpCode(char *frame) {
276
+ return *((unsigned char *)frame) & 15;
277
+ }
278
+ static inline unsigned char payloadLength(char *frame) {
279
+ return ((unsigned char *)frame)[1] & 127;
280
+ }
281
+ static inline bool rsv23(char *frame) {
282
+ return *((unsigned char *)frame) & 48;
283
+ }
284
+ static inline bool rsv1(char *frame) {
285
+ return *((unsigned char *)frame) & 64;
286
+ }
287
+
288
+ template <int N>
289
+ static inline void UnrolledXor(char *__restrict data, char *__restrict mask) {
290
+ if constexpr (N != 1) {
291
+ UnrolledXor<N - 1>(data, mask);
282
292
  }
283
-
284
- template <int DESTINATION>
285
- static inline void unmaskImprecise8(char *src, uint64_t mask, unsigned int length) {
286
- for (unsigned int n = (length >> 3) + 1; n; n--) {
287
- uint64_t loaded;
288
- memcpy(&loaded, src, 8);
289
- loaded ^= mask;
290
- memcpy(src - DESTINATION, &loaded, 8);
291
- src += 8;
292
- }
293
+ data[N - 1] ^= mask[(N - 1) % 4];
294
+ }
295
+
296
+ template <int DESTINATION>
297
+ static inline void unmaskImprecise8(char *src, uint64_t mask,
298
+ unsigned int length) {
299
+ for (unsigned int n = (length >> 3) + 1; n; n--) {
300
+ uint64_t loaded;
301
+ memcpy(&loaded, src, 8);
302
+ loaded ^= mask;
303
+ memcpy(src - DESTINATION, &loaded, 8);
304
+ src += 8;
293
305
  }
294
-
295
- /* DESTINATION = 6 makes this not SIMD, DESTINATION = 4 is with SIMD but we don't want that for short messages */
296
- template <int DESTINATION>
297
- static inline void unmaskImprecise4(char *src, uint32_t mask, unsigned int length) {
298
- for (unsigned int n = (length >> 2) + 1; n; n--) {
299
- uint32_t loaded;
300
- memcpy(&loaded, src, 4);
301
- loaded ^= mask;
302
- memcpy(src - DESTINATION, &loaded, 4);
303
- src += 4;
304
- }
306
+ }
307
+
308
+ /* DESTINATION = 6 makes this not SIMD, DESTINATION = 4 is with SIMD but we
309
+ * don't want that for short messages */
310
+ template <int DESTINATION>
311
+ static inline void unmaskImprecise4(char *src, uint32_t mask,
312
+ unsigned int length) {
313
+ for (unsigned int n = (length >> 2) + 1; n; n--) {
314
+ uint32_t loaded;
315
+ memcpy(&loaded, src, 4);
316
+ loaded ^= mask;
317
+ memcpy(src - DESTINATION, &loaded, 4);
318
+ src += 4;
305
319
  }
306
-
307
- template <int HEADER_SIZE>
308
- static inline void unmaskImpreciseCopyMask(char *src, unsigned int length) {
309
- if constexpr (HEADER_SIZE != 6) {
310
- char mask[8] = {src[-4], src[-3], src[-2], src[-1], src[-4], src[-3], src[-2], src[-1]};
311
- uint64_t maskInt;
312
- memcpy(&maskInt, mask, 8);
313
- unmaskImprecise8<HEADER_SIZE>(src, maskInt, length);
314
- } else {
315
- char mask[4] = {src[-4], src[-3], src[-2], src[-1]};
316
- uint32_t maskInt;
317
- memcpy(&maskInt, mask, 4);
318
- unmaskImprecise4<HEADER_SIZE>(src, maskInt, length);
319
- }
320
+ }
321
+
322
+ template <int HEADER_SIZE>
323
+ static inline void unmaskImpreciseCopyMask(char *src, unsigned int length) {
324
+ if constexpr (HEADER_SIZE != 6) {
325
+ char mask[8] = {src[-4], src[-3], src[-2], src[-1],
326
+ src[-4], src[-3], src[-2], src[-1]};
327
+ uint64_t maskInt;
328
+ memcpy(&maskInt, mask, 8);
329
+ unmaskImprecise8<HEADER_SIZE>(src, maskInt, length);
330
+ } else {
331
+ char mask[4] = {src[-4], src[-3], src[-2], src[-1]};
332
+ uint32_t maskInt;
333
+ memcpy(&maskInt, mask, 4);
334
+ unmaskImprecise4<HEADER_SIZE>(src, maskInt, length);
320
335
  }
321
-
322
- static inline void rotateMask(unsigned int offset, char *mask) {
323
- char originalMask[4] = {mask[0], mask[1], mask[2], mask[3]};
324
- mask[(0 + offset) % 4] = originalMask[0];
325
- mask[(1 + offset) % 4] = originalMask[1];
326
- mask[(2 + offset) % 4] = originalMask[2];
327
- mask[(3 + offset) % 4] = originalMask[3];
336
+ }
337
+
338
+ static inline void rotateMask(unsigned int offset, char *mask) {
339
+ char originalMask[4] = {mask[0], mask[1], mask[2], mask[3]};
340
+ mask[(0 + offset) % 4] = originalMask[0];
341
+ mask[(1 + offset) % 4] = originalMask[1];
342
+ mask[(2 + offset) % 4] = originalMask[2];
343
+ mask[(3 + offset) % 4] = originalMask[3];
344
+ }
345
+
346
+ static inline void unmaskInplace(char *data, char *stop, char *mask) {
347
+ while (data < stop) {
348
+ *(data++) ^= mask[0];
349
+ *(data++) ^= mask[1];
350
+ *(data++) ^= mask[2];
351
+ *(data++) ^= mask[3];
352
+ }
353
+ }
354
+
355
+ template <unsigned int MESSAGE_HEADER, typename T>
356
+ static inline bool
357
+ consumeMessage(T payLength, char *&src, unsigned int &length,
358
+ WebSocketState<isServer> *wState, void *user) {
359
+ if (getOpCode(src)) {
360
+ if (wState->state.opStack == 1 ||
361
+ (!wState->state.lastFin && getOpCode(src) < 2)) {
362
+ Impl::forceClose(wState, user);
363
+ return true;
364
+ }
365
+ wState->state.opCode[++wState->state.opStack] = (OpCode)getOpCode(src);
366
+ } else if (wState->state.opStack == -1) {
367
+ Impl::forceClose(wState, user);
368
+ return true;
328
369
  }
370
+ wState->state.lastFin = isFin(src);
329
371
 
330
- static inline void unmaskInplace(char *data, char *stop, char *mask) {
331
- while (data < stop) {
332
- *(data++) ^= mask[0];
333
- *(data++) ^= mask[1];
334
- *(data++) ^= mask[2];
335
- *(data++) ^= mask[3];
336
- }
372
+ if (Impl::refusePayloadLength(payLength, wState, user)) {
373
+ Impl::forceClose(wState, user, ERR_TOO_BIG_MESSAGE);
374
+ return true;
337
375
  }
338
376
 
339
- template <unsigned int MESSAGE_HEADER, typename T>
340
- static inline bool consumeMessage(T payLength, char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
341
- if (getOpCode(src)) {
342
- if (wState->state.opStack == 1 || (!wState->state.lastFin && getOpCode(src) < 2)) {
343
- Impl::forceClose(wState, user);
344
- return true;
345
- }
346
- wState->state.opCode[++wState->state.opStack] = (OpCode) getOpCode(src);
347
- } else if (wState->state.opStack == -1) {
348
- Impl::forceClose(wState, user);
349
- return true;
377
+ if (payLength + MESSAGE_HEADER <= length) {
378
+ bool fin = isFin(src);
379
+ if (isServer) {
380
+ /* This guy can never be assumed to be perfectly aligned since we can
381
+ * get multiple messages in one read */
382
+ unmaskImpreciseCopyMask<MESSAGE_HEADER>(src + MESSAGE_HEADER,
383
+ (unsigned int)payLength);
384
+ if (Impl::handleFragment(src, payLength, 0,
385
+ wState->state.opCode[wState->state.opStack],
386
+ fin, wState, user)) {
387
+ return true;
350
388
  }
351
- wState->state.lastFin = isFin(src);
352
-
353
- if (Impl::refusePayloadLength(payLength, wState, user)) {
354
- Impl::forceClose(wState, user, ERR_TOO_BIG_MESSAGE);
355
- return true;
389
+ } else {
390
+ if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0,
391
+ wState->state.opCode[wState->state.opStack],
392
+ isFin(src), wState, user)) {
393
+ return true;
356
394
  }
395
+ }
357
396
 
358
- if (payLength + MESSAGE_HEADER <= length) {
359
- bool fin = isFin(src);
360
- if (isServer) {
361
- /* This guy can never be assumed to be perfectly aligned since we can get multiple messages in one read */
362
- unmaskImpreciseCopyMask<MESSAGE_HEADER>(src + MESSAGE_HEADER, (unsigned int) payLength);
363
- if (Impl::handleFragment(src, payLength, 0, wState->state.opCode[wState->state.opStack], fin, wState, user)) {
364
- return true;
365
- }
366
- } else {
367
- if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0, wState->state.opCode[wState->state.opStack], isFin(src), wState, user)) {
368
- return true;
369
- }
370
- }
371
-
372
- if (fin) {
373
- wState->state.opStack--;
374
- }
375
-
376
- src += payLength + MESSAGE_HEADER;
377
- length -= (unsigned int) (payLength + MESSAGE_HEADER);
378
- wState->state.spillLength = 0;
379
- return false;
380
- } else {
381
- wState->state.spillLength = 0;
382
- wState->state.wantsHead = false;
383
- wState->remainingBytes = (unsigned int) (payLength - length + MESSAGE_HEADER);
384
- bool fin = isFin(src);
385
- if constexpr (isServer) {
386
- memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4);
387
- uint64_t mask;
388
- memcpy(&mask, src + MESSAGE_HEADER - 4, 4);
389
- memcpy(((char *)&mask) + 4, src + MESSAGE_HEADER - 4, 4);
390
- unmaskImprecise8<0>(src + MESSAGE_HEADER, mask, length);
391
- rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask);
392
- }
393
- Impl::handleFragment(src + MESSAGE_HEADER, length - MESSAGE_HEADER, wState->remainingBytes, wState->state.opCode[wState->state.opStack], fin, wState, user);
394
- return true;
395
- }
396
- }
397
+ if (fin) {
398
+ wState->state.opStack--;
399
+ }
397
400
 
398
- /* This one is nicely vectorized on both ARM64 and X64 - especially with -mavx */
399
- static inline void unmaskAll(char * __restrict data, char * __restrict mask) {
400
- for (int i = 0; i < LIBUS_RECV_BUFFER_LENGTH; i += 16) {
401
- UnrolledXor<16>(data + i, mask);
402
- }
401
+ src += payLength + MESSAGE_HEADER;
402
+ length -= (unsigned int)(payLength + MESSAGE_HEADER);
403
+ wState->state.spillLength = 0;
404
+ return false;
405
+ } else {
406
+ wState->state.spillLength = 0;
407
+ wState->state.wantsHead = false;
408
+ wState->remainingBytes =
409
+ (unsigned int)(payLength - length + MESSAGE_HEADER);
410
+ bool fin = isFin(src);
411
+ if constexpr (isServer) {
412
+ memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4);
413
+ uint64_t mask;
414
+ memcpy(&mask, src + MESSAGE_HEADER - 4, 4);
415
+ memcpy(((char *)&mask) + 4, src + MESSAGE_HEADER - 4, 4);
416
+ unmaskImprecise8<0>(src + MESSAGE_HEADER, mask, length);
417
+ rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask);
418
+ }
419
+ Impl::handleFragment(
420
+ src + MESSAGE_HEADER, length - MESSAGE_HEADER, wState->remainingBytes,
421
+ wState->state.opCode[wState->state.opStack], fin, wState, user);
422
+ return true;
403
423
  }
424
+ }
404
425
 
405
- static inline bool consumeContinuation(char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
406
- if (wState->remainingBytes <= length) {
407
- if (isServer) {
408
- unsigned int n = wState->remainingBytes >> 2;
409
- unmaskInplace(src, src + n * 4, wState->mask);
410
- for (unsigned int i = 0, s = wState->remainingBytes % 4; i < s; i++) {
411
- src[n * 4 + i] ^= wState->mask[i];
412
- }
413
- }
414
-
415
- if (Impl::handleFragment(src, wState->remainingBytes, 0, wState->state.opCode[wState->state.opStack], wState->state.lastFin, wState, user)) {
416
- return false;
417
- }
418
-
419
- if (wState->state.lastFin) {
420
- wState->state.opStack--;
421
- }
422
-
423
- src += wState->remainingBytes;
424
- length -= wState->remainingBytes;
425
- wState->state.wantsHead = true;
426
- return true;
427
- } else {
428
- if (isServer) {
429
- /* No need to unmask if mask is 0 */
430
- uint32_t nullmask = 0;
431
- if (memcmp(wState->mask, &nullmask, sizeof(uint32_t))) {
432
- if /*constexpr*/ (LIBUS_RECV_BUFFER_LENGTH == length) {
433
- unmaskAll(src, wState->mask);
434
- } else {
435
- // Slow path
436
- unmaskInplace(src, src + ((length >> 2) + 1) * 4, wState->mask);
437
- }
438
- }
439
- }
440
-
441
- wState->remainingBytes -= length;
442
- if (Impl::handleFragment(src, length, wState->remainingBytes, wState->state.opCode[wState->state.opStack], wState->state.lastFin, wState, user)) {
443
- return false;
444
- }
445
-
446
- if (isServer && length % 4) {
447
- rotateMask(4 - (length % 4), wState->mask);
448
- }
449
- return false;
426
+ /* This one is nicely vectorized on both ARM64 and X64 - especially with -mavx
427
+ */
428
+ static inline void unmaskAll(char *__restrict data, char *__restrict mask) {
429
+ for (int i = 0; i < LIBUS_RECV_BUFFER_LENGTH; i += 16) {
430
+ UnrolledXor<16>(data + i, mask);
431
+ }
432
+ }
433
+
434
+ static inline bool consumeContinuation(char *&src, unsigned int &length,
435
+ WebSocketState<isServer> *wState,
436
+ void *user) {
437
+ if (wState->remainingBytes <= length) {
438
+ if (isServer) {
439
+ unsigned int n = wState->remainingBytes >> 2;
440
+ unmaskInplace(src, src + n * 4, wState->mask);
441
+ for (unsigned int i = 0, s = wState->remainingBytes % 4; i < s; i++) {
442
+ src[n * 4 + i] ^= wState->mask[i];
450
443
  }
444
+ }
445
+
446
+ if (Impl::handleFragment(src, wState->remainingBytes, 0,
447
+ wState->state.opCode[wState->state.opStack],
448
+ wState->state.lastFin, wState, user)) {
449
+ return false;
450
+ }
451
+
452
+ if (wState->state.lastFin) {
453
+ wState->state.opStack--;
454
+ }
455
+
456
+ src += wState->remainingBytes;
457
+ length -= wState->remainingBytes;
458
+ wState->state.wantsHead = true;
459
+ return true;
460
+ } else {
461
+ if (isServer) {
462
+ /* No need to unmask if mask is 0 */
463
+ uint32_t nullmask = 0;
464
+ if (memcmp(wState->mask, &nullmask, sizeof(uint32_t))) {
465
+ if /*constexpr*/ (LIBUS_RECV_BUFFER_LENGTH == length) {
466
+ unmaskAll(src, wState->mask);
467
+ } else {
468
+ // Slow path
469
+ unmaskInplace(src, src + ((length >> 2) + 1) * 4, wState->mask);
470
+ }
471
+ }
472
+ }
473
+
474
+ wState->remainingBytes -= length;
475
+ if (Impl::handleFragment(src, length, wState->remainingBytes,
476
+ wState->state.opCode[wState->state.opStack],
477
+ wState->state.lastFin, wState, user)) {
478
+ return false;
479
+ }
480
+
481
+ if (isServer && length % 4) {
482
+ rotateMask(4 - (length % 4), wState->mask);
483
+ }
484
+ return false;
451
485
  }
486
+ }
452
487
 
453
488
  public:
454
- WebSocketProtocol() {
455
-
489
+ WebSocketProtocol() {}
490
+
491
+ static inline void consume(char *src, unsigned int length,
492
+ WebSocketState<isServer> *wState, void *user) {
493
+ if (wState->state.spillLength) {
494
+ src -= wState->state.spillLength;
495
+ length += wState->state.spillLength;
496
+ memcpy(src, wState->state.spill, wState->state.spillLength);
456
497
  }
457
-
458
- static inline void consume(char *src, unsigned int length, WebSocketState<isServer> *wState, void *user) {
459
- if (wState->state.spillLength) {
460
- src -= wState->state.spillLength;
461
- length += wState->state.spillLength;
462
- memcpy(src, wState->state.spill, wState->state.spillLength);
498
+ if (wState->state.wantsHead) {
499
+ parseNext:
500
+ while (length >= SHORT_MESSAGE_HEADER) {
501
+
502
+ // invalid reserved bits / invalid opcodes / invalid control frames /
503
+ // set compressed frame
504
+ if ((rsv1(src) && !Impl::setCompressed(wState, user)) || rsv23(src) ||
505
+ (getOpCode(src) > 2 && getOpCode(src) < 8) || getOpCode(src) > 10 ||
506
+ (getOpCode(src) > 2 && (!isFin(src) || payloadLength(src) > 125))) {
507
+ Impl::forceClose(wState, user);
508
+ return;
463
509
  }
464
- if (wState->state.wantsHead) {
465
- parseNext:
466
- while (length >= SHORT_MESSAGE_HEADER) {
467
-
468
- // invalid reserved bits / invalid opcodes / invalid control frames / set compressed frame
469
- if ((rsv1(src) && !Impl::setCompressed(wState, user)) || rsv23(src) || (getOpCode(src) > 2 && getOpCode(src) < 8) ||
470
- getOpCode(src) > 10 || (getOpCode(src) > 2 && (!isFin(src) || payloadLength(src) > 125))) {
471
- Impl::forceClose(wState, user);
472
- return;
473
- }
474
-
475
- if (payloadLength(src) < 126) {
476
- if (consumeMessage<SHORT_MESSAGE_HEADER, uint8_t>(payloadLength(src), src, length, wState, user)) {
477
- return;
478
- }
479
- } else if (payloadLength(src) == 126) {
480
- if (length < MEDIUM_MESSAGE_HEADER) {
481
- break;
482
- } else if(consumeMessage<MEDIUM_MESSAGE_HEADER, uint16_t>(protocol::cond_byte_swap<uint16_t>(protocol::bit_cast<uint16_t>(src + 2)), src, length, wState, user)) {
483
- return;
484
- }
485
- } else if (length < LONG_MESSAGE_HEADER) {
486
- break;
487
- } else if (consumeMessage<LONG_MESSAGE_HEADER, uint64_t>(protocol::cond_byte_swap<uint64_t>(protocol::bit_cast<uint64_t>(src + 2)), src, length, wState, user)) {
488
- return;
489
- }
490
- }
491
- if (length) {
492
- memcpy(wState->state.spill, src, length);
493
- wState->state.spillLength = length & 0xf;
494
- }
495
- } else if (consumeContinuation(src, length, wState, user)) {
496
- goto parseNext;
510
+
511
+ if (payloadLength(src) < 126) {
512
+ if (consumeMessage<SHORT_MESSAGE_HEADER, uint8_t>(
513
+ payloadLength(src), src, length, wState, user)) {
514
+ return;
515
+ }
516
+ } else if (payloadLength(src) == 126) {
517
+ if (length < MEDIUM_MESSAGE_HEADER) {
518
+ break;
519
+ } else if (consumeMessage<MEDIUM_MESSAGE_HEADER, uint16_t>(
520
+ protocol::cond_byte_swap<uint16_t>(
521
+ protocol::bit_cast<uint16_t>(src + 2)),
522
+ src, length, wState, user)) {
523
+ return;
524
+ }
525
+ } else if (length < LONG_MESSAGE_HEADER) {
526
+ break;
527
+ } else if (consumeMessage<LONG_MESSAGE_HEADER, uint64_t>(
528
+ protocol::cond_byte_swap<uint64_t>(
529
+ protocol::bit_cast<uint64_t>(src + 2)),
530
+ src, length, wState, user)) {
531
+ return;
497
532
  }
533
+ }
534
+ if (length) {
535
+ memcpy(wState->state.spill, src, length);
536
+ wState->state.spillLength = length & 0xf;
537
+ }
538
+ } else if (consumeContinuation(src, length, wState, user)) {
539
+ goto parseNext;
498
540
  }
541
+ }
499
542
 
500
- static const int CONSUME_POST_PADDING = 4;
501
- static const int CONSUME_PRE_PADDING = LONG_MESSAGE_HEADER - 1;
543
+ static const int CONSUME_POST_PADDING = 4;
544
+ static const int CONSUME_PRE_PADDING = LONG_MESSAGE_HEADER - 1;
502
545
  };
503
546
 
504
- }
547
+ } // namespace uWS
505
548
 
506
549
  #endif // UWS_WEBSOCKETPROTOCOL_H