effect-gpt 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +50 -0
- package/data/chat_training_data.json +55 -0
- package/data/pretraining_data.json +27 -0
- package/package.json +25 -0
- package/src/cli/errors.ts +51 -0
- package/src/cli/main.ts +163 -0
- package/src/config.ts +3 -0
- package/src/data/Dataset.ts +168 -0
- package/src/errors.ts +73 -0
- package/src/index.ts +88 -0
- package/src/model/Embeddings.ts +108 -0
- package/src/model/FeedForward.ts +121 -0
- package/src/model/LLM.ts +124 -0
- package/src/model/LayerNorm.ts +138 -0
- package/src/model/ModelLayer.ts +10 -0
- package/src/model/OutputProjection.ts +76 -0
- package/src/model/SelfAttention.ts +169 -0
- package/src/model/TransformerBlock.ts +53 -0
- package/src/services/Logger.ts +124 -0
- package/src/services/Metrics.ts +260 -0
- package/src/services/Random.ts +98 -0
- package/src/services/SeedLayer.ts +39 -0
- package/src/services/index.ts +32 -0
- package/src/tensor/Tensor2D.ts +42 -0
- package/src/tensor/ops.ts +371 -0
- package/src/tensor/random.ts +32 -0
- package/src/tokenize/split.ts +27 -0
- package/src/tokenize/tokenize.ts +28 -0
- package/src/training/Adam.ts +61 -0
- package/src/training/clip.ts +16 -0
- package/src/training/loss.ts +35 -0
- package/src/training/train.ts +203 -0
- package/src/vocab/Vocab.ts +79 -0
- package/tests/fixtures/csv_bad.csv +2 -0
- package/tests/fixtures/csv_good.csv +3 -0
- package/tests/ts/cli_error_format.test.ts +26 -0
- package/tests/ts/dataset.test.ts +35 -0
- package/tests/ts/embeddings.test.ts +81 -0
- package/tests/ts/errors.test.ts +36 -0
- package/tests/ts/feed_forward.test.ts +74 -0
- package/tests/ts/initNormal.test.ts +41 -0
- package/tests/ts/layer_norm.test.ts +96 -0
- package/tests/ts/llm_parameters.test.ts +96 -0
- package/tests/ts/llm_predict.test.ts +98 -0
- package/tests/ts/llm_tokenize.test.ts +69 -0
- package/tests/ts/output_projection.test.ts +78 -0
- package/tests/ts/random.test.ts +44 -0
- package/tests/ts/self_attention.test.ts +63 -0
- package/tests/ts/support/factories.ts +126 -0
- package/tests/ts/support/runEffect.ts +29 -0
- package/tests/ts/support/seed.ts +12 -0
- package/tests/ts/support/stubs.ts +58 -0
- package/tests/ts/support/tensorMatchers.ts +96 -0
- package/tests/ts/support.test.ts +165 -0
- package/tests/ts/train_loop.test.ts +229 -0
- package/tests/ts/transformer_block.test.ts +72 -0
- package/tsconfig.json +20 -0
- package/tsconfig.test.json +8 -0
package/src/index.ts
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
export { Vocab } from "./vocab/Vocab"
|
|
2
|
+
export { tokenize } from "./tokenize/tokenize"
|
|
3
|
+
export { Dataset, DatasetLoadError, DatasetParseError } from "./data/Dataset"
|
|
4
|
+
export {
|
|
5
|
+
TrainingError,
|
|
6
|
+
TrainingDatasetError,
|
|
7
|
+
TrainingShapeError,
|
|
8
|
+
TrainingTokenizerError,
|
|
9
|
+
TrainingOptimizerError,
|
|
10
|
+
TrainingConfigError,
|
|
11
|
+
TrainingUnknownError
|
|
12
|
+
} from "./errors"
|
|
13
|
+
export * from "./config"
|
|
14
|
+
|
|
15
|
+
export type { Tensor2D } from "./tensor/Tensor2D"
|
|
16
|
+
export * as T2D from "./tensor/Tensor2D"
|
|
17
|
+
export * as TensorOps from "./tensor/ops"
|
|
18
|
+
export { ShapeError } from "./tensor/ops"
|
|
19
|
+
export type { Rng } from "./tensor/random"
|
|
20
|
+
export { seeded } from "./tensor/random"
|
|
21
|
+
export { systemRng } from "./tensor/random"
|
|
22
|
+
|
|
23
|
+
export type { ModelLayer } from "./model/ModelLayer"
|
|
24
|
+
export { Embeddings } from "./model/Embeddings"
|
|
25
|
+
export { SelfAttention } from "./model/SelfAttention"
|
|
26
|
+
export { FeedForward } from "./model/FeedForward"
|
|
27
|
+
export { LayerNorm } from "./model/LayerNorm"
|
|
28
|
+
export { TransformerBlock } from "./model/TransformerBlock"
|
|
29
|
+
export { OutputProjection } from "./model/OutputProjection"
|
|
30
|
+
export { LLM } from "./model/LLM"
|
|
31
|
+
|
|
32
|
+
export { Adam } from "./training/Adam"
|
|
33
|
+
export { clipGlobalL2 } from "./training/clip"
|
|
34
|
+
export { softmaxRows, crossEntropyLoss, dLogits } from "./training/loss"
|
|
35
|
+
export {
|
|
36
|
+
train,
|
|
37
|
+
trainStream,
|
|
38
|
+
LLMService,
|
|
39
|
+
TrainingConfig,
|
|
40
|
+
makeLLMLayer,
|
|
41
|
+
makeTrainingConfigLayer
|
|
42
|
+
} from "./training/train"
|
|
43
|
+
|
|
44
|
+
export type {
|
|
45
|
+
LogLevel,
|
|
46
|
+
LoggerService,
|
|
47
|
+
LoggerServiceId,
|
|
48
|
+
RandomService,
|
|
49
|
+
RandomServiceId,
|
|
50
|
+
SeedService,
|
|
51
|
+
SeedServiceId,
|
|
52
|
+
MetricsService,
|
|
53
|
+
MetricsServiceId,
|
|
54
|
+
MetricsSnapshot,
|
|
55
|
+
Counter,
|
|
56
|
+
Gauge,
|
|
57
|
+
Histogram,
|
|
58
|
+
TimingResult
|
|
59
|
+
} from "./services"
|
|
60
|
+
export {
|
|
61
|
+
Logger,
|
|
62
|
+
ConsoleLoggerLive,
|
|
63
|
+
TerminalLoggerLive,
|
|
64
|
+
NullLoggerLive,
|
|
65
|
+
log,
|
|
66
|
+
debug,
|
|
67
|
+
info,
|
|
68
|
+
warn,
|
|
69
|
+
error,
|
|
70
|
+
Random,
|
|
71
|
+
SeededRandomLive,
|
|
72
|
+
SystemRandomLive,
|
|
73
|
+
next,
|
|
74
|
+
nextGaussian,
|
|
75
|
+
nextInt,
|
|
76
|
+
fork,
|
|
77
|
+
Seed,
|
|
78
|
+
SeedLayer,
|
|
79
|
+
useSeedRng,
|
|
80
|
+
Metrics,
|
|
81
|
+
InMemoryMetricsLive,
|
|
82
|
+
NoOpMetricsLive,
|
|
83
|
+
counter,
|
|
84
|
+
gauge,
|
|
85
|
+
histogram,
|
|
86
|
+
timed,
|
|
87
|
+
snapshot
|
|
88
|
+
} from "./services"
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import * as Effect from "effect/Effect"
|
|
2
|
+
import * as FiberId from "effect/FiberId"
|
|
3
|
+
import type { Tensor2D } from "../tensor/Tensor2D"
|
|
4
|
+
import * as T from "../tensor/Tensor2D"
|
|
5
|
+
import * as Ops from "../tensor/ops"
|
|
6
|
+
import type { ShapeError } from "../tensor/ops"
|
|
7
|
+
import type { ModelLayer } from "./ModelLayer"
|
|
8
|
+
import { MAX_SEQ_LEN, EMBEDDING_DIM } from "../config"
|
|
9
|
+
import { Adam } from "../training/Adam"
|
|
10
|
+
import type { Rng } from "../tensor/random"
|
|
11
|
+
|
|
12
|
+
export class Embeddings implements ModelLayer {
|
|
13
|
+
readonly _tag = "Embeddings"
|
|
14
|
+
tokenEmbeddings: Tensor2D
|
|
15
|
+
positionalEmbeddings: Tensor2D
|
|
16
|
+
|
|
17
|
+
private cache = new Map<number | string, Tensor2D>()
|
|
18
|
+
private lastCache: Tensor2D | null = null
|
|
19
|
+
tokenOptimizer: Adam
|
|
20
|
+
positionalOptimizer: Adam
|
|
21
|
+
|
|
22
|
+
constructor(vocabSize: number, embeddingDim: number = EMBEDDING_DIM, maxSeqLen: number = MAX_SEQ_LEN, rng: Rng) {
|
|
23
|
+
this.tokenEmbeddings = Ops.initNormal(vocabSize, embeddingDim, 0, 0.02, rng)
|
|
24
|
+
this.positionalEmbeddings = Ops.initNormal(maxSeqLen, embeddingDim, 0, 0.02, rng)
|
|
25
|
+
this.tokenOptimizer = Adam.make(vocabSize, embeddingDim)
|
|
26
|
+
this.positionalOptimizer = Adam.make(maxSeqLen, embeddingDim)
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
private fiberKey(fiberId: FiberId.FiberId): number | string {
|
|
30
|
+
return FiberId.isRuntime(fiberId) ? fiberId.id : JSON.stringify(fiberId)
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
get parametersCount(): number {
|
|
34
|
+
return this.tokenEmbeddings.data.length + this.positionalEmbeddings.data.length
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
forward(input: Tensor2D): Effect.Effect<Tensor2D, ShapeError> {
|
|
38
|
+
return Effect.gen(this, function* () {
|
|
39
|
+
const fiberId = yield* Effect.fiberId
|
|
40
|
+
const key = this.fiberKey(fiberId)
|
|
41
|
+
const cloned = T.clone(input)
|
|
42
|
+
this.cache.set(key, cloned)
|
|
43
|
+
this.lastCache = cloned
|
|
44
|
+
const tokenIds: Array<number> = []
|
|
45
|
+
for (let i = 0; i < input.data.length; i++) {
|
|
46
|
+
// Match Rust's float-to-usize truncation behavior.
|
|
47
|
+
tokenIds.push(Math.trunc(input.data[i]))
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
const seqLen = tokenIds.length
|
|
51
|
+
if (seqLen > this.positionalEmbeddings.rows) {
|
|
52
|
+
return yield* Effect.fail(
|
|
53
|
+
new Ops.ShapeError(`Sequence length ${seqLen} exceeds maximum ${this.positionalEmbeddings.rows}`)
|
|
54
|
+
)
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
const tokenEmbeds = yield* Ops.gatherRows(this.tokenEmbeddings, tokenIds)
|
|
58
|
+
const posEmbeds = yield* Ops.sliceRows(this.positionalEmbeddings, 0, seqLen)
|
|
59
|
+
const combined = yield* Ops.add(tokenEmbeds, posEmbeds)
|
|
60
|
+
return combined
|
|
61
|
+
})
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
backward(dOut: Tensor2D, lr: number): Effect.Effect<Tensor2D, ShapeError> {
|
|
65
|
+
return Effect.gen(this, function* () {
|
|
66
|
+
const fiberId = yield* Effect.fiberId
|
|
67
|
+
const key = this.fiberKey(fiberId)
|
|
68
|
+
const cachedInput = this.cache.get(key) ?? this.lastCache
|
|
69
|
+
if (!cachedInput) {
|
|
70
|
+
return yield* Effect.fail(new Ops.ShapeError("Embeddings.backward called before forward"))
|
|
71
|
+
}
|
|
72
|
+
this.cache.delete(key)
|
|
73
|
+
this.lastCache = null
|
|
74
|
+
|
|
75
|
+
const input = cachedInput
|
|
76
|
+
const tokenIds: Array<number> = []
|
|
77
|
+
for (let i = 0; i < input.data.length; i++) {
|
|
78
|
+
tokenIds.push(Math.trunc(input.data[i]))
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
const tokenGrads = T.zeros(this.tokenEmbeddings.rows, this.tokenEmbeddings.cols)
|
|
82
|
+
const positionalGrads = T.zeros(this.positionalEmbeddings.rows, this.positionalEmbeddings.cols)
|
|
83
|
+
|
|
84
|
+
const seqLen = tokenIds.length
|
|
85
|
+
for (let i = 0; i < seqLen; i++) {
|
|
86
|
+
const tokenId = tokenIds[i]
|
|
87
|
+
if (tokenId < 0 || tokenId >= this.tokenEmbeddings.rows) {
|
|
88
|
+
return yield* Effect.fail(
|
|
89
|
+
new Ops.ShapeError(`Token ID ${tokenId} out of bounds for vocab size ${this.tokenEmbeddings.rows}`)
|
|
90
|
+
)
|
|
91
|
+
}
|
|
92
|
+
const rowOffset = i * dOut.cols
|
|
93
|
+
const tokenOffset = tokenId * tokenGrads.cols
|
|
94
|
+
const posOffset = i * positionalGrads.cols
|
|
95
|
+
for (let j = 0; j < dOut.cols; j++) {
|
|
96
|
+
const grad = dOut.data[rowOffset + j]
|
|
97
|
+
tokenGrads.data[tokenOffset + j] += grad
|
|
98
|
+
positionalGrads.data[posOffset + j] += grad
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
this.tokenOptimizer.step(this.tokenEmbeddings, tokenGrads, lr)
|
|
103
|
+
this.positionalOptimizer.step(this.positionalEmbeddings, positionalGrads, lr)
|
|
104
|
+
|
|
105
|
+
return T.clone(dOut)
|
|
106
|
+
})
|
|
107
|
+
}
|
|
108
|
+
}
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import * as Effect from "effect/Effect"
|
|
2
|
+
import * as FiberId from "effect/FiberId"
|
|
3
|
+
import type { Tensor2D } from "../tensor/Tensor2D"
|
|
4
|
+
import * as T from "../tensor/Tensor2D"
|
|
5
|
+
import * as Ops from "../tensor/ops"
|
|
6
|
+
import type { ShapeError } from "../tensor/ops"
|
|
7
|
+
import type { ModelLayer } from "./ModelLayer"
|
|
8
|
+
import { EMBEDDING_DIM, HIDDEN_DIM } from "../config"
|
|
9
|
+
import { Adam } from "../training/Adam"
|
|
10
|
+
import type { Rng } from "../tensor/random"
|
|
11
|
+
|
|
12
|
+
export class FeedForward implements ModelLayer {
|
|
13
|
+
readonly _tag = "FeedForward"
|
|
14
|
+
w1: Tensor2D
|
|
15
|
+
b1: Tensor2D
|
|
16
|
+
w2: Tensor2D
|
|
17
|
+
b2: Tensor2D
|
|
18
|
+
|
|
19
|
+
private cache = new Map<
|
|
20
|
+
number | string,
|
|
21
|
+
{ input: Tensor2D; hiddenPreActivation: Tensor2D; hiddenPostActivation: Tensor2D }
|
|
22
|
+
>()
|
|
23
|
+
private lastCache:
|
|
24
|
+
| { input: Tensor2D; hiddenPreActivation: Tensor2D; hiddenPostActivation: Tensor2D }
|
|
25
|
+
| null = null
|
|
26
|
+
optimizerW1: Adam
|
|
27
|
+
optimizerB1: Adam
|
|
28
|
+
optimizerW2: Adam
|
|
29
|
+
optimizerB2: Adam
|
|
30
|
+
|
|
31
|
+
constructor(embeddingDim: number = EMBEDDING_DIM, hiddenDim: number = HIDDEN_DIM, rng: Rng) {
|
|
32
|
+
const stdW1 = Math.sqrt(2.0 / embeddingDim)
|
|
33
|
+
const stdW2 = Math.sqrt(2.0 / hiddenDim)
|
|
34
|
+
|
|
35
|
+
this.w1 = Ops.initNormal(embeddingDim, hiddenDim, 0, stdW1, rng)
|
|
36
|
+
this.b1 = T.zeros(1, hiddenDim)
|
|
37
|
+
this.w2 = Ops.initNormal(hiddenDim, embeddingDim, 0, stdW2, rng)
|
|
38
|
+
this.b2 = T.zeros(1, embeddingDim)
|
|
39
|
+
this.optimizerW1 = Adam.make(embeddingDim, hiddenDim)
|
|
40
|
+
this.optimizerB1 = Adam.make(1, hiddenDim)
|
|
41
|
+
this.optimizerW2 = Adam.make(hiddenDim, embeddingDim)
|
|
42
|
+
this.optimizerB2 = Adam.make(1, embeddingDim)
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
private fiberKey(fiberId: FiberId.FiberId): number | string {
|
|
46
|
+
return FiberId.isRuntime(fiberId) ? fiberId.id : JSON.stringify(fiberId)
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
get parametersCount(): number {
|
|
50
|
+
return this.w1.data.length + this.b1.data.length + this.w2.data.length + this.b2.data.length
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
forward(input: Tensor2D): Effect.Effect<Tensor2D, ShapeError> {
|
|
54
|
+
return Effect.gen(this, function* () {
|
|
55
|
+
const fiberId = yield* Effect.fiberId
|
|
56
|
+
const key = this.fiberKey(fiberId)
|
|
57
|
+
|
|
58
|
+
const h1 = yield* Ops.matMul(input, this.w1)
|
|
59
|
+
const h1Bias = yield* Ops.addRowBias(h1, this.b1)
|
|
60
|
+
const h1BiasClone = T.clone(h1Bias)
|
|
61
|
+
|
|
62
|
+
const h1Relu = Ops.relu(h1Bias)
|
|
63
|
+
const h1ReluClone = T.clone(h1Relu)
|
|
64
|
+
const cached = {
|
|
65
|
+
input: T.clone(input),
|
|
66
|
+
hiddenPreActivation: h1BiasClone,
|
|
67
|
+
hiddenPostActivation: h1ReluClone
|
|
68
|
+
}
|
|
69
|
+
this.cache.set(key, cached)
|
|
70
|
+
this.lastCache = cached
|
|
71
|
+
|
|
72
|
+
const h2 = yield* Ops.matMul(h1Relu, this.w2)
|
|
73
|
+
const h2Bias = yield* Ops.addRowBias(h2, this.b2)
|
|
74
|
+
const output = yield* Ops.add(h2Bias, input)
|
|
75
|
+
return output
|
|
76
|
+
})
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
backward(dOut: Tensor2D, lr: number): Effect.Effect<Tensor2D, ShapeError> {
|
|
80
|
+
return Effect.gen(this, function* () {
|
|
81
|
+
const fiberId = yield* Effect.fiberId
|
|
82
|
+
const key = this.fiberKey(fiberId)
|
|
83
|
+
const cached = this.cache.get(key) ?? this.lastCache
|
|
84
|
+
if (!cached) {
|
|
85
|
+
return yield* Effect.fail(new Ops.ShapeError("FeedForward.backward called before forward"))
|
|
86
|
+
}
|
|
87
|
+
this.cache.delete(key)
|
|
88
|
+
this.lastCache = null
|
|
89
|
+
|
|
90
|
+
const { input, hiddenPreActivation, hiddenPostActivation } = cached
|
|
91
|
+
|
|
92
|
+
const hiddenPostT = Ops.transpose(hiddenPostActivation)
|
|
93
|
+
const gradW2 = yield* Ops.matMul(hiddenPostT, dOut)
|
|
94
|
+
const gradB2 = Ops.sumCols(dOut)
|
|
95
|
+
|
|
96
|
+
const w2T = Ops.transpose(this.w2)
|
|
97
|
+
const gradHiddenPost = yield* Ops.matMul(dOut, w2T)
|
|
98
|
+
|
|
99
|
+
const reluGrad = T.zeros(hiddenPreActivation.rows, hiddenPreActivation.cols)
|
|
100
|
+
for (let i = 0; i < hiddenPreActivation.data.length; i++) {
|
|
101
|
+
reluGrad.data[i] = hiddenPreActivation.data[i] > 0 ? 1 : 0
|
|
102
|
+
}
|
|
103
|
+
const gradHiddenPre = yield* Ops.mul(gradHiddenPost, reluGrad)
|
|
104
|
+
|
|
105
|
+
const inputT = Ops.transpose(input)
|
|
106
|
+
const gradW1 = yield* Ops.matMul(inputT, gradHiddenPre)
|
|
107
|
+
const gradB1 = Ops.sumCols(gradHiddenPre)
|
|
108
|
+
|
|
109
|
+
const w1T = Ops.transpose(this.w1)
|
|
110
|
+
const gradInputFF = yield* Ops.matMul(gradHiddenPre, w1T)
|
|
111
|
+
const gradInput = yield* Ops.add(gradInputFF, dOut)
|
|
112
|
+
|
|
113
|
+
this.optimizerW2.step(this.w2, gradW2, lr)
|
|
114
|
+
this.optimizerB2.step(this.b2, gradB2, lr)
|
|
115
|
+
this.optimizerW1.step(this.w1, gradW1, lr)
|
|
116
|
+
this.optimizerB1.step(this.b1, gradB1, lr)
|
|
117
|
+
|
|
118
|
+
return gradInput
|
|
119
|
+
})
|
|
120
|
+
}
|
|
121
|
+
}
|
package/src/model/LLM.ts
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import * as Effect from "effect/Effect"
|
|
2
|
+
import * as Option from "effect/Option"
|
|
3
|
+
import type { Tensor2D } from "../tensor/Tensor2D"
|
|
4
|
+
import * as T from "../tensor/Tensor2D"
|
|
5
|
+
import * as Ops from "../tensor/ops"
|
|
6
|
+
import type { ShapeError } from "../tensor/ops"
|
|
7
|
+
import type { ModelLayer } from "./ModelLayer"
|
|
8
|
+
import { Embeddings } from "./Embeddings"
|
|
9
|
+
import { TransformerBlock } from "./TransformerBlock"
|
|
10
|
+
import { OutputProjection } from "./OutputProjection"
|
|
11
|
+
import { Vocab } from "../vocab/Vocab"
|
|
12
|
+
import { tokenize } from "../tokenize/tokenize"
|
|
13
|
+
import { MAX_SEQ_LEN, EMBEDDING_DIM, HIDDEN_DIM } from "../config"
|
|
14
|
+
import type { Rng } from "../tensor/random"
|
|
15
|
+
|
|
16
|
+
export class LLM {
|
|
17
|
+
readonly vocab: Vocab
|
|
18
|
+
readonly network: ReadonlyArray<ModelLayer>
|
|
19
|
+
|
|
20
|
+
constructor(vocab: Vocab, network: ReadonlyArray<ModelLayer>) {
|
|
21
|
+
this.vocab = vocab
|
|
22
|
+
this.network = network
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
static default(rng: Rng, numTransformerBlocks = 1): LLM {
|
|
26
|
+
const vocab = Vocab.make(Vocab.defaultWords())
|
|
27
|
+
return LLM.make(vocab, rng, numTransformerBlocks)
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
static make(vocab: Vocab, rng: Rng, numTransformerBlocks = 3): LLM {
|
|
31
|
+
const vocabSize = vocab.words.length
|
|
32
|
+
const network: Array<ModelLayer> = [
|
|
33
|
+
new Embeddings(vocabSize, EMBEDDING_DIM, MAX_SEQ_LEN, rng),
|
|
34
|
+
...Array.from({ length: numTransformerBlocks }, () => new TransformerBlock(EMBEDDING_DIM, HIDDEN_DIM, rng)),
|
|
35
|
+
new OutputProjection(EMBEDDING_DIM, vocabSize, rng)
|
|
36
|
+
]
|
|
37
|
+
return new LLM(vocab, network)
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
networkDescription(): string {
|
|
41
|
+
return this.network.map((layer) => layer._tag).join(", ")
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
totalParameters(): number {
|
|
45
|
+
return this.network.reduce((sum, layer) => sum + layer.parametersCount, 0)
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
predict(text: string): Effect.Effect<string, ShapeError> {
|
|
49
|
+
return Effect.gen(this, function* () {
|
|
50
|
+
const outputTokens = yield* this.forward(text)
|
|
51
|
+
|
|
52
|
+
if (outputTokens.length === 0) {
|
|
53
|
+
return ""
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
const tokenStrs: Array<string> = []
|
|
57
|
+
for (const t of outputTokens) {
|
|
58
|
+
const decoded = this.vocab.decode(t)
|
|
59
|
+
if (Option.isSome(decoded)) {
|
|
60
|
+
tokenStrs.push(decoded.value)
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
return tokenStrs.join(" ")
|
|
65
|
+
})
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
forward(text: string): Effect.Effect<ReadonlyArray<number>, ShapeError> {
|
|
69
|
+
return Effect.gen(this, function* () {
|
|
70
|
+
const tokenized: Array<number> = [...tokenize(text, this.vocab)]
|
|
71
|
+
const outputTokens: Array<number> = []
|
|
72
|
+
|
|
73
|
+
if (tokenized.length === 0) {
|
|
74
|
+
return outputTokens
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
const inputLen = tokenized.length
|
|
78
|
+
if (inputLen >= MAX_SEQ_LEN) {
|
|
79
|
+
return outputTokens
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
const endTokenId = this.vocab.encode("</s>")
|
|
83
|
+
if (Option.isNone(endTokenId)) {
|
|
84
|
+
return yield* Effect.fail(new Ops.ShapeError("End token </s> not found in vocabulary"))
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
for (let step = 0; step < MAX_SEQ_LEN - inputLen; step++) {
|
|
88
|
+
if (outputTokens.length >= MAX_SEQ_LEN - 1) {
|
|
89
|
+
break
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
const tokenInput = T.fromArray(1, tokenized.length, tokenized)
|
|
93
|
+
let input: Tensor2D = tokenInput
|
|
94
|
+
|
|
95
|
+
for (const layer of this.network) {
|
|
96
|
+
input = yield* layer.forward(input)
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
const logits = input
|
|
100
|
+
|
|
101
|
+
if (logits.rows === 0) {
|
|
102
|
+
break
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
const lastRowStart = (logits.rows - 1) * logits.cols
|
|
106
|
+
const lastLogitData = logits.data.slice(lastRowStart, lastRowStart + logits.cols)
|
|
107
|
+
const lastLogit = T.make(1, logits.cols, lastLogitData)
|
|
108
|
+
|
|
109
|
+
const probs = Ops.softmaxRows(lastLogit)
|
|
110
|
+
const tokens = Ops.argmaxRows(probs)
|
|
111
|
+
const nextToken = tokens[tokens.length - 1]
|
|
112
|
+
|
|
113
|
+
outputTokens.push(nextToken)
|
|
114
|
+
tokenized.push(nextToken)
|
|
115
|
+
|
|
116
|
+
if (nextToken === endTokenId.value) {
|
|
117
|
+
break
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
return outputTokens
|
|
122
|
+
})
|
|
123
|
+
}
|
|
124
|
+
}
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import * as Effect from "effect/Effect"
|
|
2
|
+
import * as FiberId from "effect/FiberId"
|
|
3
|
+
import type { Tensor2D } from "../tensor/Tensor2D"
|
|
4
|
+
import * as T from "../tensor/Tensor2D"
|
|
5
|
+
import * as Ops from "../tensor/ops"
|
|
6
|
+
import type { ShapeError } from "../tensor/ops"
|
|
7
|
+
import type { ModelLayer } from "./ModelLayer"
|
|
8
|
+
import { Adam } from "../training/Adam"
|
|
9
|
+
|
|
10
|
+
export class LayerNorm implements ModelLayer {
|
|
11
|
+
readonly _tag = "LayerNorm"
|
|
12
|
+
readonly epsilon: number = 1e-5
|
|
13
|
+
gamma: Tensor2D
|
|
14
|
+
beta: Tensor2D
|
|
15
|
+
|
|
16
|
+
private cache = new Map<number | string, { input: Tensor2D; mean: Tensor2D; variance: Tensor2D }>()
|
|
17
|
+
private lastCache: { input: Tensor2D; mean: Tensor2D; variance: Tensor2D } | null = null
|
|
18
|
+
optimizerGamma: Adam
|
|
19
|
+
optimizerBeta: Adam
|
|
20
|
+
|
|
21
|
+
constructor(embeddingDim: number) {
|
|
22
|
+
this.gamma = T.ones(1, embeddingDim)
|
|
23
|
+
this.beta = T.zeros(1, embeddingDim)
|
|
24
|
+
this.optimizerGamma = Adam.make(1, embeddingDim)
|
|
25
|
+
this.optimizerBeta = Adam.make(1, embeddingDim)
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
private fiberKey(fiberId: FiberId.FiberId): number | string {
|
|
29
|
+
return FiberId.isRuntime(fiberId) ? fiberId.id : JSON.stringify(fiberId)
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
get parametersCount(): number {
|
|
33
|
+
return this.gamma.data.length + this.beta.data.length
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
forward(input: Tensor2D): Effect.Effect<Tensor2D, ShapeError> {
|
|
37
|
+
return Effect.gen(this, function* () {
|
|
38
|
+
const mean = Ops.meanRows(input)
|
|
39
|
+
const variance = Ops.varRows(input)
|
|
40
|
+
|
|
41
|
+
const fiberId = yield* Effect.fiberId
|
|
42
|
+
const key = this.fiberKey(fiberId)
|
|
43
|
+
const cached = {
|
|
44
|
+
input: T.clone(input),
|
|
45
|
+
mean: T.clone(mean),
|
|
46
|
+
variance: T.clone(variance)
|
|
47
|
+
}
|
|
48
|
+
this.cache.set(key, cached)
|
|
49
|
+
this.lastCache = cached
|
|
50
|
+
|
|
51
|
+
// Use sqrt(variance + epsilon) for numerical stability
|
|
52
|
+
const rstd = Ops.mapScalar(variance, (v) => 1.0 / Math.sqrt(v + this.epsilon))
|
|
53
|
+
const centered = yield* Ops.broadcastSubCol(input, mean)
|
|
54
|
+
const normalized = yield* Ops.broadcastMulCol(centered, rstd)
|
|
55
|
+
const scaled = yield* Ops.broadcastMulRow(normalized, this.gamma)
|
|
56
|
+
const shifted = yield* Ops.broadcastAddRow(scaled, this.beta)
|
|
57
|
+
return shifted
|
|
58
|
+
})
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
backward(dOut: Tensor2D, lr: number): Effect.Effect<Tensor2D, ShapeError> {
|
|
62
|
+
return Effect.gen(this, function* () {
|
|
63
|
+
const fiberId = yield* Effect.fiberId
|
|
64
|
+
const key = this.fiberKey(fiberId)
|
|
65
|
+
const cached = this.cache.get(key) ?? this.lastCache
|
|
66
|
+
if (!cached) {
|
|
67
|
+
return yield* Effect.fail(new Ops.ShapeError("LayerNorm.backward called before forward"))
|
|
68
|
+
}
|
|
69
|
+
this.cache.delete(key)
|
|
70
|
+
this.lastCache = null
|
|
71
|
+
|
|
72
|
+
const { input, mean, variance } = cached
|
|
73
|
+
const rows = input.rows
|
|
74
|
+
const cols = input.cols
|
|
75
|
+
const nFeatures = cols
|
|
76
|
+
|
|
77
|
+
const normalized = T.zeros(rows, cols)
|
|
78
|
+
const gradNormalized = T.zeros(rows, cols)
|
|
79
|
+
for (let i = 0; i < rows; i++) {
|
|
80
|
+
const meanVal = mean.data[i]
|
|
81
|
+
// Consistent with forward: rstd = 1 / sqrt(variance + epsilon)
|
|
82
|
+
const rstd = 1.0 / Math.sqrt(variance.data[i] + this.epsilon)
|
|
83
|
+
for (let j = 0; j < cols; j++) {
|
|
84
|
+
const idx = i * cols + j
|
|
85
|
+
const norm = (input.data[idx] - meanVal) * rstd
|
|
86
|
+
normalized.data[idx] = norm
|
|
87
|
+
gradNormalized.data[idx] = this.gamma.data[j] * dOut.data[idx]
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
const gradGamma = T.zeros(1, cols)
|
|
92
|
+
const gradBeta = T.zeros(1, cols)
|
|
93
|
+
for (let j = 0; j < cols; j++) {
|
|
94
|
+
let sumGamma = 0
|
|
95
|
+
let sumBeta = 0
|
|
96
|
+
for (let i = 0; i < rows; i++) {
|
|
97
|
+
const idx = i * cols + j
|
|
98
|
+
sumGamma += normalized.data[idx] * dOut.data[idx]
|
|
99
|
+
sumBeta += dOut.data[idx]
|
|
100
|
+
}
|
|
101
|
+
gradGamma.data[j] = sumGamma
|
|
102
|
+
gradBeta.data[j] = sumBeta
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
const gradInput = T.zeros(rows, cols)
|
|
106
|
+
for (let i = 0; i < rows; i++) {
|
|
107
|
+
const meanVal = mean.data[i]
|
|
108
|
+
const varPlusEps = variance.data[i] + this.epsilon
|
|
109
|
+
const rstd = 1.0 / Math.sqrt(varPlusEps)
|
|
110
|
+
|
|
111
|
+
let sumGradNormalized = 0
|
|
112
|
+
let sumGradNormTimesNorm = 0
|
|
113
|
+
const rowOffset = i * cols
|
|
114
|
+
|
|
115
|
+
for (let j = 0; j < cols; j++) {
|
|
116
|
+
const idx = rowOffset + j
|
|
117
|
+
sumGradNormalized += gradNormalized.data[idx]
|
|
118
|
+
sumGradNormTimesNorm += gradNormalized.data[idx] * normalized.data[idx]
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
// Gradient of LayerNorm: dL/dx = rstd * (dL/dnorm - mean(dL/dnorm) - norm * mean(dL/dnorm * norm))
|
|
122
|
+
for (let j = 0; j < cols; j++) {
|
|
123
|
+
const idx = rowOffset + j
|
|
124
|
+
gradInput.data[idx] =
|
|
125
|
+
rstd *
|
|
126
|
+
(gradNormalized.data[idx] -
|
|
127
|
+
sumGradNormalized / nFeatures -
|
|
128
|
+
(normalized.data[idx] * sumGradNormTimesNorm) / nFeatures)
|
|
129
|
+
}
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
this.optimizerGamma.step(this.gamma, gradGamma, lr)
|
|
133
|
+
this.optimizerBeta.step(this.beta, gradBeta, lr)
|
|
134
|
+
|
|
135
|
+
return gradInput
|
|
136
|
+
})
|
|
137
|
+
}
|
|
138
|
+
}
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
import type * as Effect from "effect/Effect"
|
|
2
|
+
import type { Tensor2D } from "../tensor/Tensor2D"
|
|
3
|
+
import type { ShapeError } from "../tensor/ops"
|
|
4
|
+
|
|
5
|
+
export interface ModelLayer {
|
|
6
|
+
readonly _tag: string
|
|
7
|
+
readonly parametersCount: number
|
|
8
|
+
forward(input: Tensor2D): Effect.Effect<Tensor2D, ShapeError>
|
|
9
|
+
backward(dOut: Tensor2D, lr: number): Effect.Effect<Tensor2D, ShapeError>
|
|
10
|
+
}
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import * as Effect from "effect/Effect"
|
|
2
|
+
import * as FiberId from "effect/FiberId"
|
|
3
|
+
import type { Tensor2D } from "../tensor/Tensor2D"
|
|
4
|
+
import * as T from "../tensor/Tensor2D"
|
|
5
|
+
import * as Ops from "../tensor/ops"
|
|
6
|
+
import type { ShapeError } from "../tensor/ops"
|
|
7
|
+
import type { ModelLayer } from "./ModelLayer"
|
|
8
|
+
import { EMBEDDING_DIM } from "../config"
|
|
9
|
+
import { Adam } from "../training/Adam"
|
|
10
|
+
import type { Rng } from "../tensor/random"
|
|
11
|
+
|
|
12
|
+
export class OutputProjection implements ModelLayer {
|
|
13
|
+
readonly _tag = "OutputProjection"
|
|
14
|
+
wOut: Tensor2D
|
|
15
|
+
bOut: Tensor2D
|
|
16
|
+
|
|
17
|
+
private cache = new Map<number | string, Tensor2D>()
|
|
18
|
+
private lastCache: Tensor2D | null = null
|
|
19
|
+
optimizerWOut: Adam
|
|
20
|
+
|
|
21
|
+
constructor(embeddingDim: number = EMBEDDING_DIM, vocabSize: number, rng: Rng) {
|
|
22
|
+
const std = Math.sqrt(2.0 / embeddingDim)
|
|
23
|
+
this.wOut = Ops.initNormal(embeddingDim, vocabSize, 0, std, rng)
|
|
24
|
+
this.bOut = T.zeros(1, vocabSize)
|
|
25
|
+
this.optimizerWOut = Adam.make(embeddingDim, vocabSize)
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
private fiberKey(fiberId: FiberId.FiberId): number | string {
|
|
29
|
+
return FiberId.isRuntime(fiberId) ? fiberId.id : JSON.stringify(fiberId)
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
get parametersCount(): number {
|
|
33
|
+
return this.wOut.data.length + this.bOut.data.length
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
forward(input: Tensor2D): Effect.Effect<Tensor2D, ShapeError> {
|
|
37
|
+
return Effect.gen(this, function* () {
|
|
38
|
+
const fiberId = yield* Effect.fiberId
|
|
39
|
+
const key = this.fiberKey(fiberId)
|
|
40
|
+
const cloned = T.clone(input)
|
|
41
|
+
this.cache.set(key, cloned)
|
|
42
|
+
this.lastCache = cloned
|
|
43
|
+
const projected = yield* Ops.matMul(input, this.wOut)
|
|
44
|
+
const output = yield* Ops.addRowBias(projected, this.bOut)
|
|
45
|
+
return output
|
|
46
|
+
})
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
backward(dOut: Tensor2D, lr: number): Effect.Effect<Tensor2D, ShapeError> {
|
|
50
|
+
return Effect.gen(this, function* () {
|
|
51
|
+
const fiberId = yield* Effect.fiberId
|
|
52
|
+
const key = this.fiberKey(fiberId)
|
|
53
|
+
const cachedInput = this.cache.get(key) ?? this.lastCache
|
|
54
|
+
if (!cachedInput) {
|
|
55
|
+
return yield* Effect.fail(new Ops.ShapeError("OutputProjection.backward called before forward"))
|
|
56
|
+
}
|
|
57
|
+
this.cache.delete(key)
|
|
58
|
+
this.lastCache = null
|
|
59
|
+
|
|
60
|
+
const input = cachedInput
|
|
61
|
+
const inputT = Ops.transpose(input)
|
|
62
|
+
const gradWOut = yield* Ops.matMul(inputT, dOut)
|
|
63
|
+
const gradBOut = Ops.sumCols(dOut)
|
|
64
|
+
|
|
65
|
+
const wOutT = Ops.transpose(this.wOut)
|
|
66
|
+
const gradInput = yield* Ops.matMul(dOut, wOutT)
|
|
67
|
+
|
|
68
|
+
this.optimizerWOut.step(this.wOut, gradWOut, lr)
|
|
69
|
+
for (let j = 0; j < this.bOut.data.length; j++) {
|
|
70
|
+
this.bOut.data[j] -= lr * gradBOut.data[j]
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
return gradInput
|
|
74
|
+
})
|
|
75
|
+
}
|
|
76
|
+
}
|