@seanhogg/builderforce-memory 2026.6.18
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +582 -0
- package/dist/agent/SSMAgent.d.ts +146 -0
- package/dist/agent/SSMAgent.d.ts.map +1 -0
- package/dist/agent/SSMAgent.js +231 -0
- package/dist/agent/SSMAgent.js.map +1 -0
- package/dist/agent/index.d.ts +3 -0
- package/dist/agent/index.d.ts.map +1 -0
- package/dist/agent/index.js +2 -0
- package/dist/agent/index.js.map +1 -0
- package/dist/bridges/AnthropicBridge.d.ts +47 -0
- package/dist/bridges/AnthropicBridge.d.ts.map +1 -0
- package/dist/bridges/AnthropicBridge.js +120 -0
- package/dist/bridges/AnthropicBridge.js.map +1 -0
- package/dist/bridges/CachingBridge.d.ts +44 -0
- package/dist/bridges/CachingBridge.d.ts.map +1 -0
- package/dist/bridges/CachingBridge.js +62 -0
- package/dist/bridges/CachingBridge.js.map +1 -0
- package/dist/bridges/FetchBridge.d.ts +30 -0
- package/dist/bridges/FetchBridge.d.ts.map +1 -0
- package/dist/bridges/FetchBridge.js +24 -0
- package/dist/bridges/FetchBridge.js.map +1 -0
- package/dist/bridges/OpenAIBridge.d.ts +33 -0
- package/dist/bridges/OpenAIBridge.d.ts.map +1 -0
- package/dist/bridges/OpenAIBridge.js +110 -0
- package/dist/bridges/OpenAIBridge.js.map +1 -0
- package/dist/bridges/ResponseCache.d.ts +65 -0
- package/dist/bridges/ResponseCache.d.ts.map +1 -0
- package/dist/bridges/ResponseCache.js +97 -0
- package/dist/bridges/ResponseCache.js.map +1 -0
- package/dist/bridges/SemanticCachingBridge.d.ts +31 -0
- package/dist/bridges/SemanticCachingBridge.d.ts.map +1 -0
- package/dist/bridges/SemanticCachingBridge.js +44 -0
- package/dist/bridges/SemanticCachingBridge.js.map +1 -0
- package/dist/bridges/TransformerBridge.d.ts +35 -0
- package/dist/bridges/TransformerBridge.d.ts.map +1 -0
- package/dist/bridges/TransformerBridge.js +10 -0
- package/dist/bridges/TransformerBridge.js.map +1 -0
- package/dist/bridges/index.d.ts +14 -0
- package/dist/bridges/index.d.ts.map +1 -0
- package/dist/bridges/index.js +7 -0
- package/dist/bridges/index.js.map +1 -0
- package/dist/cache/FetchSemanticCacheBackend.d.ts +40 -0
- package/dist/cache/FetchSemanticCacheBackend.d.ts.map +1 -0
- package/dist/cache/FetchSemanticCacheBackend.js +61 -0
- package/dist/cache/FetchSemanticCacheBackend.js.map +1 -0
- package/dist/cache/SemanticCache.d.ts +105 -0
- package/dist/cache/SemanticCache.d.ts.map +1 -0
- package/dist/cache/SemanticCache.js +130 -0
- package/dist/cache/SemanticCache.js.map +1 -0
- package/dist/cache/index.d.ts +5 -0
- package/dist/cache/index.d.ts.map +1 -0
- package/dist/cache/index.js +3 -0
- package/dist/cache/index.js.map +1 -0
- package/dist/distillation/DistillationEngine.d.ts +107 -0
- package/dist/distillation/DistillationEngine.d.ts.map +1 -0
- package/dist/distillation/DistillationEngine.js +152 -0
- package/dist/distillation/DistillationEngine.js.map +1 -0
- package/dist/distillation/index.d.ts +3 -0
- package/dist/distillation/index.d.ts.map +1 -0
- package/dist/distillation/index.js +2 -0
- package/dist/distillation/index.js.map +1 -0
- package/dist/errors/SSMError.d.ts +14 -0
- package/dist/errors/SSMError.d.ts.map +1 -0
- package/dist/errors/SSMError.js +18 -0
- package/dist/errors/SSMError.js.map +1 -0
- package/dist/errors/index.d.ts +3 -0
- package/dist/errors/index.d.ts.map +1 -0
- package/dist/errors/index.js +2 -0
- package/dist/errors/index.js.map +1 -0
- package/dist/index.d.ts +65 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +59 -0
- package/dist/index.js.map +1 -0
- package/dist/memory/MemoryStore.d.ts +152 -0
- package/dist/memory/MemoryStore.d.ts.map +1 -0
- package/dist/memory/MemoryStore.js +290 -0
- package/dist/memory/MemoryStore.js.map +1 -0
- package/dist/memory/index.d.ts +3 -0
- package/dist/memory/index.d.ts.map +1 -0
- package/dist/memory/index.js +2 -0
- package/dist/memory/index.js.map +1 -0
- package/dist/router/InferenceRouter.d.ts +92 -0
- package/dist/router/InferenceRouter.d.ts.map +1 -0
- package/dist/router/InferenceRouter.js +113 -0
- package/dist/router/InferenceRouter.js.map +1 -0
- package/dist/router/index.d.ts +3 -0
- package/dist/router/index.d.ts.map +1 -0
- package/dist/router/index.js +2 -0
- package/dist/router/index.js.map +1 -0
- package/dist/runtime/SSMRuntime.d.ts +167 -0
- package/dist/runtime/SSMRuntime.d.ts.map +1 -0
- package/dist/runtime/SSMRuntime.js +199 -0
- package/dist/runtime/SSMRuntime.js.map +1 -0
- package/dist/runtime/index.d.ts +3 -0
- package/dist/runtime/index.d.ts.map +1 -0
- package/dist/runtime/index.js +2 -0
- package/dist/runtime/index.js.map +1 -0
- package/dist/session/errors.d.ts +10 -0
- package/dist/session/errors.d.ts.map +1 -0
- package/dist/session/errors.js +14 -0
- package/dist/session/errors.js.map +1 -0
- package/dist/session/index.d.ts +11 -0
- package/dist/session/index.d.ts.map +1 -0
- package/dist/session/index.js +7 -0
- package/dist/session/index.js.map +1 -0
- package/dist/session/persistence.d.ts +14 -0
- package/dist/session/persistence.d.ts.map +1 -0
- package/dist/session/persistence.js +100 -0
- package/dist/session/persistence.js.map +1 -0
- package/dist/session/presets.d.ts +31 -0
- package/dist/session/presets.d.ts.map +1 -0
- package/dist/session/presets.js +91 -0
- package/dist/session/presets.js.map +1 -0
- package/dist/session/session.d.ts +186 -0
- package/dist/session/session.d.ts.map +1 -0
- package/dist/session/session.js +358 -0
- package/dist/session/session.js.map +1 -0
- package/dist/session/streaming.d.ts +13 -0
- package/dist/session/streaming.d.ts.map +1 -0
- package/dist/session/streaming.js +74 -0
- package/dist/session/streaming.js.map +1 -0
- package/dist/session/tokenizer.d.ts +18 -0
- package/dist/session/tokenizer.d.ts.map +1 -0
- package/dist/session/tokenizer.js +11 -0
- package/dist/session/tokenizer.js.map +1 -0
- package/dist/similarity/index.d.ts +19 -0
- package/dist/similarity/index.d.ts.map +1 -0
- package/dist/similarity/index.js +42 -0
- package/dist/similarity/index.js.map +1 -0
- package/package.json +120 -0
- package/src/agent/SSMAgent.ts +327 -0
- package/src/agent/index.ts +2 -0
- package/src/bridges/AnthropicBridge.ts +166 -0
- package/src/bridges/CachingBridge.ts +79 -0
- package/src/bridges/FetchBridge.ts +41 -0
- package/src/bridges/OpenAIBridge.ts +143 -0
- package/src/bridges/ResponseCache.ts +131 -0
- package/src/bridges/SemanticCachingBridge.ts +60 -0
- package/src/bridges/TransformerBridge.ts +38 -0
- package/src/bridges/index.ts +13 -0
- package/src/cache/FetchSemanticCacheBackend.ts +79 -0
- package/src/cache/SemanticCache.ts +196 -0
- package/src/cache/index.ts +9 -0
- package/src/distillation/DistillationEngine.ts +248 -0
- package/src/distillation/index.ts +2 -0
- package/src/errors/SSMError.ts +26 -0
- package/src/errors/index.ts +2 -0
- package/src/index.ts +128 -0
- package/src/memory/MemoryStore.ts +408 -0
- package/src/memory/index.ts +2 -0
- package/src/router/InferenceRouter.ts +201 -0
- package/src/router/index.ts +2 -0
- package/src/runtime/SSMRuntime.ts +309 -0
- package/src/runtime/index.ts +2 -0
- package/src/session/errors.ts +24 -0
- package/src/session/index.ts +25 -0
- package/src/session/persistence.ts +142 -0
- package/src/session/presets.ts +122 -0
- package/src/session/session.ts +657 -0
- package/src/session/streaming.ts +97 -0
- package/src/session/tokenizer.ts +18 -0
- package/src/similarity/index.ts +42 -0
|
@@ -0,0 +1,657 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* session.ts – MambaSession: the single entry point for all session-layer functionality.
|
|
3
|
+
*
|
|
4
|
+
* MambaSession is a facade over MambaCode.js that collapses the 8-step
|
|
5
|
+
* async setup sequence into a single `MambaSession.create()` call.
|
|
6
|
+
*
|
|
7
|
+
* Part of the @seanhogg/ssmjs session layer.
|
|
8
|
+
*/
|
|
9
|
+
|
|
10
|
+
import {
|
|
11
|
+
initWebGPU,
|
|
12
|
+
BPETokenizer,
|
|
13
|
+
HybridMambaModel,
|
|
14
|
+
MambaTrainer,
|
|
15
|
+
type HybridMambaModelConfig,
|
|
16
|
+
type LayerSpec,
|
|
17
|
+
type LayerType,
|
|
18
|
+
} from '@seanhogg/builderforce-memory-engine';
|
|
19
|
+
|
|
20
|
+
import { SessionError } from './errors.js';
|
|
21
|
+
import type { Tokenizer } from './tokenizer.js';
|
|
22
|
+
|
|
23
|
+
// ── Opinionated defaults ───────────────────────────────────────────────────────
|
|
24
|
+
|
|
25
|
+
/** Default tokenizer: Qwen2.5-Coder BPE vocabulary (Apache 2.0). */
|
|
26
|
+
const DEFAULT_VOCAB_URL = 'https://huggingface.co/Qwen/Qwen2.5-Coder-1.5B-Instruct/resolve/main/vocab.json';
|
|
27
|
+
const DEFAULT_MERGES_URL = 'https://huggingface.co/Qwen/Qwen2.5-Coder-1.5B-Instruct/resolve/main/merges.txt';
|
|
28
|
+
|
|
29
|
+
// Checkpoint provenance
|
|
30
|
+
// ----------------------
|
|
31
|
+
// There is intentionally no hard-coded default checkpoint URL. A model with no
|
|
32
|
+
// checkpoint starts from initialised weights and generates poorly until trained
|
|
33
|
+
// or adapted — callers must opt in to a checkpoint explicitly via one of:
|
|
34
|
+
// • `checkpointUrl` — fetch a hosted .bin (browser or any fetch-capable env)
|
|
35
|
+
// • `checkpointBuffer` — load an already-read ArrayBuffer (Node/local files;
|
|
36
|
+
// fetch() cannot read local paths in Node)
|
|
37
|
+
// To produce a checkpoint without a hosted model:
|
|
38
|
+
// • Untrained (deterministic seed): `node tools/generate-bin.js --size nano`
|
|
39
|
+
// • Trained: open tools/pretrain.html, train on a corpus, download the .bin
|
|
40
|
+
import { resolveModelConfig } from './presets.js';
|
|
41
|
+
import {
|
|
42
|
+
saveToIndexedDB,
|
|
43
|
+
loadFromIndexedDB,
|
|
44
|
+
triggerDownload,
|
|
45
|
+
saveViaFileSystemAPI,
|
|
46
|
+
loadViaFileSystemAPI,
|
|
47
|
+
} from './persistence.js';
|
|
48
|
+
import { tokenStream } from './streaming.js';
|
|
49
|
+
|
|
50
|
+
// ── Public type definitions ────────────────────────────────────────────────────
|
|
51
|
+
|
|
52
|
+
export interface MambaSessionOptions {
|
|
53
|
+
/** URL to a .bin checkpoint file. Optional — model starts with random weights if omitted. */
|
|
54
|
+
checkpointUrl? : string;
|
|
55
|
+
/**
|
|
56
|
+
* Pre-read checkpoint bytes. Takes precedence over `checkpointUrl` and is
|
|
57
|
+
* loaded directly via `model.loadWeights()` with no network fetch.
|
|
58
|
+
* Use this in Node.js, where `fetch()` cannot read local file paths — read
|
|
59
|
+
* the .bin with `fs` and pass the resulting ArrayBuffer here.
|
|
60
|
+
*/
|
|
61
|
+
checkpointBuffer?: ArrayBuffer;
|
|
62
|
+
/** URL to vocab.json. Defaults to the Qwen2.5-Coder vocabulary. */
|
|
63
|
+
vocabUrl? : string;
|
|
64
|
+
/** URL to merges.txt. Defaults to the Qwen2.5-Coder merge rules. */
|
|
65
|
+
mergesUrl? : string;
|
|
66
|
+
/** In-memory vocabulary object — alternative to vocabUrl. */
|
|
67
|
+
vocabObject? : Record<string, number>;
|
|
68
|
+
/** In-memory merges array — alternative to mergesUrl. */
|
|
69
|
+
mergesArray? : string[];
|
|
70
|
+
/**
|
|
71
|
+
* A fully-constructed tokenizer to use instead of BPETokenizer.
|
|
72
|
+
* When provided, `vocabUrl`, `mergesUrl`, `vocabObject`, and `mergesArray`
|
|
73
|
+
* are all ignored — MambaSession will use this tokenizer directly.
|
|
74
|
+
* The tokenizer must satisfy the `Tokenizer` interface.
|
|
75
|
+
*/
|
|
76
|
+
tokenizer? : Tokenizer;
|
|
77
|
+
/** Unique name for this session, used as the IndexedDB key. Default: 'default'. */
|
|
78
|
+
name? : string;
|
|
79
|
+
/**
|
|
80
|
+
* Model size preset. Overrides individual model config fields.
|
|
81
|
+
* - 'nano' : dModel=128, numLayers=4
|
|
82
|
+
* - 'small' : dModel=256, numLayers=6
|
|
83
|
+
* - 'medium' : dModel=512, numLayers=8 (default)
|
|
84
|
+
* - 'large' : dModel=768, numLayers=12
|
|
85
|
+
* - 'custom' : use modelConfig directly
|
|
86
|
+
*/
|
|
87
|
+
modelSize? : 'nano' | 'small' | 'medium' | 'large' | 'custom';
|
|
88
|
+
/** Fine-grained model configuration. Only used when modelSize is 'custom'. */
|
|
89
|
+
modelConfig? : Partial<HybridMambaModelConfig>;
|
|
90
|
+
/**
|
|
91
|
+
* SSM variant applied to all layers when no layerSchedule is given.
|
|
92
|
+
* Default: 'mamba1' (existing behaviour).
|
|
93
|
+
*/
|
|
94
|
+
mambaVersion? : LayerType extends 'attention' ? never : 'mamba1' | 'mamba2' | 'mamba3';
|
|
95
|
+
/**
|
|
96
|
+
* Per-layer type schedule. Length must equal the resolved numLayers.
|
|
97
|
+
* Overrides mambaVersion when provided.
|
|
98
|
+
*
|
|
99
|
+
* Preset strings:
|
|
100
|
+
* 'jamba' — every 4th layer is attention, rest mamba2
|
|
101
|
+
* 'zamba' — every 6th layer is attention, rest mamba3
|
|
102
|
+
*/
|
|
103
|
+
layerSchedule? : LayerSpec[] | 'jamba' | 'zamba';
|
|
104
|
+
/**
|
|
105
|
+
* Pre-created GPUAdapter to use instead of calling navigator.gpu.requestAdapter().
|
|
106
|
+
* Use this in Node.js environments with @webgpu/node:
|
|
107
|
+
* import { create as createGPU } from '@webgpu/node';
|
|
108
|
+
* const gpuAdapter = await createGPU().requestAdapter();
|
|
109
|
+
* When provided, `powerPreference` and `allowCpuFallback` are ignored.
|
|
110
|
+
*/
|
|
111
|
+
gpuAdapter?: GPUAdapter;
|
|
112
|
+
/**
|
|
113
|
+
* IDBFactory to use instead of the global `indexedDB`.
|
|
114
|
+
* Use this in Node.js environments with fake-indexeddb:
|
|
115
|
+
* import { IDBFactory } from 'fake-indexeddb';
|
|
116
|
+
* const idbFactory = new IDBFactory();
|
|
117
|
+
* When provided, the 'indexedDB' storage target uses this factory.
|
|
118
|
+
*/
|
|
119
|
+
idbFactory?: IDBFactory;
|
|
120
|
+
/** WebGPU power preference. Default: 'high-performance'. */
|
|
121
|
+
powerPreference?: 'high-performance' | 'low-power';
|
|
122
|
+
/**
|
|
123
|
+
* When true, MambaSession will attempt to obtain a software (CPU) WebGPU
|
|
124
|
+
* adapter if the preferred GPU adapter is unavailable — for example in
|
|
125
|
+
* environments that lack a discrete GPU or where the browser has disabled
|
|
126
|
+
* WebGPU. Performance will be significantly degraded, but the session will
|
|
127
|
+
* still initialise and produce correct output.
|
|
128
|
+
*
|
|
129
|
+
* Requires a browser/runtime that supports `forceFallbackAdapter: true` on
|
|
130
|
+
* `navigator.gpu.requestAdapter()`. Node.js environments may not have a
|
|
131
|
+
* software adapter available; an error is thrown in that case.
|
|
132
|
+
*
|
|
133
|
+
* Default: false (a missing GPU is a hard error).
|
|
134
|
+
*/
|
|
135
|
+
allowCpuFallback?: boolean;
|
|
136
|
+
/** Number of times to retry a failed checkpoint fetch. Default: 2. */
|
|
137
|
+
fetchRetries? : number;
|
|
138
|
+
/**
|
|
139
|
+
* Deterministic seed for weight initialisation. When set, a model created
|
|
140
|
+
* without a checkpoint initialises reproducibly — the same seed yields
|
|
141
|
+
* byte-identical weights on any machine. Omit for non-reproducible
|
|
142
|
+
* `Math.random` init (the default).
|
|
143
|
+
*/
|
|
144
|
+
seed? : number;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
export interface CompleteOptions {
|
|
148
|
+
maxNewTokens? : number; // Default: 200
|
|
149
|
+
temperature? : number; // Default: 0.8
|
|
150
|
+
topK? : number; // Default: 50
|
|
151
|
+
topP? : number; // Default: 0.9
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
export interface AdaptOptions {
|
|
155
|
+
epochs? : number; // Default: 3
|
|
156
|
+
learningRate? : number; // Default: 1e-4
|
|
157
|
+
seqLen? : number; // Default: 512
|
|
158
|
+
wsla? : boolean; // Default: true (WSLA fast-adapt mode)
|
|
159
|
+
fullTrain? : boolean; // Convenience alias: sets wsla=false and epochs=5
|
|
160
|
+
onProgress? : (epoch: number, loss: number) => void;
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
export interface AdaptResult {
|
|
164
|
+
losses : number[];
|
|
165
|
+
epochCount : number;
|
|
166
|
+
durationMs : number;
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
export type StorageTarget = 'indexedDB' | 'download' | 'fileSystem';
|
|
170
|
+
|
|
171
|
+
export interface SaveOptions {
|
|
172
|
+
storage? : StorageTarget; // Default: 'indexedDB'
|
|
173
|
+
filename? : string; // Used by 'download' and 'fileSystem'. Default: '<name>.bin'
|
|
174
|
+
key? : string; // IndexedDB key override. Default: session name
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
export interface LoadOptions {
|
|
178
|
+
storage? : StorageTarget; // Default: 'indexedDB'
|
|
179
|
+
url? : string; // Used when storage is 'url'
|
|
180
|
+
key? : string; // IndexedDB key override. Default: session name
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
export type CreateStage = 'gpu' | 'tokenizer' | 'model' | 'weights';
|
|
184
|
+
|
|
185
|
+
export interface CreateProgressEvent {
|
|
186
|
+
stage : CreateStage;
|
|
187
|
+
progress : number; // 0.0 – 1.0 within the current stage
|
|
188
|
+
message : string;
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
export interface SessionInternals {
|
|
192
|
+
device : GPUDevice;
|
|
193
|
+
model : HybridMambaModel;
|
|
194
|
+
trainer : MambaTrainer;
|
|
195
|
+
tokenizer : BPETokenizer;
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
export interface CreateCallbacks {
|
|
199
|
+
onProgress?: (event: CreateProgressEvent) => void;
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
/** Base delay (ms) for the first checkpoint fetch retry. Subsequent retries double this. */
|
|
203
|
+
const RETRY_BASE_DELAY_MS = 500;
|
|
204
|
+
/** Multiplier applied to delay on each successive retry. */
|
|
205
|
+
const RETRY_BACKOFF_FACTOR = 2;
|
|
206
|
+
|
|
207
|
+
// ── MambaSession ───────────────────────────────────────────────────────────────
|
|
208
|
+
|
|
209
|
+
export type GpuMode = 'webgpu' | 'cpu-fallback';
|
|
210
|
+
|
|
211
|
+
export class MambaSession {
|
|
212
|
+
private _device : GPUDevice;
|
|
213
|
+
private _tokenizer : BPETokenizer;
|
|
214
|
+
private _model : HybridMambaModel;
|
|
215
|
+
private _trainer : MambaTrainer;
|
|
216
|
+
private _name : string;
|
|
217
|
+
private _destroyed = false;
|
|
218
|
+
private _gpuMode : GpuMode = 'webgpu';
|
|
219
|
+
private _idbFactory : IDBFactory | undefined;
|
|
220
|
+
|
|
221
|
+
/**
|
|
222
|
+
* Whether the session is running on a real GPU ('webgpu') or a software
|
|
223
|
+
* CPU fallback ('cpu-fallback'). Only 'cpu-fallback' when
|
|
224
|
+
* `allowCpuFallback: true` was passed and no GPU was available.
|
|
225
|
+
*/
|
|
226
|
+
get gpuMode(): GpuMode { return this._gpuMode; }
|
|
227
|
+
|
|
228
|
+
private constructor(
|
|
229
|
+
device : GPUDevice,
|
|
230
|
+
tokenizer : BPETokenizer,
|
|
231
|
+
model : HybridMambaModel,
|
|
232
|
+
trainer : MambaTrainer,
|
|
233
|
+
name : string,
|
|
234
|
+
gpuMode : GpuMode = 'webgpu',
|
|
235
|
+
idbFactory? : IDBFactory,
|
|
236
|
+
) {
|
|
237
|
+
this._device = device;
|
|
238
|
+
this._tokenizer = tokenizer;
|
|
239
|
+
this._model = model;
|
|
240
|
+
this._trainer = trainer;
|
|
241
|
+
this._name = name;
|
|
242
|
+
this._gpuMode = gpuMode;
|
|
243
|
+
this._idbFactory = idbFactory;
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
// ── Static factory ─────────────────────────────────────────────────────────
|
|
247
|
+
|
|
248
|
+
static async create(
|
|
249
|
+
options : MambaSessionOptions,
|
|
250
|
+
callbacks: CreateCallbacks = {},
|
|
251
|
+
): Promise<MambaSession> {
|
|
252
|
+
const { onProgress } = callbacks;
|
|
253
|
+
const name = options.name ?? 'default';
|
|
254
|
+
const fetchRetries = options.fetchRetries ?? 2;
|
|
255
|
+
|
|
256
|
+
const emit = (stage: CreateStage, progress: number, message: string) => {
|
|
257
|
+
onProgress?.({ stage, progress, message });
|
|
258
|
+
};
|
|
259
|
+
|
|
260
|
+
// Step 1 — GPU
|
|
261
|
+
emit('gpu', 0.0, 'Initialising WebGPU…');
|
|
262
|
+
let device: GPUDevice;
|
|
263
|
+
let gpuMode: GpuMode = 'webgpu';
|
|
264
|
+
|
|
265
|
+
if (options.gpuAdapter != null) {
|
|
266
|
+
// Adapter injected externally (e.g. @webgpu/node in Node.js)
|
|
267
|
+
try {
|
|
268
|
+
device = await options.gpuAdapter.requestDevice();
|
|
269
|
+
} catch (err) {
|
|
270
|
+
throw new SessionError(
|
|
271
|
+
'GPU_UNAVAILABLE',
|
|
272
|
+
`Failed to acquire GPUDevice from provided gpuAdapter: ${(err as Error).message}`,
|
|
273
|
+
err,
|
|
274
|
+
);
|
|
275
|
+
}
|
|
276
|
+
emit('gpu', 1.0, 'WebGPU ready (injected adapter)');
|
|
277
|
+
} else {
|
|
278
|
+
try {
|
|
279
|
+
const result = await initWebGPU({
|
|
280
|
+
powerPreference: options.powerPreference ?? 'high-performance',
|
|
281
|
+
});
|
|
282
|
+
device = result.device;
|
|
283
|
+
} catch (primaryErr) {
|
|
284
|
+
if (!options.allowCpuFallback) {
|
|
285
|
+
throw new SessionError(
|
|
286
|
+
'GPU_UNAVAILABLE',
|
|
287
|
+
`WebGPU initialisation failed: ${(primaryErr as Error).message}. ` +
|
|
288
|
+
`Set allowCpuFallback: true to attempt a software (CPU) fallback.`,
|
|
289
|
+
primaryErr,
|
|
290
|
+
);
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
// Attempt software (CPU) adapter — available in Chrome with
|
|
294
|
+
// --enable-unsafe-webgpu or in environments with a WARP/SwiftShader adapter.
|
|
295
|
+
emit('gpu', 0.4,
|
|
296
|
+
`WebGPU unavailable (${(primaryErr as Error).message}); ` +
|
|
297
|
+
`attempting CPU software fallback — performance will be degraded.`);
|
|
298
|
+
|
|
299
|
+
if (typeof navigator === 'undefined' || !navigator.gpu) {
|
|
300
|
+
throw new SessionError(
|
|
301
|
+
'GPU_UNAVAILABLE',
|
|
302
|
+
'WebGPU is not available in this environment and no software adapter can be requested.',
|
|
303
|
+
primaryErr,
|
|
304
|
+
);
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
try {
|
|
308
|
+
const fallbackAdapter = await navigator.gpu.requestAdapter({
|
|
309
|
+
powerPreference: 'low-power',
|
|
310
|
+
forceFallbackAdapter: true,
|
|
311
|
+
});
|
|
312
|
+
if (!fallbackAdapter) throw new Error('requestAdapter returned null');
|
|
313
|
+
device = await fallbackAdapter.requestDevice();
|
|
314
|
+
gpuMode = 'cpu-fallback';
|
|
315
|
+
} catch (fallbackErr) {
|
|
316
|
+
throw new SessionError(
|
|
317
|
+
'GPU_UNAVAILABLE',
|
|
318
|
+
`WebGPU unavailable and CPU fallback failed: ${(fallbackErr as Error).message}`,
|
|
319
|
+
fallbackErr,
|
|
320
|
+
);
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
emit('gpu', 1.0, gpuMode === 'cpu-fallback'
|
|
324
|
+
? 'CPU software adapter ready (degraded performance)'
|
|
325
|
+
: 'WebGPU ready');
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
// Step 2 — Tokenizer
|
|
329
|
+
emit('tokenizer', 0.0, 'Loading tokenizer…');
|
|
330
|
+
let tokenizer: BPETokenizer;
|
|
331
|
+
if (options.tokenizer != null) {
|
|
332
|
+
// Custom tokenizer injected — skip BPETokenizer construction entirely.
|
|
333
|
+
// Wrap it in a BPETokenizer-shaped proxy so the rest of the code is unchanged.
|
|
334
|
+
const custom = options.tokenizer;
|
|
335
|
+
tokenizer = Object.assign(new BPETokenizer(), {
|
|
336
|
+
encode: (text: string) => custom.encode(text),
|
|
337
|
+
decode: (ids: number[]) => custom.decode(ids),
|
|
338
|
+
get vocabSize() { return custom.vocabSize; },
|
|
339
|
+
}) as BPETokenizer;
|
|
340
|
+
} else {
|
|
341
|
+
tokenizer = new BPETokenizer();
|
|
342
|
+
try {
|
|
343
|
+
if (options.vocabObject != null && options.mergesArray != null) {
|
|
344
|
+
tokenizer.loadFromObjects(options.vocabObject, options.mergesArray);
|
|
345
|
+
} else {
|
|
346
|
+
const vocabUrl = options.vocabUrl ?? DEFAULT_VOCAB_URL;
|
|
347
|
+
const mergesUrl = options.mergesUrl ?? DEFAULT_MERGES_URL;
|
|
348
|
+
await tokenizer.load(vocabUrl, mergesUrl);
|
|
349
|
+
}
|
|
350
|
+
} catch (err) {
|
|
351
|
+
throw new SessionError(
|
|
352
|
+
'TOKENIZER_LOAD_FAILED',
|
|
353
|
+
`Tokenizer failed to load: ${(err as Error).message}`,
|
|
354
|
+
err,
|
|
355
|
+
);
|
|
356
|
+
}
|
|
357
|
+
}
|
|
358
|
+
emit('tokenizer', 1.0, 'Tokenizer ready');
|
|
359
|
+
|
|
360
|
+
// Step 3 — Model & Trainer
|
|
361
|
+
emit('model', 0.0, 'Building model…');
|
|
362
|
+
const vocabSize = tokenizer.vocabSize > 0 ? tokenizer.vocabSize : 1;
|
|
363
|
+
const config = resolveModelConfig(options, vocabSize);
|
|
364
|
+
const model = new HybridMambaModel(device, config);
|
|
365
|
+
const trainer = new MambaTrainer(model, tokenizer);
|
|
366
|
+
emit('model', 1.0, 'Model ready');
|
|
367
|
+
|
|
368
|
+
// Step 4 — Checkpoint (optional)
|
|
369
|
+
// A pre-read buffer takes precedence over a URL (Node/local-file path);
|
|
370
|
+
// fetch() is only used when a URL is supplied.
|
|
371
|
+
if (options.checkpointBuffer != null) {
|
|
372
|
+
emit('weights', 0.0, 'Loading checkpoint…');
|
|
373
|
+
try {
|
|
374
|
+
await model.loadWeights(options.checkpointBuffer);
|
|
375
|
+
} catch (err) {
|
|
376
|
+
throw new SessionError(
|
|
377
|
+
'CHECKPOINT_INVALID',
|
|
378
|
+
`Checkpoint buffer is invalid or incompatible: ${(err as Error).message}`,
|
|
379
|
+
err,
|
|
380
|
+
);
|
|
381
|
+
}
|
|
382
|
+
emit('weights', 1.0, 'Checkpoint loaded');
|
|
383
|
+
} else if (options.checkpointUrl != null) {
|
|
384
|
+
emit('weights', 0.0, 'Fetching checkpoint…');
|
|
385
|
+
let buffer: ArrayBuffer | null = null;
|
|
386
|
+
let lastErr: unknown;
|
|
387
|
+
|
|
388
|
+
for (let attempt = 0; attempt <= fetchRetries; attempt++) {
|
|
389
|
+
try {
|
|
390
|
+
const res = await fetch(options.checkpointUrl);
|
|
391
|
+
if (!res.ok) {
|
|
392
|
+
throw new Error(`HTTP ${res.status} ${res.statusText}`);
|
|
393
|
+
}
|
|
394
|
+
buffer = await res.arrayBuffer();
|
|
395
|
+
break;
|
|
396
|
+
} catch (err) {
|
|
397
|
+
lastErr = err;
|
|
398
|
+
if (attempt < fetchRetries) {
|
|
399
|
+
await sleep(RETRY_BASE_DELAY_MS * Math.pow(RETRY_BACKOFF_FACTOR, attempt));
|
|
400
|
+
}
|
|
401
|
+
}
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
if (buffer == null) {
|
|
405
|
+
throw new SessionError(
|
|
406
|
+
'CHECKPOINT_FETCH_FAILED',
|
|
407
|
+
`Failed to fetch checkpoint from "${options.checkpointUrl}" after ${fetchRetries + 1} attempt(s): ${(lastErr as Error).message}`,
|
|
408
|
+
lastErr,
|
|
409
|
+
);
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
try {
|
|
413
|
+
await model.loadWeights(buffer);
|
|
414
|
+
} catch (err) {
|
|
415
|
+
throw new SessionError(
|
|
416
|
+
'CHECKPOINT_INVALID',
|
|
417
|
+
`Checkpoint file is invalid or incompatible: ${(err as Error).message}`,
|
|
418
|
+
err,
|
|
419
|
+
);
|
|
420
|
+
}
|
|
421
|
+
emit('weights', 1.0, 'Checkpoint loaded');
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
return new MambaSession(device, tokenizer, model, trainer, name, gpuMode, options.idbFactory);
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
// ── Text generation ────────────────────────────────────────────────────────
|
|
428
|
+
|
|
429
|
+
async complete(prompt: string, options: CompleteOptions = {}): Promise<string> {
|
|
430
|
+
this._assertNotDestroyed();
|
|
431
|
+
|
|
432
|
+
const {
|
|
433
|
+
maxNewTokens = 200,
|
|
434
|
+
temperature = 0.8,
|
|
435
|
+
topK = 50,
|
|
436
|
+
topP = 0.9,
|
|
437
|
+
} = options;
|
|
438
|
+
|
|
439
|
+
const promptIds = this._tokenizer.encode(prompt);
|
|
440
|
+
const outputIds = await this._model.generate(promptIds, maxNewTokens, {
|
|
441
|
+
temperature,
|
|
442
|
+
topK,
|
|
443
|
+
topP,
|
|
444
|
+
});
|
|
445
|
+
|
|
446
|
+
// Return the continuation only (not the original prompt tokens)
|
|
447
|
+
const continuationIds = outputIds.slice(promptIds.length);
|
|
448
|
+
return this._tokenizer.decode(continuationIds);
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
async *completeStream(
|
|
452
|
+
prompt : string,
|
|
453
|
+
options: CompleteOptions = {},
|
|
454
|
+
): AsyncIterable<string> {
|
|
455
|
+
this._assertNotDestroyed();
|
|
456
|
+
|
|
457
|
+
const {
|
|
458
|
+
maxNewTokens = 200,
|
|
459
|
+
temperature = 0.8,
|
|
460
|
+
topK = 50,
|
|
461
|
+
topP = 0.9,
|
|
462
|
+
} = options;
|
|
463
|
+
|
|
464
|
+
const promptIds = this._tokenizer.encode(prompt);
|
|
465
|
+
|
|
466
|
+
for await (const tokenId of tokenStream(this._model, promptIds, maxNewTokens, {
|
|
467
|
+
temperature,
|
|
468
|
+
topK,
|
|
469
|
+
topP,
|
|
470
|
+
})) {
|
|
471
|
+
yield this._tokenizer.decode([tokenId]);
|
|
472
|
+
}
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
// ── Embedding ──────────────────────────────────────────────────────────────
|
|
476
|
+
|
|
477
|
+
/**
|
|
478
|
+
* Returns a fixed-length (`dModel`) L2-normalised embedding for `text`,
|
|
479
|
+
* derived from the model's final hidden state. Suitable for cosine-similarity
|
|
480
|
+
* semantic search (see MemoryStore.recallSimilar).
|
|
481
|
+
*
|
|
482
|
+
* Returns a zero vector for empty input.
|
|
483
|
+
*/
|
|
484
|
+
async embed(text: string): Promise<Float32Array> {
|
|
485
|
+
this._assertNotDestroyed();
|
|
486
|
+
const ids = this._tokenizer.encode(text);
|
|
487
|
+
if (ids.length === 0) {
|
|
488
|
+
return new Float32Array(this._model.config.dModel);
|
|
489
|
+
}
|
|
490
|
+
return this._model.embed(ids);
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
// ── Fine-tuning ────────────────────────────────────────────────────────────
|
|
494
|
+
|
|
495
|
+
async adapt(text: string, options: AdaptOptions = {}): Promise<AdaptResult> {
|
|
496
|
+
this._assertNotDestroyed();
|
|
497
|
+
|
|
498
|
+
let {
|
|
499
|
+
epochs = 3,
|
|
500
|
+
wsla = true,
|
|
501
|
+
} = options;
|
|
502
|
+
const {
|
|
503
|
+
learningRate = 1e-4,
|
|
504
|
+
seqLen = 512,
|
|
505
|
+
fullTrain = false,
|
|
506
|
+
onProgress,
|
|
507
|
+
} = options;
|
|
508
|
+
|
|
509
|
+
// Convenience alias: fullTrain overrides wsla and epoch defaults
|
|
510
|
+
if (fullTrain) {
|
|
511
|
+
wsla = false;
|
|
512
|
+
epochs = options.epochs ?? 5;
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
const encoded = this._tokenizer.encode(text);
|
|
516
|
+
if (encoded.length < 2) {
|
|
517
|
+
throw new SessionError(
|
|
518
|
+
'INPUT_TOO_SHORT',
|
|
519
|
+
'The input text encodes to fewer than 2 tokens and cannot be used for training.',
|
|
520
|
+
);
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
const startTime = Date.now();
|
|
524
|
+
const losses = await this._trainer.train(text, {
|
|
525
|
+
epochs,
|
|
526
|
+
learningRate,
|
|
527
|
+
seqLen,
|
|
528
|
+
wsla,
|
|
529
|
+
onEpochEnd: onProgress ?? null,
|
|
530
|
+
});
|
|
531
|
+
|
|
532
|
+
return {
|
|
533
|
+
losses,
|
|
534
|
+
epochCount : losses.length,
|
|
535
|
+
durationMs : Date.now() - startTime,
|
|
536
|
+
};
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
// ── Evaluation ─────────────────────────────────────────────────────────────
|
|
540
|
+
|
|
541
|
+
async evaluate(text: string): Promise<number> {
|
|
542
|
+
this._assertNotDestroyed();
|
|
543
|
+
return this._trainer.evaluate(text);
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
// ── Persistence ────────────────────────────────────────────────────────────
|
|
547
|
+
|
|
548
|
+
async save(options: SaveOptions = {}): Promise<void> {
|
|
549
|
+
this._assertNotDestroyed();
|
|
550
|
+
|
|
551
|
+
const storage = options.storage ?? 'indexedDB';
|
|
552
|
+
const key = options.key ?? this._name;
|
|
553
|
+
const filename = options.filename ?? `${this._name}.bin`;
|
|
554
|
+
|
|
555
|
+
const buffer = await this._model.exportWeights();
|
|
556
|
+
|
|
557
|
+
switch (storage) {
|
|
558
|
+
case 'indexedDB':
|
|
559
|
+
await saveToIndexedDB(key, buffer, this._idbFactory);
|
|
560
|
+
break;
|
|
561
|
+
case 'download':
|
|
562
|
+
await triggerDownload(filename, buffer);
|
|
563
|
+
break;
|
|
564
|
+
case 'fileSystem':
|
|
565
|
+
await saveViaFileSystemAPI(filename, buffer);
|
|
566
|
+
break;
|
|
567
|
+
default:
|
|
568
|
+
throw new SessionError('STORAGE_UNAVAILABLE', `Unknown storage target: "${storage as string}"`);
|
|
569
|
+
}
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
async load(options: LoadOptions = {}): Promise<boolean> {
|
|
573
|
+
this._assertNotDestroyed();
|
|
574
|
+
|
|
575
|
+
const storage = options.storage ?? 'indexedDB';
|
|
576
|
+
const key = options.key ?? this._name;
|
|
577
|
+
|
|
578
|
+
let buffer: ArrayBuffer | undefined;
|
|
579
|
+
|
|
580
|
+
switch (storage) {
|
|
581
|
+
case 'indexedDB': {
|
|
582
|
+
buffer = await loadFromIndexedDB(key, this._idbFactory);
|
|
583
|
+
break;
|
|
584
|
+
}
|
|
585
|
+
case 'fileSystem': {
|
|
586
|
+
buffer = await loadViaFileSystemAPI();
|
|
587
|
+
break;
|
|
588
|
+
}
|
|
589
|
+
default: {
|
|
590
|
+
// Treat any other string as a URL fetch (covers custom `url` option)
|
|
591
|
+
const url = options.url;
|
|
592
|
+
if (!url) {
|
|
593
|
+
throw new SessionError(
|
|
594
|
+
'STORAGE_UNAVAILABLE',
|
|
595
|
+
'load() with storage other than "indexedDB" or "fileSystem" requires a url option.',
|
|
596
|
+
);
|
|
597
|
+
}
|
|
598
|
+
const res = await fetch(url);
|
|
599
|
+
if (!res.ok) {
|
|
600
|
+
throw new SessionError(
|
|
601
|
+
'CHECKPOINT_FETCH_FAILED',
|
|
602
|
+
`Failed to fetch checkpoint from "${url}": HTTP ${res.status}`,
|
|
603
|
+
);
|
|
604
|
+
}
|
|
605
|
+
buffer = await res.arrayBuffer();
|
|
606
|
+
}
|
|
607
|
+
}
|
|
608
|
+
|
|
609
|
+
if (buffer == null) return false;
|
|
610
|
+
|
|
611
|
+
try {
|
|
612
|
+
await this._model.loadWeights(buffer);
|
|
613
|
+
} catch (err) {
|
|
614
|
+
throw new SessionError(
|
|
615
|
+
'CHECKPOINT_INVALID',
|
|
616
|
+
`Saved checkpoint is invalid or incompatible: ${(err as Error).message}`,
|
|
617
|
+
err,
|
|
618
|
+
);
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
return true;
|
|
622
|
+
}
|
|
623
|
+
|
|
624
|
+
// ── Resource cleanup ───────────────────────────────────────────────────────
|
|
625
|
+
|
|
626
|
+
destroy(): void {
|
|
627
|
+
if (this._destroyed) return;
|
|
628
|
+
this._destroyed = true;
|
|
629
|
+
this._device.destroy();
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
// ── Escape hatch ───────────────────────────────────────────────────────────
|
|
633
|
+
|
|
634
|
+
get internals(): SessionInternals {
|
|
635
|
+
return {
|
|
636
|
+
device : this._device,
|
|
637
|
+
model : this._model,
|
|
638
|
+
trainer : this._trainer,
|
|
639
|
+
tokenizer : this._tokenizer,
|
|
640
|
+
};
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
// ── Private helpers ────────────────────────────────────────────────────────
|
|
644
|
+
|
|
645
|
+
private _assertNotDestroyed(): void {
|
|
646
|
+
if (this._destroyed) {
|
|
647
|
+
throw new SessionError(
|
|
648
|
+
'SESSION_DESTROYED',
|
|
649
|
+
'This MambaSession has been destroyed. Create a new session with MambaSession.create().',
|
|
650
|
+
);
|
|
651
|
+
}
|
|
652
|
+
}
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
function sleep(ms: number): Promise<void> {
|
|
656
|
+
return new Promise(resolve => setTimeout(resolve, ms));
|
|
657
|
+
}
|