@upstash/ratelimit 1.0.1 → 1.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.mjs CHANGED
@@ -102,7 +102,7 @@ function ms(d) {
102
102
  if (!match) {
103
103
  throw new Error(`Unable to parse window size: ${d}`);
104
104
  }
105
- const time = parseInt(match[1]);
105
+ const time = Number.parseInt(match[1]);
106
106
  const unit = match[2];
107
107
  switch (unit) {
108
108
  case "ms":
@@ -120,6 +120,59 @@ function ms(d) {
120
120
  }
121
121
  }
122
122
 
123
+ // src/lua-scripts/multi.ts
124
+ var fixedWindowScript = `
125
+ local key = KEYS[1]
126
+ local id = ARGV[1]
127
+ local window = ARGV[2]
128
+ local incrementBy = tonumber(ARGV[3])
129
+
130
+ redis.call("HSET", key, id, incrementBy)
131
+ local fields = redis.call("HGETALL", key)
132
+ if #fields == 1 and tonumber(fields[1])==incrementBy then
133
+ -- The first time this key is set, and the value will be equal to incrementBy.
134
+ -- So we only need the expire command once
135
+ redis.call("PEXPIRE", key, window)
136
+ end
137
+
138
+ return fields
139
+ `;
140
+ var slidingWindowScript = `
141
+ local currentKey = KEYS[1] -- identifier including prefixes
142
+ local previousKey = KEYS[2] -- key of the previous bucket
143
+ local tokens = tonumber(ARGV[1]) -- tokens per window
144
+ local now = ARGV[2] -- current timestamp in milliseconds
145
+ local window = ARGV[3] -- interval in milliseconds
146
+ local requestId = ARGV[4] -- uuid for this request
147
+ local incrementBy = tonumber(ARGV[5]) -- custom rate, default is 1
148
+
149
+ local currentFields = redis.call("HGETALL", currentKey)
150
+ local requestsInCurrentWindow = 0
151
+ for i = 2, #currentFields, 2 do
152
+ requestsInCurrentWindow = requestsInCurrentWindow + tonumber(currentFields[i])
153
+ end
154
+
155
+ local previousFields = redis.call("HGETALL", previousKey)
156
+ local requestsInPreviousWindow = 0
157
+ for i = 2, #previousFields, 2 do
158
+ requestsInPreviousWindow = requestsInPreviousWindow + tonumber(previousFields[i])
159
+ end
160
+
161
+ local percentageInCurrent = ( now % window) / window
162
+ if requestsInPreviousWindow * (1 - percentageInCurrent ) + requestsInCurrentWindow >= tokens then
163
+ return {currentFields, previousFields, false}
164
+ end
165
+
166
+ redis.call("HSET", currentKey, requestId, incrementBy)
167
+
168
+ if requestsInCurrentWindow == 0 then
169
+ -- The first time this key is set, the value will be equal to incrementBy.
170
+ -- So we only need the expire command once
171
+ redis.call("PEXPIRE", currentKey, window * 2 + 1000) -- Enough time to overlap with a new window + 1 second
172
+ end
173
+ return {currentFields, previousFields, true}
174
+ `;
175
+
123
176
  // src/ratelimit.ts
124
177
  var Ratelimit = class {
125
178
  limiter;
@@ -160,12 +213,29 @@ var Ratelimit = class {
160
213
  * }
161
214
  * return "Yes"
162
215
  * ```
216
+ *
217
+ * @param req.rate - The rate at which tokens will be added or consumed from the token bucket. A higher rate allows for more requests to be processed. Defaults to 1 token per interval if not specified.
218
+ *
219
+ * Usage with `req.rate`
220
+ * @example
221
+ * ```ts
222
+ * const ratelimit = new Ratelimit({
223
+ * redis: Redis.fromEnv(),
224
+ * limiter: Ratelimit.slidingWindow(100, "10 s")
225
+ * })
226
+ *
227
+ * const { success } = await ratelimit.limit(id, {rate: 10})
228
+ * if (!success){
229
+ * return "Nope"
230
+ * }
231
+ * return "Yes"
232
+ * ```
163
233
  */
164
234
  limit = async (identifier, req) => {
165
235
  const key = [this.prefix, identifier].join(":");
166
236
  let timeoutId = null;
167
237
  try {
168
- const arr = [this.limiter(this.ctx, key)];
238
+ const arr = [this.limiter(this.ctx, key, req?.rate)];
169
239
  if (this.timeout > 0) {
170
240
  arr.push(
171
241
  new Promise((resolve) => {
@@ -297,22 +367,7 @@ var MultiRegionRatelimit = class extends Ratelimit {
297
367
  */
298
368
  static fixedWindow(tokens, window) {
299
369
  const windowDuration = ms(window);
300
- const script = `
301
- local key = KEYS[1]
302
- local id = ARGV[1]
303
- local window = ARGV[2]
304
-
305
- redis.call("SADD", key, id)
306
- local members = redis.call("SMEMBERS", key)
307
- if #members == 1 then
308
- -- The first time this key is set, the value will be 1.
309
- -- So we only need the expire command once
310
- redis.call("PEXPIRE", key, window)
311
- end
312
-
313
- return members
314
- `;
315
- return async function(ctx, identifier) {
370
+ return async (ctx, identifier, rate) => {
316
371
  if (ctx.cache) {
317
372
  const { blocked, reset: reset2 } = ctx.cache.isBlocked(identifier);
318
373
  if (blocked) {
@@ -328,26 +383,60 @@ var MultiRegionRatelimit = class extends Ratelimit {
328
383
  const requestId = randomId();
329
384
  const bucket = Math.floor(Date.now() / windowDuration);
330
385
  const key = [identifier, bucket].join(":");
386
+ const incrementBy = rate ? Math.max(1, rate) : 1;
331
387
  const dbs = ctx.redis.map((redis) => ({
332
388
  redis,
333
- request: redis.eval(script, [key], [requestId, windowDuration])
389
+ request: redis.eval(
390
+ fixedWindowScript,
391
+ [key],
392
+ [requestId, windowDuration, incrementBy]
393
+ )
334
394
  }));
335
395
  const firstResponse = await Promise.any(dbs.map((s) => s.request));
336
- const usedTokens = firstResponse.length;
337
- const remaining = tokens - usedTokens - 1;
396
+ const usedTokens = firstResponse.reduce((accTokens, usedToken, index) => {
397
+ let parsedToken = 0;
398
+ if (index % 2) {
399
+ parsedToken = Number.parseInt(usedToken);
400
+ }
401
+ return accTokens + parsedToken;
402
+ }, 0);
403
+ const remaining = tokens - usedTokens;
338
404
  async function sync() {
339
405
  const individualIDs = await Promise.all(dbs.map((s) => s.request));
340
- const allIDs = Array.from(new Set(individualIDs.flatMap((_) => _)).values());
406
+ const allIDs = Array.from(
407
+ new Set(
408
+ individualIDs.flatMap((_) => _).reduce((acc, curr, index) => {
409
+ if (index % 2 === 0) {
410
+ acc.push(curr);
411
+ }
412
+ return acc;
413
+ }, [])
414
+ ).values()
415
+ );
341
416
  for (const db of dbs) {
342
- const ids = await db.request;
343
- if (ids.length >= tokens) {
417
+ const usedDbTokens = (await db.request).reduce((accTokens, usedToken, index) => {
418
+ let parsedToken = 0;
419
+ if (index % 2) {
420
+ parsedToken = Number.parseInt(usedToken);
421
+ }
422
+ return accTokens + parsedToken;
423
+ }, 0);
424
+ const dbIds = (await db.request).reduce((ids, currentId, index) => {
425
+ if (index % 2 === 0) {
426
+ ids.push(currentId);
427
+ }
428
+ return ids;
429
+ }, []);
430
+ if (usedDbTokens >= tokens) {
344
431
  continue;
345
432
  }
346
- const diff = allIDs.filter((id) => !ids.includes(id));
433
+ const diff = allIDs.filter((id) => !dbIds.includes(id));
347
434
  if (diff.length === 0) {
348
435
  continue;
349
436
  }
350
- await db.redis.sadd(key, ...allIDs);
437
+ for (const requestId2 of diff) {
438
+ await db.redis.hset(key, { [requestId2]: incrementBy });
439
+ }
351
440
  }
352
441
  }
353
442
  const success = remaining > 0;
@@ -382,69 +471,76 @@ var MultiRegionRatelimit = class extends Ratelimit {
382
471
  */
383
472
  static slidingWindow(tokens, window) {
384
473
  const windowSize = ms(window);
385
- const script = `
386
- local currentKey = KEYS[1] -- identifier including prefixes
387
- local previousKey = KEYS[2] -- key of the previous bucket
388
- local tokens = tonumber(ARGV[1]) -- tokens per window
389
- local now = ARGV[2] -- current timestamp in milliseconds
390
- local window = ARGV[3] -- interval in milliseconds
391
- local requestId = ARGV[4] -- uuid for this request
392
-
393
-
394
- local currentMembers = redis.call("SMEMBERS", currentKey)
395
- local requestsInCurrentWindow = #currentMembers
396
- local previousMembers = redis.call("SMEMBERS", previousKey)
397
- local requestsInPreviousWindow = #previousMembers
398
-
399
- local percentageInCurrent = ( now % window) / window
400
- if requestsInPreviousWindow * ( 1 - percentageInCurrent ) + requestsInCurrentWindow >= tokens then
401
- return {currentMembers, previousMembers, false}
402
- end
403
-
404
- redis.call("SADD", currentKey, requestId)
405
- table.insert(currentMembers, requestId)
406
- if requestsInCurrentWindow == 0 then
407
- -- The first time this key is set, the value will be 1.
408
- -- So we only need the expire command once
409
- redis.call("PEXPIRE", currentKey, window * 2 + 1000) -- Enough time to overlap with a new window + 1 second
410
- end
411
- return {currentMembers, previousMembers, true}
412
- `;
413
474
  const windowDuration = ms(window);
414
- return async function(ctx, identifier) {
475
+ return async (ctx, identifier, rate) => {
415
476
  const requestId = randomId();
416
477
  const now = Date.now();
417
478
  const currentWindow = Math.floor(now / windowSize);
418
479
  const currentKey = [identifier, currentWindow].join(":");
419
480
  const previousWindow = currentWindow - 1;
420
481
  const previousKey = [identifier, previousWindow].join(":");
482
+ const incrementBy = rate ? Math.max(1, rate) : 1;
421
483
  const dbs = ctx.redis.map((redis) => ({
422
484
  redis,
423
485
  request: redis.eval(
424
- script,
486
+ slidingWindowScript,
425
487
  [currentKey, previousKey],
426
- [tokens, now, windowDuration, requestId]
488
+ [tokens, now, windowDuration, requestId, incrementBy]
427
489
  // lua seems to return `1` for true and `null` for false
428
490
  )
429
491
  }));
430
492
  const percentageInCurrent = now % windowDuration / windowDuration;
431
493
  const [current, previous, success] = await Promise.any(dbs.map((s) => s.request));
432
- const previousPartialUsed = previous.length * (1 - percentageInCurrent);
433
- const usedTokens = previousPartialUsed + current.length;
494
+ const previousUsedTokens = previous.reduce((accTokens, usedToken, index) => {
495
+ let parsedToken = 0;
496
+ if (index % 2) {
497
+ parsedToken = Number.parseInt(usedToken);
498
+ }
499
+ return accTokens + parsedToken;
500
+ }, 0);
501
+ const currentUsedTokens = current.reduce((accTokens, usedToken, index) => {
502
+ let parsedToken = 0;
503
+ if (index % 2) {
504
+ parsedToken = Number.parseInt(usedToken);
505
+ }
506
+ return accTokens + parsedToken;
507
+ }, 0);
508
+ const previousPartialUsed = previousUsedTokens * (1 - percentageInCurrent);
509
+ const usedTokens = previousPartialUsed + currentUsedTokens;
434
510
  const remaining = tokens - usedTokens;
435
511
  async function sync() {
436
512
  const res = await Promise.all(dbs.map((s) => s.request));
437
- const allCurrentIds = res.flatMap(([current2]) => current2);
513
+ const allCurrentIds = res.flatMap(([current2]) => current2).reduce((accCurrentIds, curr, index) => {
514
+ if (index % 2 === 0) {
515
+ accCurrentIds.push(curr);
516
+ }
517
+ return accCurrentIds;
518
+ }, []);
438
519
  for (const db of dbs) {
439
- const [ids] = await db.request;
440
- if (ids.length >= tokens) {
520
+ const [_current, previous2, _success] = await db.request;
521
+ const dbIds = previous2.reduce((ids, currentId, index) => {
522
+ if (index % 2 === 0) {
523
+ ids.push(currentId);
524
+ }
525
+ return ids;
526
+ }, []);
527
+ const usedDbTokens = previous2.reduce((accTokens, usedToken, index) => {
528
+ let parsedToken = 0;
529
+ if (index % 2) {
530
+ parsedToken = Number.parseInt(usedToken);
531
+ }
532
+ return accTokens + parsedToken;
533
+ }, 0);
534
+ if (usedDbTokens >= tokens) {
441
535
  continue;
442
536
  }
443
- const diff = allCurrentIds.filter((id) => !ids.includes(id));
537
+ const diff = allCurrentIds.filter((id) => !dbIds.includes(id));
444
538
  if (diff.length === 0) {
445
539
  continue;
446
540
  }
447
- await db.redis.sadd(currentKey, ...diff);
541
+ for (const requestId2 of diff) {
542
+ await db.redis.hset(currentKey, { [requestId2]: incrementBy });
543
+ }
448
544
  }
449
545
  }
450
546
  const reset = (currentWindow + 1) * windowDuration;
@@ -454,7 +550,7 @@ var MultiRegionRatelimit = class extends Ratelimit {
454
550
  return {
455
551
  success: Boolean(success),
456
552
  limit: tokens,
457
- remaining,
553
+ remaining: Math.max(0, remaining),
458
554
  reset,
459
555
  pending: sync()
460
556
  };
@@ -462,6 +558,107 @@ var MultiRegionRatelimit = class extends Ratelimit {
462
558
  }
463
559
  };
464
560
 
561
+ // src/lua-scripts/single.ts
562
+ var fixedWindowScript2 = `
563
+ local key = KEYS[1]
564
+ local window = ARGV[1]
565
+ local incrementBy = ARGV[2] -- increment rate per request at a given value, default is 1
566
+
567
+ local r = redis.call("INCRBY", key, incrementBy)
568
+ if r == incrementBy then
569
+ -- The first time this key is set, the value will be equal to incrementBy.
570
+ -- So we only need the expire command once
571
+ redis.call("PEXPIRE", key, window)
572
+ end
573
+
574
+ return r
575
+ `;
576
+ var slidingWindowScript2 = `
577
+ local currentKey = KEYS[1] -- identifier including prefixes
578
+ local previousKey = KEYS[2] -- key of the previous bucket
579
+ local tokens = tonumber(ARGV[1]) -- tokens per window
580
+ local now = ARGV[2] -- current timestamp in milliseconds
581
+ local window = ARGV[3] -- interval in milliseconds
582
+ local incrementBy = ARGV[4] -- increment rate per request at a given value, default is 1
583
+
584
+ local requestsInCurrentWindow = redis.call("GET", currentKey)
585
+ if requestsInCurrentWindow == false then
586
+ requestsInCurrentWindow = 0
587
+ end
588
+
589
+ local requestsInPreviousWindow = redis.call("GET", previousKey)
590
+ if requestsInPreviousWindow == false then
591
+ requestsInPreviousWindow = 0
592
+ end
593
+ local percentageInCurrent = ( now % window ) / window
594
+ -- weighted requests to consider from the previous window
595
+ requestsInPreviousWindow = math.floor(( 1 - percentageInCurrent ) * requestsInPreviousWindow)
596
+ if requestsInPreviousWindow + requestsInCurrentWindow >= tokens then
597
+ return -1
598
+ end
599
+
600
+ local newValue = redis.call("INCRBY", currentKey, incrementBy)
601
+ if newValue == incrementBy then
602
+ -- The first time this key is set, the value will be equal to incrementBy.
603
+ -- So we only need the expire command once
604
+ redis.call("PEXPIRE", currentKey, window * 2 + 1000) -- Enough time to overlap with a new window + 1 second
605
+ end
606
+ return tokens - ( newValue + requestsInPreviousWindow )
607
+ `;
608
+ var tokenBucketScript = `
609
+ local key = KEYS[1] -- identifier including prefixes
610
+ local maxTokens = tonumber(ARGV[1]) -- maximum number of tokens
611
+ local interval = tonumber(ARGV[2]) -- size of the window in milliseconds
612
+ local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval
613
+ local now = tonumber(ARGV[4]) -- current timestamp in milliseconds
614
+ local incrementBy = tonumber(ARGV[5]) -- how many tokens to consume, default is 1
615
+
616
+ local bucket = redis.call("HMGET", key, "refilledAt", "tokens")
617
+
618
+ local refilledAt
619
+ local tokens
620
+
621
+ if bucket[1] == false then
622
+ refilledAt = now
623
+ tokens = maxTokens
624
+ else
625
+ refilledAt = tonumber(bucket[1])
626
+ tokens = tonumber(bucket[2])
627
+ end
628
+
629
+ if now >= refilledAt + interval then
630
+ local numRefills = math.floor((now - refilledAt) / interval)
631
+ tokens = math.min(maxTokens, tokens + numRefills * refillRate)
632
+
633
+ refilledAt = refilledAt + numRefills * interval
634
+ end
635
+
636
+ if tokens == 0 then
637
+ return {-1, refilledAt + interval}
638
+ end
639
+
640
+ local remaining = tokens - incrementBy
641
+ local expireAt = math.ceil(((maxTokens - remaining) / refillRate)) * interval
642
+
643
+ redis.call("HSET", key, "refilledAt", refilledAt, "tokens", remaining)
644
+ redis.call("PEXPIRE", key, expireAt)
645
+ return {remaining, refilledAt + interval}
646
+ `;
647
+ var cachedFixedWindowScript = `
648
+ local key = KEYS[1]
649
+ local window = ARGV[1]
650
+ local incrementBy = ARGV[2] -- increment rate per request at a given value, default is 1
651
+
652
+ local r = redis.call("INCRBY", key, incrementBy)
653
+ if r == incrementBy then
654
+ -- The first time this key is set, the value will be equal to incrementBy.
655
+ -- So we only need the expire command once
656
+ redis.call("PEXPIRE", key, window)
657
+ end
658
+
659
+ return r
660
+ `;
661
+
465
662
  // src/single.ts
466
663
  var RegionRatelimit = class extends Ratelimit {
467
664
  /**
@@ -499,20 +696,7 @@ var RegionRatelimit = class extends Ratelimit {
499
696
  */
500
697
  static fixedWindow(tokens, window) {
501
698
  const windowDuration = ms(window);
502
- const script = `
503
- local key = KEYS[1]
504
- local window = ARGV[1]
505
-
506
- local r = redis.call("INCR", key)
507
- if r == 1 then
508
- -- The first time this key is set, the value will be 1.
509
- -- So we only need the expire command once
510
- redis.call("PEXPIRE", key, window)
511
- end
512
-
513
- return r
514
- `;
515
- return async function(ctx, identifier) {
699
+ return async (ctx, identifier, rate) => {
516
700
  const bucket = Math.floor(Date.now() / windowDuration);
517
701
  const key = [identifier, bucket].join(":");
518
702
  if (ctx.cache) {
@@ -527,12 +711,14 @@ var RegionRatelimit = class extends Ratelimit {
527
711
  };
528
712
  }
529
713
  }
714
+ const incrementBy = rate ? Math.max(1, rate) : 1;
530
715
  const usedTokensAfterUpdate = await ctx.redis.eval(
531
- script,
716
+ fixedWindowScript2,
532
717
  [key],
533
- [windowDuration]
718
+ [windowDuration, incrementBy]
534
719
  );
535
720
  const success = usedTokensAfterUpdate <= tokens;
721
+ const remainingTokens = Math.max(0, tokens - usedTokensAfterUpdate);
536
722
  const reset = (bucket + 1) * windowDuration;
537
723
  if (ctx.cache && !success) {
538
724
  ctx.cache.blockUntil(identifier, reset);
@@ -540,7 +726,7 @@ var RegionRatelimit = class extends Ratelimit {
540
726
  return {
541
727
  success,
542
728
  limit: tokens,
543
- remaining: Math.max(0, tokens - usedTokensAfterUpdate),
729
+ remaining: remainingTokens,
544
730
  reset,
545
731
  pending: Promise.resolve()
546
732
  };
@@ -563,39 +749,8 @@ var RegionRatelimit = class extends Ratelimit {
563
749
  * @param window - The duration in which the user can max X requests.
564
750
  */
565
751
  static slidingWindow(tokens, window) {
566
- const script = `
567
- local currentKey = KEYS[1] -- identifier including prefixes
568
- local previousKey = KEYS[2] -- key of the previous bucket
569
- local tokens = tonumber(ARGV[1]) -- tokens per window
570
- local now = ARGV[2] -- current timestamp in milliseconds
571
- local window = ARGV[3] -- interval in milliseconds
572
-
573
- local requestsInCurrentWindow = redis.call("GET", currentKey)
574
- if requestsInCurrentWindow == false then
575
- requestsInCurrentWindow = 0
576
- end
577
-
578
- local requestsInPreviousWindow = redis.call("GET", previousKey)
579
- if requestsInPreviousWindow == false then
580
- requestsInPreviousWindow = 0
581
- end
582
- local percentageInCurrent = ( now % window ) / window
583
- -- weighted requests to consider from the previous window
584
- requestsInPreviousWindow = math.floor(( 1 - percentageInCurrent ) * requestsInPreviousWindow)
585
- if requestsInPreviousWindow + requestsInCurrentWindow >= tokens then
586
- return -1
587
- end
588
-
589
- local newValue = redis.call("INCR", currentKey)
590
- if newValue == 1 then
591
- -- The first time this key is set, the value will be 1.
592
- -- So we only need the expire command once
593
- redis.call("PEXPIRE", currentKey, window * 2 + 1000) -- Enough time to overlap with a new window + 1 second
594
- end
595
- return tokens - ( newValue + requestsInPreviousWindow )
596
- `;
597
752
  const windowSize = ms(window);
598
- return async function(ctx, identifier) {
753
+ return async (ctx, identifier, rate) => {
599
754
  const now = Date.now();
600
755
  const currentWindow = Math.floor(now / windowSize);
601
756
  const currentKey = [identifier, currentWindow].join(":");
@@ -613,12 +768,13 @@ var RegionRatelimit = class extends Ratelimit {
613
768
  };
614
769
  }
615
770
  }
616
- const remaining = await ctx.redis.eval(
617
- script,
771
+ const incrementBy = rate ? Math.max(1, rate) : 1;
772
+ const remainingTokens = await ctx.redis.eval(
773
+ slidingWindowScript2,
618
774
  [currentKey, previousKey],
619
- [tokens, now, windowSize]
775
+ [tokens, now, windowSize, incrementBy]
620
776
  );
621
- const success = remaining >= 0;
777
+ const success = remainingTokens >= 0;
622
778
  const reset = (currentWindow + 1) * windowSize;
623
779
  if (ctx.cache && !success) {
624
780
  ctx.cache.blockUntil(identifier, reset);
@@ -626,7 +782,7 @@ var RegionRatelimit = class extends Ratelimit {
626
782
  return {
627
783
  success,
628
784
  limit: tokens,
629
- remaining: Math.max(0, remaining),
785
+ remaining: Math.max(0, remainingTokens),
630
786
  reset,
631
787
  pending: Promise.resolve()
632
788
  };
@@ -646,46 +802,8 @@ var RegionRatelimit = class extends Ratelimit {
646
802
  * than `refillRate`
647
803
  */
648
804
  static tokenBucket(refillRate, interval, maxTokens) {
649
- const script = `
650
- local key = KEYS[1] -- identifier including prefixes
651
- local maxTokens = tonumber(ARGV[1]) -- maximum number of tokens
652
- local interval = tonumber(ARGV[2]) -- size of the window in milliseconds
653
- local refillRate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval
654
- local now = tonumber(ARGV[4]) -- current timestamp in milliseconds
655
-
656
- local bucket = redis.call("HMGET", key, "refilledAt", "tokens")
657
-
658
- local refilledAt
659
- local tokens
660
-
661
- if bucket[1] == false then
662
- refilledAt = now
663
- tokens = maxTokens
664
- else
665
- refilledAt = tonumber(bucket[1])
666
- tokens = tonumber(bucket[2])
667
- end
668
-
669
- if now >= refilledAt + interval then
670
- local numRefills = math.floor((now - refilledAt) / interval)
671
- tokens = math.min(maxTokens, tokens + numRefills * refillRate)
672
-
673
- refilledAt = refilledAt + numRefills * interval
674
- end
675
-
676
- if tokens == 0 then
677
- return {-1, refilledAt + interval}
678
- end
679
-
680
- local remaining = tokens - 1
681
- local expireAt = math.ceil(((maxTokens - remaining) / refillRate)) * interval
682
-
683
- redis.call("HSET", key, "refilledAt", refilledAt, "tokens", remaining)
684
- redis.call("PEXPIRE", key, expireAt)
685
- return {remaining, refilledAt + interval}
686
- `;
687
805
  const intervalDuration = ms(interval);
688
- return async function(ctx, identifier) {
806
+ return async (ctx, identifier, rate) => {
689
807
  if (ctx.cache) {
690
808
  const { blocked, reset: reset2 } = ctx.cache.isBlocked(identifier);
691
809
  if (blocked) {
@@ -699,10 +817,11 @@ var RegionRatelimit = class extends Ratelimit {
699
817
  }
700
818
  }
701
819
  const now = Date.now();
820
+ const incrementBy = rate ? Math.max(1, rate) : 1;
702
821
  const [remaining, reset] = await ctx.redis.eval(
703
- script,
822
+ tokenBucketScript,
704
823
  [identifier],
705
- [maxTokens, intervalDuration, refillRate, now]
824
+ [maxTokens, intervalDuration, refillRate, now, incrementBy]
706
825
  );
707
826
  const success = remaining >= 0;
708
827
  if (ctx.cache && !success) {
@@ -743,31 +862,19 @@ var RegionRatelimit = class extends Ratelimit {
743
862
  */
744
863
  static cachedFixedWindow(tokens, window) {
745
864
  const windowDuration = ms(window);
746
- const script = `
747
- local key = KEYS[1]
748
- local window = ARGV[1]
749
-
750
- local r = redis.call("INCR", key)
751
- if r == 1 then
752
- -- The first time this key is set, the value will be 1.
753
- -- So we only need the expire command once
754
- redis.call("PEXPIRE", key, window)
755
- end
756
-
757
- return r
758
- `;
759
- return async function(ctx, identifier) {
865
+ return async (ctx, identifier, rate) => {
760
866
  if (!ctx.cache) {
761
867
  throw new Error("This algorithm requires a cache");
762
868
  }
763
869
  const bucket = Math.floor(Date.now() / windowDuration);
764
870
  const key = [identifier, bucket].join(":");
765
871
  const reset = (bucket + 1) * windowDuration;
872
+ const incrementBy = rate ? Math.max(1, rate) : 1;
766
873
  const hit = typeof ctx.cache.get(key) === "number";
767
874
  if (hit) {
768
875
  const cachedTokensAfterUpdate = ctx.cache.incr(key);
769
876
  const success = cachedTokensAfterUpdate < tokens;
770
- const pending = success ? ctx.redis.eval(script, [key], [windowDuration]).then((t) => {
877
+ const pending = success ? ctx.redis.eval(cachedFixedWindowScript, [key], [windowDuration, incrementBy]).then((t) => {
771
878
  ctx.cache.set(key, t);
772
879
  }) : Promise.resolve();
773
880
  return {
@@ -779,9 +886,9 @@ var RegionRatelimit = class extends Ratelimit {
779
886
  };
780
887
  }
781
888
  const usedTokensAfterUpdate = await ctx.redis.eval(
782
- script,
889
+ cachedFixedWindowScript,
783
890
  [key],
784
- [windowDuration]
891
+ [windowDuration, incrementBy]
785
892
  );
786
893
  ctx.cache.set(key, usedTokensAfterUpdate);
787
894
  const remaining = tokens - usedTokensAfterUpdate;