node-llama-cpp 3.17.1 → 3.18.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. package/dist/bindings/AddonTypes.d.ts +13 -0
  2. package/dist/bindings/Llama.js +20 -2
  3. package/dist/bindings/Llama.js.map +1 -1
  4. package/dist/bindings/getLlama.d.ts +1 -1
  5. package/dist/bindings/getLlama.js +19 -8
  6. package/dist/bindings/getLlama.js.map +1 -1
  7. package/dist/bindings/utils/compileLLamaCpp.d.ts +2 -1
  8. package/dist/bindings/utils/compileLLamaCpp.js +8 -0
  9. package/dist/bindings/utils/compileLLamaCpp.js.map +1 -1
  10. package/dist/bindings/utils/getGpuTypesToUseForOption.d.ts +1 -1
  11. package/dist/bindings/utils/getLlamaGpuTypes.js +2 -0
  12. package/dist/bindings/utils/getLlamaGpuTypes.js.map +1 -1
  13. package/dist/chatWrappers/QwenChatWrapper.d.ts +7 -0
  14. package/dist/chatWrappers/QwenChatWrapper.js +176 -56
  15. package/dist/chatWrappers/QwenChatWrapper.js.map +1 -1
  16. package/dist/chatWrappers/generic/JinjaTemplateChatWrapper.js +127 -88
  17. package/dist/chatWrappers/generic/JinjaTemplateChatWrapper.js.map +1 -1
  18. package/dist/chatWrappers/generic/utils/extractFunctionCallSettingsFromJinjaTemplate.d.ts +16 -10
  19. package/dist/chatWrappers/generic/utils/extractFunctionCallSettingsFromJinjaTemplate.js +115 -5
  20. package/dist/chatWrappers/generic/utils/extractFunctionCallSettingsFromJinjaTemplate.js.map +1 -1
  21. package/dist/chatWrappers/generic/utils/extractSegmentSettingsFromTokenizerAndChatTemplate.js +1 -0
  22. package/dist/chatWrappers/generic/utils/extractSegmentSettingsFromTokenizerAndChatTemplate.js.map +1 -1
  23. package/dist/cli/commands/ChatCommand.d.ts +3 -0
  24. package/dist/cli/commands/ChatCommand.js +34 -5
  25. package/dist/cli/commands/ChatCommand.js.map +1 -1
  26. package/dist/cli/commands/CompleteCommand.d.ts +3 -0
  27. package/dist/cli/commands/CompleteCommand.js +34 -4
  28. package/dist/cli/commands/CompleteCommand.js.map +1 -1
  29. package/dist/cli/commands/InfillCommand.d.ts +3 -0
  30. package/dist/cli/commands/InfillCommand.js +34 -4
  31. package/dist/cli/commands/InfillCommand.js.map +1 -1
  32. package/dist/cli/commands/OnPostInstallCommand.js +31 -4
  33. package/dist/cli/commands/OnPostInstallCommand.js.map +1 -1
  34. package/dist/cli/commands/inspect/commands/InspectEstimateCommand.d.ts +3 -0
  35. package/dist/cli/commands/inspect/commands/InspectEstimateCommand.js +28 -1
  36. package/dist/cli/commands/inspect/commands/InspectEstimateCommand.js.map +1 -1
  37. package/dist/cli/commands/inspect/commands/InspectGgufCommand.js +5 -0
  38. package/dist/cli/commands/inspect/commands/InspectGgufCommand.js.map +1 -1
  39. package/dist/cli/commands/inspect/commands/InspectGpuCommand.js +51 -4
  40. package/dist/cli/commands/inspect/commands/InspectGpuCommand.js.map +1 -1
  41. package/dist/cli/commands/inspect/commands/InspectMeasureCommand.d.ts +3 -0
  42. package/dist/cli/commands/inspect/commands/InspectMeasureCommand.js +46 -5
  43. package/dist/cli/commands/inspect/commands/InspectMeasureCommand.js.map +1 -1
  44. package/dist/cli/utils/interactivelyAskForModel.d.ts +4 -1
  45. package/dist/cli/utils/interactivelyAskForModel.js +21 -7
  46. package/dist/cli/utils/interactivelyAskForModel.js.map +1 -1
  47. package/dist/cli/utils/packageJsonConfig.d.ts +6 -0
  48. package/dist/cli/utils/packageJsonConfig.js +51 -0
  49. package/dist/cli/utils/packageJsonConfig.js.map +1 -0
  50. package/dist/cli/utils/packageManager.d.ts +1 -0
  51. package/dist/cli/utils/packageManager.js +15 -0
  52. package/dist/cli/utils/packageManager.js.map +1 -0
  53. package/dist/cli/utils/printCommonInfoLines.js +9 -0
  54. package/dist/cli/utils/printCommonInfoLines.js.map +1 -1
  55. package/dist/cli/utils/resolveCommandGgufPath.d.ts +4 -1
  56. package/dist/cli/utils/resolveCommandGgufPath.js +9 -2
  57. package/dist/cli/utils/resolveCommandGgufPath.js.map +1 -1
  58. package/dist/cli/utils/resolveNpmrcConfig.d.ts +18 -0
  59. package/dist/cli/utils/resolveNpmrcConfig.js +129 -0
  60. package/dist/cli/utils/resolveNpmrcConfig.js.map +1 -0
  61. package/dist/config.d.ts +6 -1
  62. package/dist/config.js +12 -2
  63. package/dist/config.js.map +1 -1
  64. package/dist/evaluator/LlamaChat/LlamaChat.d.ts +8 -2
  65. package/dist/evaluator/LlamaChat/LlamaChat.js +99 -6
  66. package/dist/evaluator/LlamaChat/LlamaChat.js.map +1 -1
  67. package/dist/evaluator/LlamaChat/utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.js +8 -2
  68. package/dist/evaluator/LlamaChat/utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.js.map +1 -1
  69. package/dist/evaluator/LlamaChatSession/LlamaChatSession.d.ts +8 -2
  70. package/dist/evaluator/LlamaChatSession/LlamaChatSession.js.map +1 -1
  71. package/dist/evaluator/LlamaContext/LlamaContext.d.ts +91 -0
  72. package/dist/evaluator/LlamaContext/LlamaContext.js +215 -19
  73. package/dist/evaluator/LlamaContext/LlamaContext.js.map +1 -1
  74. package/dist/evaluator/LlamaContext/LlamaContextSequenceCheckpoints.d.ts +27 -0
  75. package/dist/evaluator/LlamaContext/LlamaContextSequenceCheckpoints.js +130 -0
  76. package/dist/evaluator/LlamaContext/LlamaContextSequenceCheckpoints.js.map +1 -0
  77. package/dist/evaluator/LlamaContext/types.d.ts +32 -1
  78. package/dist/evaluator/LlamaModel/LlamaModel.d.ts +33 -0
  79. package/dist/evaluator/LlamaModel/LlamaModel.js +24 -3
  80. package/dist/evaluator/LlamaModel/LlamaModel.js.map +1 -1
  81. package/dist/gguf/insights/GgufInsights.d.ts +12 -1
  82. package/dist/gguf/insights/GgufInsights.js +246 -49
  83. package/dist/gguf/insights/GgufInsights.js.map +1 -1
  84. package/dist/gguf/insights/GgufInsightsConfigurationResolver.d.ts +13 -4
  85. package/dist/gguf/insights/GgufInsightsConfigurationResolver.js +17 -5
  86. package/dist/gguf/insights/GgufInsightsConfigurationResolver.js.map +1 -1
  87. package/dist/gguf/insights/utils/resolveContextContextSizeOption.d.ts +4 -1
  88. package/dist/gguf/insights/utils/resolveContextContextSizeOption.js +7 -1
  89. package/dist/gguf/insights/utils/resolveContextContextSizeOption.js.map +1 -1
  90. package/dist/gguf/insights/utils/resolveModelGpuLayersOption.d.ts +4 -1
  91. package/dist/gguf/insights/utils/resolveModelGpuLayersOption.js +16 -4
  92. package/dist/gguf/insights/utils/resolveModelGpuLayersOption.js.map +1 -1
  93. package/dist/gguf/types/GgufMetadataTypes.d.ts +18 -2
  94. package/dist/gguf/types/GgufMetadataTypes.js +6 -0
  95. package/dist/gguf/types/GgufMetadataTypes.js.map +1 -1
  96. package/dist/gguf/types/GgufTensorInfoTypes.d.ts +4 -2
  97. package/dist/gguf/types/GgufTensorInfoTypes.js +11 -1
  98. package/dist/gguf/types/GgufTensorInfoTypes.js.map +1 -1
  99. package/dist/gguf/utils/getGgufFileTypeName.d.ts +1 -1
  100. package/dist/gguf/utils/ggufQuantNames.js +1 -0
  101. package/dist/gguf/utils/ggufQuantNames.js.map +1 -1
  102. package/dist/tsconfig.tsbuildinfo +1 -1
  103. package/dist/types.d.ts +1 -0
  104. package/dist/utils/getBuildDefaults.d.ts +1 -1
  105. package/dist/utils/getFirstWritableDir.d.ts +8 -0
  106. package/dist/utils/getFirstWritableDir.js +60 -0
  107. package/dist/utils/getFirstWritableDir.js.map +1 -0
  108. package/dist/utils/getTempDir.d.ts +10 -0
  109. package/dist/utils/getTempDir.js +121 -0
  110. package/dist/utils/getTempDir.js.map +1 -0
  111. package/dist/utils/prettyPrintObject.js +1 -1
  112. package/dist/utils/prettyPrintObject.js.map +1 -1
  113. package/dist/utils/resolveModelFile.js +19 -8
  114. package/dist/utils/resolveModelFile.js.map +1 -1
  115. package/llama/addon/AddonContext.cpp +182 -0
  116. package/llama/addon/AddonContext.h +27 -0
  117. package/llama/addon/addon.cpp +1 -0
  118. package/llama/binariesGithubRelease.json +1 -1
  119. package/llama/gitRelease.bundle +0 -0
  120. package/llama/llama.cpp.info.json +1 -1
  121. package/package.json +24 -24
@@ -8,9 +8,12 @@ import { UnsupportedError } from "../../utils/UnsupportedError.js";
8
8
  import { pushAll } from "../../utils/pushAll.js";
9
9
  import { safeEventCallback } from "../../utils/safeEventCallback.js";
10
10
  import { GgufArchitectureType } from "../../gguf/types/GgufMetadataTypes.js";
11
+ import { LlamaLogLevel } from "../../bindings/types.js";
12
+ import { resolveGgmlTypeOption } from "../../gguf/types/GgufTensorInfoTypes.js";
11
13
  import { resolveBatchItemsPrioritizationStrategy } from "./utils/resolveBatchItemsPrioritizationStrategy.js";
12
14
  import { LlamaSampler } from "./LlamaSampler.js";
13
15
  import { padSafeContextSize } from "./utils/padSafeContextSize.js";
16
+ import { LlamaContextSequenceCheckpoints } from "./LlamaContextSequenceCheckpoints.js";
14
17
  const defaultLoraScale = 1;
15
18
  const shrinkRetriesMinContextSize = 4096;
16
19
  const defaultMaxPunishTokens = 64;
@@ -20,6 +23,25 @@ const defaultFailedCreationRemedy = {
20
23
  };
21
24
  const defaultEvaluationPriority = 5;
22
25
  const defaultDryRepeatPenalitySequenceBreakers = ["\n", ":", '"', "*"];
26
+ const defaultCheckpointOptions = {
27
+ max: 32,
28
+ interval: 8192,
29
+ maxMemory: null
30
+ };
31
+ export const internalCheckpoints = {
32
+ speculative: {
33
+ name: "speculative",
34
+ maxCheckpoints: 2
35
+ },
36
+ chatSequenceStart: {
37
+ name: "sequenceStart",
38
+ maxCheckpoints: 1
39
+ },
40
+ chatGrammarEnd: {
41
+ name: "grammarEnd",
42
+ maxCheckpoints: 1
43
+ }
44
+ };
23
45
  const decodeSyncWorkaround = {
24
46
  vulkanLock: {}
25
47
  };
@@ -35,6 +57,8 @@ export class LlamaContext {
35
57
  /** @internal */ _idealThreads;
36
58
  /** @internal */ _minThreads;
37
59
  /** @internal */ _performanceTracking;
60
+ /** @internal */ _kvCacheKeyType;
61
+ /** @internal */ _kvCacheValueType;
38
62
  /** @internal */ _totalSequences;
39
63
  /** @internal */ _unusedSequenceIds = [];
40
64
  /** @internal */ _batchingOptions;
@@ -53,7 +77,7 @@ export class LlamaContext {
53
77
  /** @internal */ _allocatedContextSize;
54
78
  /** @internal */ _disposed = false;
55
79
  onDispose = new EventRelay();
56
- constructor({ _model }, { sequences, contextSize, batchSize, flashAttention = _model.defaultContextFlashAttention, threads, batching: { dispatchSchedule: batchingDispatchSchedule = "nextCycle", itemPrioritizationStrategy: batchingItemsPrioritizationStrategy = "maximumParallelism" } = {}, swaFullCache = _model.defaultContextSwaFullCache, performanceTracking = false, _embeddings, _ranking }) {
80
+ constructor({ _model }, { sequences, contextSize, batchSize, flashAttention = _model.defaultContextFlashAttention, threads, batching: { dispatchSchedule: batchingDispatchSchedule = "nextCycle", itemPrioritizationStrategy: batchingItemsPrioritizationStrategy = "maximumParallelism" } = {}, swaFullCache = _model.defaultContextSwaFullCache, performanceTracking = false, experimentalKvCacheKeyType, experimentalKvCacheValueType, _embeddings, _ranking }) {
57
81
  if (_model.disposed)
58
82
  throw new DisposedError();
59
83
  this._llama = _model._llama;
@@ -73,6 +97,8 @@ export class LlamaContext {
73
97
  ? 1
74
98
  : this._llama._threadsSplitter.normalizeThreadsValue(threads?.min ?? 1));
75
99
  this._performanceTracking = !!performanceTracking;
100
+ this._kvCacheKeyType = experimentalKvCacheKeyType;
101
+ this._kvCacheValueType = experimentalKvCacheValueType;
76
102
  this._swaFullCache = !!swaFullCache;
77
103
  this._ctx = new this._llama._bindings.AddonContext(this._model._model, removeNullFields({
78
104
  contextSize: padSafeContextSize(this._contextSize * this._totalSequences, "up"), // each sequence needs its own <contextSize> of cells
@@ -85,6 +111,8 @@ export class LlamaContext {
85
111
  embeddings: _embeddings,
86
112
  ranking: _ranking,
87
113
  performanceTracking: this._performanceTracking,
114
+ kvCacheKeyType: this._kvCacheKeyType,
115
+ kvCacheValueType: this._kvCacheValueType,
88
116
  swaFullCache: this._swaFullCache
89
117
  }));
90
118
  this._batchingOptions = {
@@ -130,6 +158,12 @@ export class LlamaContext {
130
158
  get flashAttention() {
131
159
  return this._flashAttention;
132
160
  }
161
+ get kvCacheKeyType() {
162
+ return this._kvCacheKeyType;
163
+ }
164
+ get kvCacheValueType() {
165
+ return this._kvCacheValueType;
166
+ }
133
167
  /**
134
168
  * The actual size of the state in the memory in bytes.
135
169
  * This value is provided by `llama.cpp` and doesn't include all the memory overhead of the context.
@@ -168,7 +202,7 @@ export class LlamaContext {
168
202
  * When there are no sequences left, this method will throw an error.
169
203
  */
170
204
  getSequence(options = {}) {
171
- const { contextShift: { size: contextShiftSize = Math.min(100, Math.ceil(this.contextSize / 2)), strategy: contextShiftStrategy = "eraseBeginning" } = {}, tokenPredictor, _tokenMeter } = options;
205
+ const { contextShift: { size: contextShiftSize = Math.min(100, Math.ceil(this.contextSize / 2)), strategy: contextShiftStrategy = "eraseBeginning" } = {}, tokenPredictor, checkpoints, _tokenMeter } = options;
172
206
  this._ensureNotDisposed();
173
207
  const nextSequenceId = this._popSequenceId();
174
208
  if (nextSequenceId == null)
@@ -181,7 +215,8 @@ export class LlamaContext {
181
215
  size: contextShiftSize,
182
216
  strategy: contextShiftStrategy
183
217
  },
184
- tokenPredictor
218
+ tokenPredictor,
219
+ checkpoints
185
220
  });
186
221
  }
187
222
  dispatchPendingBatch() {
@@ -293,16 +328,18 @@ export class LlamaContext {
293
328
  batchLogitIndexes,
294
329
  batchLogitTokenIndexes: tokenIndexesWithLogitsToProcess,
295
330
  firstTokenIndex: queuedDecode.firstTokenSequenceIndex,
331
+ sequenceStateLength: queuedDecode.firstTokenSequenceIndex + processAmount + 1,
296
332
  returnResults: true
297
333
  });
298
334
  }
299
335
  else {
300
- if (batchLogitIndexes.length > 0)
336
+ if (batchLogitIndexes.length > 0 || queuedDecode.afterBatchAction != null)
301
337
  afterDecodeActions.push({
302
338
  queuedDecode,
303
339
  batchLogitIndexes,
304
340
  batchLogitTokenIndexes: tokenIndexesWithLogitsToProcess,
305
- firstTokenIndex: queuedDecode.firstTokenSequenceIndex
341
+ firstTokenIndex: queuedDecode.firstTokenSequenceIndex,
342
+ sequenceStateLength: queuedDecode.firstTokenSequenceIndex + processAmount + 1
306
343
  });
307
344
  queuedDecode.tokens = queuedDecode.tokens.slice(processAmount);
308
345
  queuedDecode.logits = queuedDecode.logits.slice(processAmount);
@@ -378,6 +415,11 @@ export class LlamaContext {
378
415
  return undefined;
379
416
  });
380
417
  await Promise.all(afterDecodeActionResults);
418
+ for (const action of afterDecodeActions) {
419
+ const resPromise = action.queuedDecode.afterBatchAction?.(action.sequenceStateLength);
420
+ if (resPromise instanceof Promise)
421
+ await resPromise;
422
+ }
381
423
  };
382
424
  const prioritizationStrategy = resolvePrioritizationStrategy();
383
425
  if (prioritizationStrategy == null)
@@ -432,7 +474,7 @@ export class LlamaContext {
432
474
  await new Promise((accept) => setTimeout(accept, 0)); // wait for the logs to finish printing
433
475
  }
434
476
  /** @internal */
435
- async _decodeTokens({ sequenceId, firstTokenSequenceIndex, tokens, logits, evaluationPriority = defaultEvaluationPriority, tokenMeter }, logitDataMapper) {
477
+ async _decodeTokens({ sequenceId, firstTokenSequenceIndex, tokens, logits, evaluationPriority = defaultEvaluationPriority, tokenMeter, afterBatchAction }, logitDataMapper) {
436
478
  return await new Promise((accept, reject) => {
437
479
  this._queuedDecodes.push({
438
480
  sequenceId,
@@ -442,7 +484,8 @@ export class LlamaContext {
442
484
  evaluationPriority,
443
485
  tokenMeter,
444
486
  response: [accept, reject],
445
- logitDataMapper
487
+ logitDataMapper,
488
+ afterBatchAction
446
489
  });
447
490
  this._queuedDecodeSequenceIds.add(sequenceId);
448
491
  this._scheduleDecode();
@@ -568,6 +611,12 @@ export class LlamaContext {
568
611
  const flashAttention = _model.flashAttentionSupported
569
612
  ? Boolean(options.flashAttention ?? _model.defaultContextFlashAttention)
570
613
  : false;
614
+ const kvCacheKeyType = options.experimentalKvCacheKeyType === "currentQuant"
615
+ ? _model.fileInsights.dominantTensorType ?? _model.defaultContextKvCacheKeyType
616
+ : resolveGgmlTypeOption(options.experimentalKvCacheKeyType) ?? _model.defaultContextKvCacheKeyType;
617
+ const kvCacheValueType = options.experimentalKvCacheValueType === "currentQuant"
618
+ ? _model.fileInsights.dominantTensorType ?? _model.defaultContextKvCacheValueType
619
+ : resolveGgmlTypeOption(options.experimentalKvCacheValueType) ?? _model.defaultContextKvCacheValueType;
571
620
  const swaFullCache = options.swaFullCache ?? _model.defaultContextSwaFullCache;
572
621
  const loraOptions = typeof options.lora === "string"
573
622
  ? { adapters: [{ filePath: options.lora }] }
@@ -584,6 +633,8 @@ export class LlamaContext {
584
633
  modelGpuLayers: _model.gpuLayers,
585
634
  modelTrainContextSize: _model.trainContextSize,
586
635
  flashAttention,
636
+ kvCacheKeyType,
637
+ kvCacheValueType,
587
638
  swaFullCache,
588
639
  getVramState: () => _model._llama._vramOrchestrator.getMemoryState(),
589
640
  llamaGpu: _model._llama.gpu,
@@ -612,9 +663,20 @@ export class LlamaContext {
612
663
  modelGpuLayers: _model.gpuLayers,
613
664
  batchSize,
614
665
  flashAttention,
666
+ kvCacheKeyType,
667
+ kvCacheValueType,
668
+ swaFullCache
669
+ });
670
+ const context = new LlamaContext({ _model }, {
671
+ ...options,
672
+ contextSize,
673
+ batchSize,
674
+ sequences,
675
+ flashAttention,
676
+ experimentalKvCacheKeyType: kvCacheKeyType,
677
+ experimentalKvCacheValueType: kvCacheValueType,
615
678
  swaFullCache
616
679
  });
617
- const context = new LlamaContext({ _model }, { ...options, contextSize, batchSize, sequences, flashAttention, swaFullCache });
618
680
  const contextCreationVramReservation = options.ignoreMemorySafetyChecks
619
681
  ? null
620
682
  : _model._llama._vramOrchestrator.reserveMemory(resourceRequirementsEstimation.gpuVram);
@@ -697,6 +759,8 @@ export class LlamaContextSequence {
697
759
  /** @internal */ _context;
698
760
  /** @internal */ _contextShift;
699
761
  /** @internal */ _tokenPredictor;
762
+ /** @internal */ _checkpoints = new LlamaContextSequenceCheckpoints();
763
+ /** @internal */ _checkpointOptions;
700
764
  /** @internal */ _tokenMeter;
701
765
  /** @internal */ _disposeAggregator = new DisposeAggregator();
702
766
  /** @internal */ _lock = {};
@@ -711,22 +775,29 @@ export class LlamaContextSequence {
711
775
  /** @internal */ _refutedTokenPredictions = 0;
712
776
  /** @internal */ _disposed = false;
713
777
  onDispose = new EventRelay();
714
- constructor({ sequenceId, context, tokenMeter, contextShift, tokenPredictor }) {
778
+ constructor({ sequenceId, context, tokenMeter, contextShift, tokenPredictor, checkpoints }) {
715
779
  this._sequenceId = sequenceId;
716
780
  this._context = context;
717
781
  this._tokenMeter = tokenMeter ?? new TokenMeter();
718
782
  this._contextShift = contextShift;
719
783
  this._tokenPredictor = tokenPredictor;
784
+ this._checkpointOptions = {
785
+ max: checkpoints?.max ?? defaultCheckpointOptions.max,
786
+ interval: checkpoints?.interval ?? defaultCheckpointOptions.interval,
787
+ maxMemory: checkpoints?.maxMemory ?? defaultCheckpointOptions.maxMemory
788
+ };
720
789
  this._gcRegistry = new FinalizationRegistry(this._context._reclaimUnusedSequenceId);
721
790
  this._gcRegistry.register(this, sequenceId);
722
791
  this._disposeAggregator.add(() => this._gcRegistry.unregister(this));
723
792
  this._disposeAggregator.add(this.onDispose.dispatchEvent);
724
793
  this._disposeAggregator.add(this.model.onDispose.createListener(disposeContextSequenceIfReferenced.bind(null, new WeakRef(this))));
725
794
  this._disposeAggregator.add(() => {
795
+ this._checkpoints.clearAllCheckpoints();
726
796
  this._context._reclaimUnusedSequenceId(this._sequenceId);
727
797
  });
728
798
  if (this._tokenPredictor != null)
729
799
  this._disposeAggregator.add(this._tokenPredictor);
800
+ this._takeIntervalCheckpointIfNeededAfterBatch = this._takeIntervalCheckpointIfNeededAfterBatch.bind(this);
730
801
  }
731
802
  dispose() {
732
803
  if (this._disposed)
@@ -892,7 +963,7 @@ export class LlamaContextSequence {
892
963
  /** @internal */
893
964
  async _eraseContextTokenRanges(ranges, { canResetTokenPredictor = true, canRemovePredictionTokens = true, skipLock = false } = {}) {
894
965
  this._ensureNotDisposed();
895
- let awaitPromise;
966
+ let awaitEvaluationPromise;
896
967
  await withLock([this._context, "context"], async () => {
897
968
  this._ensureNotDisposed();
898
969
  if (ranges.length === 0)
@@ -968,16 +1039,39 @@ export class LlamaContextSequence {
968
1039
  this._nextTokenIndex -= removedTokens;
969
1040
  if (canResetTokenPredictor && removedTokens > 0)
970
1041
  await this._abortTokenPredictor(true);
1042
+ this._checkpoints.pruneFromEndToIndex(this._contextTokens.length - 1);
971
1043
  if (deletionSuccessful)
972
1044
  return;
1045
+ let restoreCheckpointIndex = this._contextTokens.length - 1;
1046
+ const existingCheckpoint = this._checkpoints.getLastCheckpoint(restoreCheckpointIndex, this.contextSize);
1047
+ if (existingCheckpoint != null &&
1048
+ restoreCheckpointIndex >= existingCheckpoint.minPos &&
1049
+ existingCheckpoint.maxPos <= this.contextSize) {
1050
+ restoreCheckpointIndex = Math.min(restoreCheckpointIndex, existingCheckpoint.maxPos);
1051
+ const restoredSuccessfully = await this._context._ctx.restoreCheckpoint(existingCheckpoint, restoreCheckpointIndex);
1052
+ if (restoredSuccessfully) {
1053
+ const tokensToEvaluate = this._contextTokens.slice(restoreCheckpointIndex + 1);
1054
+ this._contextTokens = this._contextTokens.slice(0, restoreCheckpointIndex + 1);
1055
+ this._nextTokenIndex = restoreCheckpointIndex + 1;
1056
+ // wait for the evaluation outside the "context" lock to avoid deadlocks
1057
+ if (tokensToEvaluate.length > 0)
1058
+ awaitEvaluationPromise = this.evaluateWithoutGeneratingNewTokens(tokensToEvaluate, { _skipLock: skipLock });
1059
+ return;
1060
+ }
1061
+ }
973
1062
  const newSequenceTokens = this._contextTokens.slice();
974
1063
  this._nextTokenIndex = 0;
975
1064
  this._context._ctx.disposeSequence(this._sequenceId);
1065
+ this._contextTokens = [];
976
1066
  // wait for the evaluation outside the "context" lock to avoid deadlocks
977
- awaitPromise = this.evaluateWithoutGeneratingNewTokens(newSequenceTokens, { _skipLock: skipLock });
1067
+ if (newSequenceTokens.length > 0)
1068
+ awaitEvaluationPromise = this.evaluateWithoutGeneratingNewTokens(newSequenceTokens, { _skipLock: skipLock });
978
1069
  });
979
- if (awaitPromise != null)
980
- await awaitPromise;
1070
+ if (awaitEvaluationPromise != null) {
1071
+ await awaitEvaluationPromise;
1072
+ if (this.needsCheckpoints && this._checkpoints.lastCheckpointIndex !== this._nextTokenIndex - 1)
1073
+ await this.takeCheckpoint();
1074
+ }
981
1075
  }
982
1076
  /**
983
1077
  * Evaluate the provided tokens into the context sequence, and continue generating new tokens on iterator iterations.
@@ -1168,7 +1262,7 @@ export class LlamaContextSequence {
1168
1262
  onTokenResult?.(tokenIndex, output);
1169
1263
  return output;
1170
1264
  });
1171
- });
1265
+ }, this._takeIntervalCheckpointIfNeededAfterBatch);
1172
1266
  }
1173
1267
  finally {
1174
1268
  evaluatorLock.dispose();
@@ -1188,6 +1282,7 @@ export class LlamaContextSequence {
1188
1282
  const contextLock = await acquireLock([this._context, "context"]);
1189
1283
  try {
1190
1284
  this._ensureNotDisposed();
1285
+ // TODO: save checkpoints to disk
1191
1286
  const fileSize = await this._context._ctx.saveSequenceStateToFile(resolvedPath, this._sequenceId, Uint32Array.from(this.contextTokens));
1192
1287
  return { fileSize };
1193
1288
  }
@@ -1235,6 +1330,97 @@ export class LlamaContextSequence {
1235
1330
  evaluatorLock.dispose();
1236
1331
  }
1237
1332
  }
1333
+ /**
1334
+ * When reusing a prefix evaluation state is not possible for the current context sequence
1335
+ * (like in contexts from recurrent and hybrid models,
1336
+ * or with models that use SWA (Sliding Window Attention) when the `swaFullCache` option is not enabled on the context),
1337
+ * you can use this method to checkpoint the current context sequence state.
1338
+ * Those checkpoints will automatically be used when trying to erase parts of the context state that come after a checkpointed state,
1339
+ * and be freed from memory when no longer relevant.
1340
+ *
1341
+ * Those checkpoints are relatively lightweight compared to saving the entire state,
1342
+ * but taking too many checkpoints can increase memory usage.
1343
+ * Checkpoints are stored in the RAM (not VRAM).
1344
+ *
1345
+ * Calling this method on a context sequence from a model that natively supports prefix evaluation state reuse will have no effect.
1346
+ *
1347
+ * > **Note:** to check whether the current context sequence needs taking checkpoints,
1348
+ * > you can use the {@link needsCheckpoints `.needsCheckpoints`} property.
1349
+ */
1350
+ async takeCheckpoint() {
1351
+ if (!this.needsCheckpoints)
1352
+ return;
1353
+ return await withLock([this._context, "context"], () => {
1354
+ return this._takeCheckpoint(undefined, this._checkpointOptions.max);
1355
+ });
1356
+ }
1357
+ /** @internal */
1358
+ async _takeNamedCheckpoint(name, maxNamedCheckpoints) {
1359
+ if (!this.needsCheckpoints)
1360
+ return;
1361
+ return await withLock([this._context, "context"], () => {
1362
+ return this._takeCheckpoint(name, maxNamedCheckpoints);
1363
+ });
1364
+ }
1365
+ /**
1366
+ * Whether the current context sequence needs taking checkpoints of the context state to be able to reuse
1367
+ * it as a prefix evaluation state in the future.
1368
+ *
1369
+ * See {@link takeCheckpoint `.takeCheckpoint()`} for more details.
1370
+ */
1371
+ get needsCheckpoints() {
1372
+ if (this.model.fileInsights.isHybrid || this.model.fileInsights.isRecurrent)
1373
+ return true;
1374
+ else if (this.model.fileInsights.swaSize != null && !this._context._swaFullCache)
1375
+ return true;
1376
+ return false;
1377
+ }
1378
+ /**
1379
+ * The index of the last taken checkpoint that's available for prefix reuse
1380
+ */
1381
+ get lastCheckpointIndex() {
1382
+ return Math.max(0, Math.min(this._checkpoints.lastCheckpointIndex, this.nextTokenIndex - 1));
1383
+ }
1384
+ /**
1385
+ * The total memory usage in bytes of all the checkpoints currently held for this context sequence
1386
+ */
1387
+ get checkpointsMemoryUsage() {
1388
+ return this._checkpoints.memoryUsage;
1389
+ }
1390
+ /** @internal */
1391
+ async _takeCheckpoint(name, maxNamedCheckpoints) {
1392
+ if (!this.needsCheckpoints || this._nextTokenIndex === 0 || this._checkpoints.hasCheckpoint(name, this._nextTokenIndex - 1))
1393
+ return;
1394
+ if (this._checkpointOptions.maxMemory != null)
1395
+ this._checkpoints.prepareMemoryForIncomingCheckpoint(this._checkpointOptions.maxMemory);
1396
+ const checkpoint = new this.model._llama._bindings.AddonContextSequenceCheckpoint();
1397
+ await checkpoint.init(this._context._ctx, this._sequenceId);
1398
+ if (this._nextTokenIndex - 1 !== checkpoint.maxPos)
1399
+ this.model._llama._log(LlamaLogLevel.warn, `Checkpoint max position mismatch: expected ${this._nextTokenIndex - 1}, got ${checkpoint.maxPos}`);
1400
+ this._checkpoints.storeCheckpoint({
1401
+ name,
1402
+ maxNamedCheckpoints,
1403
+ checkpoint,
1404
+ currentMaxPos: checkpoint.maxPos
1405
+ });
1406
+ if (this._checkpointOptions.maxMemory != null)
1407
+ this._checkpoints.pruneToKeepUnderMemoryUsage(this._checkpointOptions.maxMemory);
1408
+ }
1409
+ /** @internal */
1410
+ _takeIntervalCheckpointIfNeeded(currentIndex = this._nextTokenIndex - 1) {
1411
+ if (!this.needsCheckpoints)
1412
+ return;
1413
+ const lastCheckpointIndex = this._checkpoints.getLastNamedCheckpointIndex(undefined);
1414
+ if (this._checkpointOptions.interval === false || currentIndex - lastCheckpointIndex < this._checkpointOptions.interval)
1415
+ return;
1416
+ return this._takeCheckpoint(undefined, this._checkpointOptions.max);
1417
+ }
1418
+ /** @internal */
1419
+ _takeIntervalCheckpointIfNeededAfterBatch(sequenceStateLength) {
1420
+ if (sequenceStateLength === 0)
1421
+ return;
1422
+ return this._takeIntervalCheckpointIfNeeded(sequenceStateLength - 1);
1423
+ }
1238
1424
  /** @internal */
1239
1425
  async *_evaluate(tokens, metadata, { temperature, minP, topK, topP, seed, xtc, grammarEvaluationState, repeatPenalty, dryRepeatPenalty, tokenBias, evaluationPriority = defaultEvaluationPriority, generateNewTokens = true, contextShiftOptions, yieldEogToken = false, _noSampling = false, _skipLock = false }) {
1240
1426
  this._ensureNotDisposed();
@@ -1282,7 +1468,7 @@ export class LlamaContextSequence {
1282
1468
  else
1283
1469
  return this._context._ctx.sampleToken(batchLogitIndex, sampler._sampler);
1284
1470
  });
1285
- });
1471
+ }, this._takeIntervalCheckpointIfNeededAfterBatch);
1286
1472
  const lastDecodeResult = decodeResult[evalTokens.length - 1];
1287
1473
  if (lastDecodeResult instanceof Array) {
1288
1474
  const [token, probabilities, confidence] = lastDecodeResult;
@@ -1366,6 +1552,14 @@ export class LlamaContextSequence {
1366
1552
  const deleteStartIndex = Math.max(0, this._nextTokenIndex - this._loadedTokenPredictions.length);
1367
1553
  await this._eraseContextTokenRanges([{ start: deleteStartIndex, end: this._nextTokenIndex }], { canResetTokenPredictor: true, canRemovePredictionTokens: true, skipLock: true });
1368
1554
  this._loadedTokenPredictions.length = 0;
1555
+ if (this.needsCheckpoints) {
1556
+ await this._takeCheckpoint(internalCheckpoints.speculative.name, internalCheckpoints.speculative.maxCheckpoints);
1557
+ await this._takeIntervalCheckpointIfNeeded();
1558
+ }
1559
+ }
1560
+ else if (this._tokenPredictorOwner === tokenPredictorOwner && this.needsCheckpoints) {
1561
+ await this._takeCheckpoint(internalCheckpoints.speculative.name, internalCheckpoints.speculative.maxCheckpoints);
1562
+ await this._takeIntervalCheckpointIfNeeded();
1369
1563
  }
1370
1564
  if (this._resetTokenPredictor) {
1371
1565
  await tokenPredictor.reset({
@@ -1578,7 +1772,7 @@ export class LlamaContextSequence {
1578
1772
  * The caller of this function has to wrap it with a lock to ensure this function doesn't run concurrently.
1579
1773
  * @internal
1580
1774
  */
1581
- async _decodeTokens(tokens, logits, evaluationPriority, tokenMeter, contextShiftOptions, logitDataMapper) {
1775
+ async _decodeTokens(tokens, logits, evaluationPriority, tokenMeter, contextShiftOptions, logitDataMapper, afterBatchAction) {
1582
1776
  this._ensureNotDisposed();
1583
1777
  const tokensLeftToDecode = tokens.slice();
1584
1778
  const tokenLogitsLeftToDecode = logits.slice();
@@ -1604,7 +1798,8 @@ export class LlamaContextSequence {
1604
1798
  firstTokenSequenceIndex: this._nextTokenIndex,
1605
1799
  logits: tokensLogits,
1606
1800
  evaluationPriority,
1607
- tokenMeter
1801
+ tokenMeter,
1802
+ afterBatchAction
1608
1803
  }, normalizedLogitDataMapper);
1609
1804
  for (const [index, value] of generatedLogits)
1610
1805
  res[currentTokenIndex + (index - this._nextTokenIndex)] = value;
@@ -1648,7 +1843,7 @@ export class LlamaContextSequence {
1648
1843
  * We need this to make it impossible to manually create instances of this class outside the code of this library
1649
1844
  * @internal
1650
1845
  */
1651
- static _create({ sequenceId, context, tokenMeter, contextShift: { size: contextShiftSize = Math.min(100, Math.ceil(context.contextSize / 2)), strategy: contextShiftStrategy = "eraseBeginning" } = {}, tokenPredictor }) {
1846
+ static _create({ sequenceId, context, tokenMeter, contextShift: { size: contextShiftSize = Math.min(100, Math.ceil(context.contextSize / 2)), strategy: contextShiftStrategy = "eraseBeginning" } = {}, tokenPredictor, checkpoints }) {
1652
1847
  return new LlamaContextSequence({
1653
1848
  sequenceId,
1654
1849
  context,
@@ -1657,7 +1852,8 @@ export class LlamaContextSequence {
1657
1852
  size: contextShiftSize,
1658
1853
  strategy: contextShiftStrategy
1659
1854
  },
1660
- tokenPredictor
1855
+ tokenPredictor,
1856
+ checkpoints
1661
1857
  });
1662
1858
  }
1663
1859
  }