@stellarapp/tfjs-stellar 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/jest.config.ts ADDED
@@ -0,0 +1,203 @@
1
+ /**
2
+ * For a detailed explanation regarding each configuration property, visit:
3
+ * https://jestjs.io/docs/configuration
4
+ */
5
+
6
+ import type { Config } from 'jest';
7
+
8
+ const config: Config = {
9
+ setupFiles: [],
10
+
11
+ extensionsToTreatAsEsm: [".ts"],
12
+
13
+ // A map from regular expressions to paths to transformers
14
+ transform: {
15
+ "^.+\.ts?$": ["ts-jest", {
16
+ useESM: true
17
+ }],
18
+ },
19
+
20
+ // An array of regexp pattern strings that are matched against all test paths, matched tests are skipped
21
+ testPathIgnorePatterns: [
22
+ "/node_modules/", "e2e"
23
+ ],
24
+
25
+ // A map from regular expressions to module names or to arrays of module names that allow to stub out resources with a single module
26
+ moduleNameMapper: {
27
+ "^@/(.*$)": "<rootDir>/src/$1"
28
+ },
29
+
30
+ // All imported modules in your tests should be mocked automatically
31
+ // automock: false,
32
+
33
+ // Stop running tests after `n` failures
34
+ // bail: 0,
35
+
36
+ // The directory where Jest should store its cached dependency information
37
+ // cacheDirectory: "/private/var/folders/8x/0jgq0fqx5qzgtm1zc8xtrdc80000gn/T/jest_dx",
38
+
39
+ // Automatically clear mock calls, instances, contexts and results before every test
40
+ clearMocks: true,
41
+
42
+ // Indicates whether the coverage information should be collected while executing the test
43
+ collectCoverage: false,
44
+
45
+ // An array of glob patterns indicating a set of files for which coverage information should be collected
46
+ // collectCoverageFrom: undefined,
47
+
48
+ // The directory where Jest should output its coverage files
49
+ coverageDirectory: "coverage",
50
+
51
+ // An array of regexp pattern strings used to skip coverage collection
52
+ // coveragePathIgnorePatterns: [
53
+ // "/node_modules/"
54
+ // ],
55
+
56
+ // Indicates which provider should be used to instrument code for coverage
57
+ // coverageProvider: "babel",
58
+
59
+ // A list of reporter names that Jest uses when writing coverage reports
60
+ // coverageReporters: [
61
+ // "json",
62
+ // "text",
63
+ // "lcov",
64
+ // "clover"
65
+ // ],
66
+
67
+ // An object that configures minimum threshold enforcement for coverage results
68
+ // coverageThreshold: undefined,
69
+
70
+ // A path to a custom dependency extractor
71
+ // dependencyExtractor: undefined,
72
+
73
+ // Make calling deprecated APIs throw helpful error messages
74
+ // errorOnDeprecated: false,
75
+
76
+ // The default configuration for fake timers
77
+ // fakeTimers: {
78
+ // "enableGlobally": false
79
+ // },
80
+
81
+ // Force coverage collection from ignored files using an array of glob patterns
82
+ // forceCoverageMatch: [],
83
+
84
+ // A path to a module which exports an async function that is triggered once before all test suites
85
+ // globalSetup: undefined,
86
+
87
+ // A path to a module which exports an async function that is triggered once after all test suites
88
+ // globalTeardown: undefined,
89
+
90
+ // A set of global variables that need to be available in all test environments
91
+ // globals: {},
92
+
93
+ // The maximum amount of workers used to run your tests. Can be specified as % or a number. E.g. maxWorkers: 10% will use 10% of your CPU amount + 1 as the maximum worker number. maxWorkers: 2 will use a maximum of 2 workers.
94
+ // maxWorkers: "50%",
95
+
96
+ // An array of file extensions your modules use
97
+ // moduleFileExtensions: [
98
+ // "js",
99
+ // "mjs",
100
+ // "cjs",
101
+ // "jsx",
102
+ // "ts",
103
+ // "tsx",
104
+ // "json",
105
+ // "node"
106
+ // ],
107
+
108
+ // An array of regexp pattern strings, matched against all module paths before considered 'visible' to the module loader
109
+ // modulePathIgnorePatterns: [],
110
+
111
+ // Activates notifications for test results
112
+ // notify: false,
113
+
114
+ // An enum that specifies notification mode. Requires { notify: true }
115
+ // notifyMode: "failure-change",
116
+
117
+ // A preset that is used as a base for Jest's configuration
118
+ // preset: undefined,
119
+
120
+ // Run tests from one or more projects
121
+ // projects: undefined,
122
+
123
+ // Use this configuration option to add custom reporters to Jest
124
+ // reporters: undefined,
125
+
126
+ // Automatically reset mock state before every test
127
+ // resetMocks: false,
128
+
129
+ // Reset the module registry before running each individual test
130
+ // resetModules: false,
131
+
132
+ // A path to a custom resolver
133
+ // resolver: undefined,
134
+
135
+ // Automatically restore mock state and implementation before every test
136
+ // restoreMocks: false,
137
+
138
+ // The root directory that Jest should scan for tests and modules within
139
+ // rootDir: undefined,
140
+
141
+ // A list of paths to directories that Jest should use to search for files in
142
+ // roots: [
143
+ // "<rootDir>"
144
+ // ],
145
+
146
+ // Allows you to use a custom runner instead of Jest's default test runner
147
+ // runner: "jest-runner",
148
+
149
+ // The paths to modules that run some code to configure or set up the testing environment before each test
150
+
151
+ // A list of paths to modules that run some code to configure or set up the testing framework before each test
152
+ // setupFilesAfterEnv: [],
153
+
154
+ // The number of seconds after which a test is considered as slow and reported as such in the results.
155
+ // slowTestThreshold: 5,
156
+
157
+ // A list of paths to snapshot serializer modules Jest should use for snapshot testing
158
+ // snapshotSerializers: [],
159
+
160
+ // The test environment that will be used for testing
161
+ testEnvironment: "node",
162
+
163
+ // Options that will be passed to the testEnvironment
164
+ // testEnvironmentOptions: {},
165
+
166
+ // Adds a location field to test results
167
+ // testLocationInResults: false,
168
+
169
+ // The glob patterns Jest uses to detect test files
170
+ // testMatch: [
171
+ // "**/__tests__/**/*.[jt]s?(x)",
172
+ // "**/?(*.)+(spec|test).[tj]s?(x)"
173
+ // ],
174
+
175
+ // The regexp pattern or array of patterns that Jest uses to detect test files
176
+ // testRegex: [],
177
+
178
+ // This option allows the use of a custom results processor
179
+ // testResultsProcessor: undefined,
180
+
181
+ // This option allows use of a custom test runner
182
+ // testRunner: "jest-circus/runner",
183
+
184
+ // An array of regexp pattern strings that are matched against all source file paths, matched files will skip transformation
185
+ // transformIgnorePatterns: [
186
+ // "/node_modules/",
187
+ // "\\.pnp\\.[^\\/]+$"
188
+ // ],
189
+
190
+ // An array of regexp pattern strings that are matched against all modules before the module loader will automatically return a mock for them
191
+ // unmockedModulePathPatterns: undefined,
192
+
193
+ // Indicates whether each individual test should be reported during the run
194
+ // verbose: undefined,
195
+
196
+ // An array of regexp patterns that are matched against all source file paths before re-running tests in watch mode
197
+ // watchPathIgnorePatterns: [],
198
+
199
+ // Whether to use watchman for file crawling
200
+ // watchman: true,
201
+ };
202
+
203
+ export default config;
package/package.json ADDED
@@ -0,0 +1,24 @@
1
+ {
2
+ "name": "@stellarapp/tfjs-stellar",
3
+ "version": "1.0.0",
4
+ "description": "An extension of TensorFlow.js for implementing large language models.",
5
+ "license": "ISC",
6
+ "author": "",
7
+ "type": "module",
8
+ "main": "index.ts",
9
+ "scripts": {
10
+ "test": "npx jest"
11
+ },
12
+ "devDependencies": {
13
+ "@tensorflow/tfjs": "^4.22.0",
14
+ "@types/jest": "^30.0.0",
15
+ "@types/node": "^26.0.0",
16
+ "jest": "^30.4.2",
17
+ "ts-jest": "^29.4.11",
18
+ "tsx": "^4.22.4",
19
+ "typescript": "^6.0.3"
20
+ },
21
+ "peerDependencies": {
22
+ "@tensorflow/tfjs": "*"
23
+ }
24
+ }
package/src/index.ts ADDED
@@ -0,0 +1,93 @@
1
+ export * as models from "./models";
2
+ export * as losses from "./losses";
3
+ export * as metrics from "./metrics";
4
+
5
+ import { MultiHeadAttention, type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
6
+ export { MultiHeadAttention, type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
7
+
8
+ import { CachedRoPEMultiHeadAttention } from "@/layers/cached_rope_multihead_attention";
9
+ export { CachedRoPEMultiHeadAttention } from "@/layers/cached_rope_multihead_attention";
10
+
11
+ import { TransformerEncoder, type TransformerEncoderArgs, } from "@/layers/transformer_encoder";
12
+ export { TransformerEncoder, type TransformerEncoderArgs, } from "@/layers/transformer_encoder";
13
+
14
+ import { TransformerDecoder, type TransformerDecoderArgs, } from "@/layers/transformer_decoder";
15
+ export { TransformerDecoder, type TransformerDecoderArgs, } from "@/layers/transformer_decoder";
16
+
17
+ import { TokenAndPositionalEmbedding, type TokenAndPositionalEmbeddingArgs } from "@/layers/token_and_positional_embedding";
18
+ export { TokenAndPositionalEmbedding, type TokenAndPositionalEmbeddingArgs } from "@/layers/token_and_positional_embedding";
19
+
20
+ import { PositionalEncoding, type PositionalEncodingArgs } from "@/layers/positional_encoding";
21
+ export { PositionalEncoding, type PositionalEncodingArgs } from "@/layers/positional_encoding";
22
+
23
+ import { GPT2DecoderBlock, type GPTDecoderBlockArgs } from "@/layers/gpt_decoder_block";
24
+ export { GPT2DecoderBlock as GPTDecoderBlock, type GPTDecoderBlockArgs } from "@/layers/gpt_decoder_block";
25
+
26
+ import { LlmModel, type LlmModelArgs } from "@/models/llm_model";
27
+ export { LlmModel, type LlmModelArgs } from "@/models/llm_model";
28
+
29
+ import { UNetModel, type UNetModelArgs } from "@/models/u_net";
30
+
31
+ import { RotaryPositionEmbedding, type RotaryPositionEmbeddingArgs } from "@/layers/rotary_position_embedding";
32
+ export { RotaryPositionEmbedding, type RotaryPositionEmbeddingArgs } from "@/layers/rotary_position_embedding";
33
+
34
+
35
+ import { GptModel, type GptModelArgs } from "@/models/gpt_model";
36
+ export { GptModel, type GptModelArgs } from "@/models/gpt_model";
37
+
38
+
39
+ // The following exports give a keras-like import just like TFJS's tf.layers.<...>
40
+
41
+ export function llmModel(args: LlmModelArgs) {
42
+ return new LlmModel(args);
43
+ }
44
+
45
+
46
+ export function gptModel(args: GptModelArgs) {
47
+ return new GptModel(args);
48
+ }
49
+
50
+
51
+ export function tokenAndPositionalEmbedding(args: TokenAndPositionalEmbeddingArgs) {
52
+ return new TokenAndPositionalEmbedding(args);
53
+ }
54
+
55
+
56
+ export function transformerEncoder(args: TransformerEncoderArgs) {
57
+ return new TransformerEncoder(args);
58
+ }
59
+
60
+
61
+ export function transformerDecoder(args: TransformerDecoderArgs) {
62
+ return new TransformerDecoder(args);
63
+ }
64
+
65
+
66
+ export function multiheadAttention(args: MultiHeadAttentionArgs) {
67
+ return new MultiHeadAttention(args);
68
+ }
69
+
70
+
71
+ export function cachedRopeMultiheadAttention(args: MultiHeadAttentionArgs) {
72
+ return new CachedRoPEMultiHeadAttention(args);
73
+ }
74
+
75
+
76
+ export function positionalEncoding(args: PositionalEncodingArgs) {
77
+ return new PositionalEncoding(args);
78
+ }
79
+
80
+
81
+ export function gpt2DecoderBlock(args: GPTDecoderBlockArgs) {
82
+ return new GPT2DecoderBlock(args);
83
+ }
84
+
85
+
86
+ export function unetModel(args: UNetModelArgs) {
87
+ return new UNetModel(args);
88
+ }
89
+
90
+
91
+ export function rotaryPositionEmbedding(args: RotaryPositionEmbeddingArgs) {
92
+ return new RotaryPositionEmbedding(args);
93
+ }
@@ -0,0 +1,205 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+
3
+
4
+ export interface KvCacheArgs {
5
+ batchSize: number;
6
+ maxSequenceLength: number;
7
+ numHeads: number;
8
+ headDim: number;
9
+ dtype?: tf.DataType
10
+ }
11
+
12
+
13
+ /**
14
+ * A container for KV caches. A model should initialize one KV cache
15
+ */
16
+ export class KvCacheContainer {
17
+ protected caches = new Map<string, KvCache>();
18
+ protected max_sequence_length: number;
19
+
20
+
21
+ constructor(maxSequenceLength: number) {
22
+ if (!maxSequenceLength) {
23
+ throw Error(`KvCacheContainer: expected KV cache maximum sequence length to be greater than 0, got: ${String(maxSequenceLength)}`);
24
+ }
25
+
26
+ this.max_sequence_length = maxSequenceLength;
27
+ }
28
+
29
+
30
+ public create(id: string, args: Omit<KvCacheArgs, "maxSequenceLength">) {
31
+ const new_cache = new KvCache({
32
+ ...args,
33
+ maxSequenceLength: this.max_sequence_length
34
+ });
35
+
36
+ this.caches.set(id, new_cache);
37
+ }
38
+
39
+
40
+ /**
41
+ * The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
42
+ */
43
+ public update(id: string, key: tf.Tensor4D, value: tf.Tensor4D) {
44
+ const kv_cache = this.caches.get(id);
45
+
46
+ if (!kv_cache) {
47
+ return undefined;
48
+ }
49
+
50
+ const { keyCache, valueCache } = kv_cache.update(key, value);
51
+
52
+ // slicing to get only the past key and value projections, but normally
53
+ // in TensorFlow and PyTorch the full cache is returned and masked for
54
+ // graph purposes
55
+ return tf.tidy(() => {
56
+ const k_cache = keyCache.slice(
57
+ [0, 0, 0, 0],
58
+ [keyCache.shape[0], keyCache.shape[1], kv_cache.size, keyCache.shape[3]]);
59
+ const v_cache = valueCache.slice(
60
+ [0, 0, 0, 0],
61
+ [valueCache.shape[0], valueCache.shape[1], kv_cache.size, valueCache.shape[3]]);
62
+
63
+ return {
64
+ keyCache: k_cache,
65
+ valueCache: v_cache
66
+ }
67
+ })
68
+ }
69
+
70
+
71
+ public reset() {
72
+ this.caches.forEach(cache => {
73
+ cache.reset();
74
+ })
75
+ }
76
+
77
+
78
+ public dispose() {
79
+ this.caches.forEach(cache => {
80
+ cache.dispose();
81
+ })
82
+ }
83
+
84
+
85
+ public get size() {
86
+ // the size of all KV caches are expected to be the same, just use the first one
87
+ return this.caches.entries().next().value?.[1].size ?? 0;
88
+ }
89
+
90
+
91
+ public get maxSequenceLength() {
92
+ return this.max_sequence_length;
93
+ }
94
+ }
95
+
96
+
97
+ export class KvCache {
98
+
99
+ protected key_cache: tf.Variable<tf.Rank.R4>;
100
+ protected value_cache: tf.Variable<tf.Rank.R4>
101
+
102
+ // the size of the KV cache, represents the number of tokens since the first chat token
103
+ protected current_position: number = 0;
104
+
105
+ protected batch_size: number;
106
+ protected max_sequence_length: number;
107
+ protected num_kv_heads: number;
108
+ protected head_dim: number;
109
+
110
+ constructor({ batchSize, maxSequenceLength, numHeads, headDim, dtype = "float32" }: KvCacheArgs) {
111
+ const cache_shape = [batchSize, numHeads, maxSequenceLength, headDim] as [number, number, number, number];
112
+
113
+ this.key_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
114
+ this.value_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
115
+
116
+ this.batch_size = batchSize;
117
+ this.max_sequence_length = maxSequenceLength;
118
+ this.num_kv_heads = numHeads;
119
+ this.head_dim = headDim;
120
+ }
121
+
122
+
123
+ /**
124
+ * The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
125
+ */
126
+ public update(key: tf.Tensor4D, value: tf.Tensor4D) {
127
+ const batch_size = key.shape[0];
128
+ const seq_len = key.shape[2];
129
+
130
+ if (batch_size > this.key_cache.shape[0]) {
131
+ throw Error(`The current KV cache has been set up with a batch size of` +
132
+ ` ${this.key_cache.shape[0]}, but found new key tensors with batch size ${batch_size}`)
133
+ }
134
+
135
+ if (this.current_position + seq_len > this.max_sequence_length) {
136
+ throw Error(`The KV cache has exceeded its maximum sequence length of ${this.max_sequence_length}. Use a larger value.`);
137
+ }
138
+
139
+ const new_key_cache = this.mergeIntoCache(key, this.key_cache);
140
+ const new_value_cache = this.mergeIntoCache(value, this.value_cache);
141
+
142
+ this.key_cache.assign(new_key_cache);
143
+ this.value_cache.assign(new_value_cache);
144
+
145
+ new_key_cache.dispose();
146
+ new_value_cache.dispose();
147
+
148
+ // advance the pointer to reflect the updated cache's current
149
+ this.current_position += seq_len;
150
+
151
+ return {
152
+ keyCache: this.key_cache,
153
+ valueCache: this.value_cache,
154
+ }
155
+ }
156
+
157
+
158
+ protected mergeIntoCache(new_value: tf.Tensor4D, current_cache: tf.Tensor4D) {
159
+ const seq_len = new_value.shape[2];
160
+
161
+ return tf.tidy(() => {
162
+
163
+ const historical = current_cache.slice(
164
+ [0, 0, 0, 0],
165
+ [this.batch_size, this.num_kv_heads, this.current_position, this.head_dim]);
166
+
167
+ const future = current_cache.slice(
168
+ [0, 0, this.current_position + seq_len, 0],
169
+ [this.batch_size, this.num_kv_heads, this.max_sequence_length - this.current_position - seq_len, this.head_dim]);
170
+
171
+ // merge the new tensor into the current cache to create a new, larger, cache,
172
+ // this is different from Python immplementations because TFJS tensors are immutable,
173
+ // because we cannot update a slice, we must slice and concat
174
+ return tf.concat([historical, new_value, future], 2);
175
+ })
176
+ }
177
+
178
+
179
+ public reset(): void {
180
+ this.current_position = 0;
181
+
182
+ tf.tidy(() => {
183
+ const key_cache_shape = this.key_cache.shape;
184
+ const value_cache_shape = this.value_cache.shape;
185
+
186
+ this.key_cache.assign(tf.zeros(key_cache_shape));
187
+ this.value_cache.assign(tf.zeros(value_cache_shape));
188
+ });
189
+ }
190
+
191
+
192
+ public dispose(): void {
193
+ this.key_cache.dispose();
194
+ this.value_cache.dispose();
195
+ }
196
+
197
+
198
+ /**
199
+ * The size of the KV cache, also the number of tokens since the first one.
200
+ */
201
+ get size(): number {
202
+ return this.current_position;
203
+ }
204
+
205
+ }
@@ -0,0 +1,59 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+
3
+ import { KvCacheContainer } from '@/kv_cache';
4
+ import { CachedRoPEMultiHeadAttention } from '@/layers/cached_rope_multihead_attention';
5
+
6
+
7
+ // disables warning for using the faster node backend,
8
+ // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
9
+ tf.env().set('IS_NODE', false);
10
+
11
+
12
+ describe("CachedRoPEMultiHeadAttention tests", () => {
13
+
14
+ test("aggregate forward passes output are identical normal multihead attention", () => {
15
+ compareNormalWithCachedAttention(tf.randomUniform<tf.Rank.R3>([2, 10, 16]), 123);
16
+ compareNormalWithCachedAttention(tf.randomUniform<tf.Rank.R3>([1, 10, 16]), 123);
17
+ compareNormalWithCachedAttention(tf.randomUniform<tf.Rank.R3>([1, 1, 16]), 123);
18
+ compareNormalWithCachedAttention(tf.randomUniform<tf.Rank.R3>([3, 2, 16]), 123);
19
+
20
+ // input exceeds KV cach size
21
+ expect(() => compareNormalWithCachedAttention(tf.randomUniform<tf.Rank.R3>([1, 10, 16]), 5)).toThrow();
22
+
23
+ function compareNormalWithCachedAttention(input: tf.Tensor3D, max_sequence_length: number) {
24
+ const embed_dim = input.shape[2];
25
+ const batch = input.shape[0];
26
+ const heads = 2;
27
+
28
+ const kv_cache = new KvCacheContainer(max_sequence_length);
29
+
30
+ const normal_mha = new CachedRoPEMultiHeadAttention({ numHeads: heads, embedDim: embed_dim, causal: true });
31
+ const normal_mha_output = normal_mha.apply(input) as tf.Tensor;
32
+
33
+ // initialize cached attention with identical configuration and weights
34
+ const cached_mha1 = new CachedRoPEMultiHeadAttention({ ...normal_mha.getConfig(), name: "cache_test1" });
35
+ cached_mha1.build(input.shape);
36
+ cached_mha1.setWeights(normal_mha.getWeights());
37
+
38
+ const cached_mha2 = new CachedRoPEMultiHeadAttention({ ...normal_mha.getConfig(), name: "cache_test2" });
39
+ cached_mha2.build(input.shape);
40
+ cached_mha2.setWeights(normal_mha.getWeights());
41
+
42
+ const cached_mha_outputs1: tf.Tensor[] = [];
43
+ const cached_mha_outputs2: tf.Tensor[] = [];
44
+
45
+ for (let i = 0; i < input.shape[1]; i++) {
46
+ const current_token = input.slice([0, i, 0], [batch, 1, embed_dim]);
47
+
48
+ cached_mha_outputs1.push(cached_mha1.apply(current_token, { kvCache: kv_cache }) as tf.Tensor);
49
+ cached_mha_outputs2.push(cached_mha2.apply(current_token, { kvCache: kv_cache }) as tf.Tensor);
50
+ }
51
+
52
+ expect(kv_cache.size== input.shape[1]);
53
+ expect(kv_cache.size == input.shape[1]);
54
+
55
+ expect(normal_mha_output.sub(tf.concat(cached_mha_outputs1, 1)).sum().dataSync()[0]).toBeLessThan(1e-6);
56
+ expect(normal_mha_output.sub(tf.concat(cached_mha_outputs2, 1)).sum().dataSync()[0]).toBeLessThan(1e-6);
57
+ }
58
+ })
59
+ });