effect-gpt 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. package/README.md +50 -0
  2. package/data/chat_training_data.json +55 -0
  3. package/data/pretraining_data.json +27 -0
  4. package/package.json +25 -0
  5. package/src/cli/errors.ts +51 -0
  6. package/src/cli/main.ts +163 -0
  7. package/src/config.ts +3 -0
  8. package/src/data/Dataset.ts +168 -0
  9. package/src/errors.ts +73 -0
  10. package/src/index.ts +88 -0
  11. package/src/model/Embeddings.ts +108 -0
  12. package/src/model/FeedForward.ts +121 -0
  13. package/src/model/LLM.ts +124 -0
  14. package/src/model/LayerNorm.ts +138 -0
  15. package/src/model/ModelLayer.ts +10 -0
  16. package/src/model/OutputProjection.ts +76 -0
  17. package/src/model/SelfAttention.ts +169 -0
  18. package/src/model/TransformerBlock.ts +53 -0
  19. package/src/services/Logger.ts +124 -0
  20. package/src/services/Metrics.ts +260 -0
  21. package/src/services/Random.ts +98 -0
  22. package/src/services/SeedLayer.ts +39 -0
  23. package/src/services/index.ts +32 -0
  24. package/src/tensor/Tensor2D.ts +42 -0
  25. package/src/tensor/ops.ts +371 -0
  26. package/src/tensor/random.ts +32 -0
  27. package/src/tokenize/split.ts +27 -0
  28. package/src/tokenize/tokenize.ts +28 -0
  29. package/src/training/Adam.ts +61 -0
  30. package/src/training/clip.ts +16 -0
  31. package/src/training/loss.ts +35 -0
  32. package/src/training/train.ts +203 -0
  33. package/src/vocab/Vocab.ts +79 -0
  34. package/tests/fixtures/csv_bad.csv +2 -0
  35. package/tests/fixtures/csv_good.csv +3 -0
  36. package/tests/ts/cli_error_format.test.ts +26 -0
  37. package/tests/ts/dataset.test.ts +35 -0
  38. package/tests/ts/embeddings.test.ts +81 -0
  39. package/tests/ts/errors.test.ts +36 -0
  40. package/tests/ts/feed_forward.test.ts +74 -0
  41. package/tests/ts/initNormal.test.ts +41 -0
  42. package/tests/ts/layer_norm.test.ts +96 -0
  43. package/tests/ts/llm_parameters.test.ts +96 -0
  44. package/tests/ts/llm_predict.test.ts +98 -0
  45. package/tests/ts/llm_tokenize.test.ts +69 -0
  46. package/tests/ts/output_projection.test.ts +78 -0
  47. package/tests/ts/random.test.ts +44 -0
  48. package/tests/ts/self_attention.test.ts +63 -0
  49. package/tests/ts/support/factories.ts +126 -0
  50. package/tests/ts/support/runEffect.ts +29 -0
  51. package/tests/ts/support/seed.ts +12 -0
  52. package/tests/ts/support/stubs.ts +58 -0
  53. package/tests/ts/support/tensorMatchers.ts +96 -0
  54. package/tests/ts/support.test.ts +165 -0
  55. package/tests/ts/train_loop.test.ts +229 -0
  56. package/tests/ts/transformer_block.test.ts +72 -0
  57. package/tsconfig.json +20 -0
  58. package/tsconfig.test.json +8 -0
package/src/index.ts ADDED
@@ -0,0 +1,88 @@
1
+ export { Vocab } from "./vocab/Vocab"
2
+ export { tokenize } from "./tokenize/tokenize"
3
+ export { Dataset, DatasetLoadError, DatasetParseError } from "./data/Dataset"
4
+ export {
5
+ TrainingError,
6
+ TrainingDatasetError,
7
+ TrainingShapeError,
8
+ TrainingTokenizerError,
9
+ TrainingOptimizerError,
10
+ TrainingConfigError,
11
+ TrainingUnknownError
12
+ } from "./errors"
13
+ export * from "./config"
14
+
15
+ export type { Tensor2D } from "./tensor/Tensor2D"
16
+ export * as T2D from "./tensor/Tensor2D"
17
+ export * as TensorOps from "./tensor/ops"
18
+ export { ShapeError } from "./tensor/ops"
19
+ export type { Rng } from "./tensor/random"
20
+ export { seeded } from "./tensor/random"
21
+ export { systemRng } from "./tensor/random"
22
+
23
+ export type { ModelLayer } from "./model/ModelLayer"
24
+ export { Embeddings } from "./model/Embeddings"
25
+ export { SelfAttention } from "./model/SelfAttention"
26
+ export { FeedForward } from "./model/FeedForward"
27
+ export { LayerNorm } from "./model/LayerNorm"
28
+ export { TransformerBlock } from "./model/TransformerBlock"
29
+ export { OutputProjection } from "./model/OutputProjection"
30
+ export { LLM } from "./model/LLM"
31
+
32
+ export { Adam } from "./training/Adam"
33
+ export { clipGlobalL2 } from "./training/clip"
34
+ export { softmaxRows, crossEntropyLoss, dLogits } from "./training/loss"
35
+ export {
36
+ train,
37
+ trainStream,
38
+ LLMService,
39
+ TrainingConfig,
40
+ makeLLMLayer,
41
+ makeTrainingConfigLayer
42
+ } from "./training/train"
43
+
44
+ export type {
45
+ LogLevel,
46
+ LoggerService,
47
+ LoggerServiceId,
48
+ RandomService,
49
+ RandomServiceId,
50
+ SeedService,
51
+ SeedServiceId,
52
+ MetricsService,
53
+ MetricsServiceId,
54
+ MetricsSnapshot,
55
+ Counter,
56
+ Gauge,
57
+ Histogram,
58
+ TimingResult
59
+ } from "./services"
60
+ export {
61
+ Logger,
62
+ ConsoleLoggerLive,
63
+ TerminalLoggerLive,
64
+ NullLoggerLive,
65
+ log,
66
+ debug,
67
+ info,
68
+ warn,
69
+ error,
70
+ Random,
71
+ SeededRandomLive,
72
+ SystemRandomLive,
73
+ next,
74
+ nextGaussian,
75
+ nextInt,
76
+ fork,
77
+ Seed,
78
+ SeedLayer,
79
+ useSeedRng,
80
+ Metrics,
81
+ InMemoryMetricsLive,
82
+ NoOpMetricsLive,
83
+ counter,
84
+ gauge,
85
+ histogram,
86
+ timed,
87
+ snapshot
88
+ } from "./services"
@@ -0,0 +1,108 @@
1
+ import * as Effect from "effect/Effect"
2
+ import * as FiberId from "effect/FiberId"
3
+ import type { Tensor2D } from "../tensor/Tensor2D"
4
+ import * as T from "../tensor/Tensor2D"
5
+ import * as Ops from "../tensor/ops"
6
+ import type { ShapeError } from "../tensor/ops"
7
+ import type { ModelLayer } from "./ModelLayer"
8
+ import { MAX_SEQ_LEN, EMBEDDING_DIM } from "../config"
9
+ import { Adam } from "../training/Adam"
10
+ import type { Rng } from "../tensor/random"
11
+
12
+ export class Embeddings implements ModelLayer {
13
+ readonly _tag = "Embeddings"
14
+ tokenEmbeddings: Tensor2D
15
+ positionalEmbeddings: Tensor2D
16
+
17
+ private cache = new Map<number | string, Tensor2D>()
18
+ private lastCache: Tensor2D | null = null
19
+ tokenOptimizer: Adam
20
+ positionalOptimizer: Adam
21
+
22
+ constructor(vocabSize: number, embeddingDim: number = EMBEDDING_DIM, maxSeqLen: number = MAX_SEQ_LEN, rng: Rng) {
23
+ this.tokenEmbeddings = Ops.initNormal(vocabSize, embeddingDim, 0, 0.02, rng)
24
+ this.positionalEmbeddings = Ops.initNormal(maxSeqLen, embeddingDim, 0, 0.02, rng)
25
+ this.tokenOptimizer = Adam.make(vocabSize, embeddingDim)
26
+ this.positionalOptimizer = Adam.make(maxSeqLen, embeddingDim)
27
+ }
28
+
29
+ private fiberKey(fiberId: FiberId.FiberId): number | string {
30
+ return FiberId.isRuntime(fiberId) ? fiberId.id : JSON.stringify(fiberId)
31
+ }
32
+
33
+ get parametersCount(): number {
34
+ return this.tokenEmbeddings.data.length + this.positionalEmbeddings.data.length
35
+ }
36
+
37
+ forward(input: Tensor2D): Effect.Effect<Tensor2D, ShapeError> {
38
+ return Effect.gen(this, function* () {
39
+ const fiberId = yield* Effect.fiberId
40
+ const key = this.fiberKey(fiberId)
41
+ const cloned = T.clone(input)
42
+ this.cache.set(key, cloned)
43
+ this.lastCache = cloned
44
+ const tokenIds: Array<number> = []
45
+ for (let i = 0; i < input.data.length; i++) {
46
+ // Match Rust's float-to-usize truncation behavior.
47
+ tokenIds.push(Math.trunc(input.data[i]))
48
+ }
49
+
50
+ const seqLen = tokenIds.length
51
+ if (seqLen > this.positionalEmbeddings.rows) {
52
+ return yield* Effect.fail(
53
+ new Ops.ShapeError(`Sequence length ${seqLen} exceeds maximum ${this.positionalEmbeddings.rows}`)
54
+ )
55
+ }
56
+
57
+ const tokenEmbeds = yield* Ops.gatherRows(this.tokenEmbeddings, tokenIds)
58
+ const posEmbeds = yield* Ops.sliceRows(this.positionalEmbeddings, 0, seqLen)
59
+ const combined = yield* Ops.add(tokenEmbeds, posEmbeds)
60
+ return combined
61
+ })
62
+ }
63
+
64
+ backward(dOut: Tensor2D, lr: number): Effect.Effect<Tensor2D, ShapeError> {
65
+ return Effect.gen(this, function* () {
66
+ const fiberId = yield* Effect.fiberId
67
+ const key = this.fiberKey(fiberId)
68
+ const cachedInput = this.cache.get(key) ?? this.lastCache
69
+ if (!cachedInput) {
70
+ return yield* Effect.fail(new Ops.ShapeError("Embeddings.backward called before forward"))
71
+ }
72
+ this.cache.delete(key)
73
+ this.lastCache = null
74
+
75
+ const input = cachedInput
76
+ const tokenIds: Array<number> = []
77
+ for (let i = 0; i < input.data.length; i++) {
78
+ tokenIds.push(Math.trunc(input.data[i]))
79
+ }
80
+
81
+ const tokenGrads = T.zeros(this.tokenEmbeddings.rows, this.tokenEmbeddings.cols)
82
+ const positionalGrads = T.zeros(this.positionalEmbeddings.rows, this.positionalEmbeddings.cols)
83
+
84
+ const seqLen = tokenIds.length
85
+ for (let i = 0; i < seqLen; i++) {
86
+ const tokenId = tokenIds[i]
87
+ if (tokenId < 0 || tokenId >= this.tokenEmbeddings.rows) {
88
+ return yield* Effect.fail(
89
+ new Ops.ShapeError(`Token ID ${tokenId} out of bounds for vocab size ${this.tokenEmbeddings.rows}`)
90
+ )
91
+ }
92
+ const rowOffset = i * dOut.cols
93
+ const tokenOffset = tokenId * tokenGrads.cols
94
+ const posOffset = i * positionalGrads.cols
95
+ for (let j = 0; j < dOut.cols; j++) {
96
+ const grad = dOut.data[rowOffset + j]
97
+ tokenGrads.data[tokenOffset + j] += grad
98
+ positionalGrads.data[posOffset + j] += grad
99
+ }
100
+ }
101
+
102
+ this.tokenOptimizer.step(this.tokenEmbeddings, tokenGrads, lr)
103
+ this.positionalOptimizer.step(this.positionalEmbeddings, positionalGrads, lr)
104
+
105
+ return T.clone(dOut)
106
+ })
107
+ }
108
+ }
@@ -0,0 +1,121 @@
1
+ import * as Effect from "effect/Effect"
2
+ import * as FiberId from "effect/FiberId"
3
+ import type { Tensor2D } from "../tensor/Tensor2D"
4
+ import * as T from "../tensor/Tensor2D"
5
+ import * as Ops from "../tensor/ops"
6
+ import type { ShapeError } from "../tensor/ops"
7
+ import type { ModelLayer } from "./ModelLayer"
8
+ import { EMBEDDING_DIM, HIDDEN_DIM } from "../config"
9
+ import { Adam } from "../training/Adam"
10
+ import type { Rng } from "../tensor/random"
11
+
12
+ export class FeedForward implements ModelLayer {
13
+ readonly _tag = "FeedForward"
14
+ w1: Tensor2D
15
+ b1: Tensor2D
16
+ w2: Tensor2D
17
+ b2: Tensor2D
18
+
19
+ private cache = new Map<
20
+ number | string,
21
+ { input: Tensor2D; hiddenPreActivation: Tensor2D; hiddenPostActivation: Tensor2D }
22
+ >()
23
+ private lastCache:
24
+ | { input: Tensor2D; hiddenPreActivation: Tensor2D; hiddenPostActivation: Tensor2D }
25
+ | null = null
26
+ optimizerW1: Adam
27
+ optimizerB1: Adam
28
+ optimizerW2: Adam
29
+ optimizerB2: Adam
30
+
31
+ constructor(embeddingDim: number = EMBEDDING_DIM, hiddenDim: number = HIDDEN_DIM, rng: Rng) {
32
+ const stdW1 = Math.sqrt(2.0 / embeddingDim)
33
+ const stdW2 = Math.sqrt(2.0 / hiddenDim)
34
+
35
+ this.w1 = Ops.initNormal(embeddingDim, hiddenDim, 0, stdW1, rng)
36
+ this.b1 = T.zeros(1, hiddenDim)
37
+ this.w2 = Ops.initNormal(hiddenDim, embeddingDim, 0, stdW2, rng)
38
+ this.b2 = T.zeros(1, embeddingDim)
39
+ this.optimizerW1 = Adam.make(embeddingDim, hiddenDim)
40
+ this.optimizerB1 = Adam.make(1, hiddenDim)
41
+ this.optimizerW2 = Adam.make(hiddenDim, embeddingDim)
42
+ this.optimizerB2 = Adam.make(1, embeddingDim)
43
+ }
44
+
45
+ private fiberKey(fiberId: FiberId.FiberId): number | string {
46
+ return FiberId.isRuntime(fiberId) ? fiberId.id : JSON.stringify(fiberId)
47
+ }
48
+
49
+ get parametersCount(): number {
50
+ return this.w1.data.length + this.b1.data.length + this.w2.data.length + this.b2.data.length
51
+ }
52
+
53
+ forward(input: Tensor2D): Effect.Effect<Tensor2D, ShapeError> {
54
+ return Effect.gen(this, function* () {
55
+ const fiberId = yield* Effect.fiberId
56
+ const key = this.fiberKey(fiberId)
57
+
58
+ const h1 = yield* Ops.matMul(input, this.w1)
59
+ const h1Bias = yield* Ops.addRowBias(h1, this.b1)
60
+ const h1BiasClone = T.clone(h1Bias)
61
+
62
+ const h1Relu = Ops.relu(h1Bias)
63
+ const h1ReluClone = T.clone(h1Relu)
64
+ const cached = {
65
+ input: T.clone(input),
66
+ hiddenPreActivation: h1BiasClone,
67
+ hiddenPostActivation: h1ReluClone
68
+ }
69
+ this.cache.set(key, cached)
70
+ this.lastCache = cached
71
+
72
+ const h2 = yield* Ops.matMul(h1Relu, this.w2)
73
+ const h2Bias = yield* Ops.addRowBias(h2, this.b2)
74
+ const output = yield* Ops.add(h2Bias, input)
75
+ return output
76
+ })
77
+ }
78
+
79
+ backward(dOut: Tensor2D, lr: number): Effect.Effect<Tensor2D, ShapeError> {
80
+ return Effect.gen(this, function* () {
81
+ const fiberId = yield* Effect.fiberId
82
+ const key = this.fiberKey(fiberId)
83
+ const cached = this.cache.get(key) ?? this.lastCache
84
+ if (!cached) {
85
+ return yield* Effect.fail(new Ops.ShapeError("FeedForward.backward called before forward"))
86
+ }
87
+ this.cache.delete(key)
88
+ this.lastCache = null
89
+
90
+ const { input, hiddenPreActivation, hiddenPostActivation } = cached
91
+
92
+ const hiddenPostT = Ops.transpose(hiddenPostActivation)
93
+ const gradW2 = yield* Ops.matMul(hiddenPostT, dOut)
94
+ const gradB2 = Ops.sumCols(dOut)
95
+
96
+ const w2T = Ops.transpose(this.w2)
97
+ const gradHiddenPost = yield* Ops.matMul(dOut, w2T)
98
+
99
+ const reluGrad = T.zeros(hiddenPreActivation.rows, hiddenPreActivation.cols)
100
+ for (let i = 0; i < hiddenPreActivation.data.length; i++) {
101
+ reluGrad.data[i] = hiddenPreActivation.data[i] > 0 ? 1 : 0
102
+ }
103
+ const gradHiddenPre = yield* Ops.mul(gradHiddenPost, reluGrad)
104
+
105
+ const inputT = Ops.transpose(input)
106
+ const gradW1 = yield* Ops.matMul(inputT, gradHiddenPre)
107
+ const gradB1 = Ops.sumCols(gradHiddenPre)
108
+
109
+ const w1T = Ops.transpose(this.w1)
110
+ const gradInputFF = yield* Ops.matMul(gradHiddenPre, w1T)
111
+ const gradInput = yield* Ops.add(gradInputFF, dOut)
112
+
113
+ this.optimizerW2.step(this.w2, gradW2, lr)
114
+ this.optimizerB2.step(this.b2, gradB2, lr)
115
+ this.optimizerW1.step(this.w1, gradW1, lr)
116
+ this.optimizerB1.step(this.b1, gradB1, lr)
117
+
118
+ return gradInput
119
+ })
120
+ }
121
+ }
@@ -0,0 +1,124 @@
1
+ import * as Effect from "effect/Effect"
2
+ import * as Option from "effect/Option"
3
+ import type { Tensor2D } from "../tensor/Tensor2D"
4
+ import * as T from "../tensor/Tensor2D"
5
+ import * as Ops from "../tensor/ops"
6
+ import type { ShapeError } from "../tensor/ops"
7
+ import type { ModelLayer } from "./ModelLayer"
8
+ import { Embeddings } from "./Embeddings"
9
+ import { TransformerBlock } from "./TransformerBlock"
10
+ import { OutputProjection } from "./OutputProjection"
11
+ import { Vocab } from "../vocab/Vocab"
12
+ import { tokenize } from "../tokenize/tokenize"
13
+ import { MAX_SEQ_LEN, EMBEDDING_DIM, HIDDEN_DIM } from "../config"
14
+ import type { Rng } from "../tensor/random"
15
+
16
+ export class LLM {
17
+ readonly vocab: Vocab
18
+ readonly network: ReadonlyArray<ModelLayer>
19
+
20
+ constructor(vocab: Vocab, network: ReadonlyArray<ModelLayer>) {
21
+ this.vocab = vocab
22
+ this.network = network
23
+ }
24
+
25
+ static default(rng: Rng, numTransformerBlocks = 1): LLM {
26
+ const vocab = Vocab.make(Vocab.defaultWords())
27
+ return LLM.make(vocab, rng, numTransformerBlocks)
28
+ }
29
+
30
+ static make(vocab: Vocab, rng: Rng, numTransformerBlocks = 3): LLM {
31
+ const vocabSize = vocab.words.length
32
+ const network: Array<ModelLayer> = [
33
+ new Embeddings(vocabSize, EMBEDDING_DIM, MAX_SEQ_LEN, rng),
34
+ ...Array.from({ length: numTransformerBlocks }, () => new TransformerBlock(EMBEDDING_DIM, HIDDEN_DIM, rng)),
35
+ new OutputProjection(EMBEDDING_DIM, vocabSize, rng)
36
+ ]
37
+ return new LLM(vocab, network)
38
+ }
39
+
40
+ networkDescription(): string {
41
+ return this.network.map((layer) => layer._tag).join(", ")
42
+ }
43
+
44
+ totalParameters(): number {
45
+ return this.network.reduce((sum, layer) => sum + layer.parametersCount, 0)
46
+ }
47
+
48
+ predict(text: string): Effect.Effect<string, ShapeError> {
49
+ return Effect.gen(this, function* () {
50
+ const outputTokens = yield* this.forward(text)
51
+
52
+ if (outputTokens.length === 0) {
53
+ return ""
54
+ }
55
+
56
+ const tokenStrs: Array<string> = []
57
+ for (const t of outputTokens) {
58
+ const decoded = this.vocab.decode(t)
59
+ if (Option.isSome(decoded)) {
60
+ tokenStrs.push(decoded.value)
61
+ }
62
+ }
63
+
64
+ return tokenStrs.join(" ")
65
+ })
66
+ }
67
+
68
+ forward(text: string): Effect.Effect<ReadonlyArray<number>, ShapeError> {
69
+ return Effect.gen(this, function* () {
70
+ const tokenized: Array<number> = [...tokenize(text, this.vocab)]
71
+ const outputTokens: Array<number> = []
72
+
73
+ if (tokenized.length === 0) {
74
+ return outputTokens
75
+ }
76
+
77
+ const inputLen = tokenized.length
78
+ if (inputLen >= MAX_SEQ_LEN) {
79
+ return outputTokens
80
+ }
81
+
82
+ const endTokenId = this.vocab.encode("</s>")
83
+ if (Option.isNone(endTokenId)) {
84
+ return yield* Effect.fail(new Ops.ShapeError("End token </s> not found in vocabulary"))
85
+ }
86
+
87
+ for (let step = 0; step < MAX_SEQ_LEN - inputLen; step++) {
88
+ if (outputTokens.length >= MAX_SEQ_LEN - 1) {
89
+ break
90
+ }
91
+
92
+ const tokenInput = T.fromArray(1, tokenized.length, tokenized)
93
+ let input: Tensor2D = tokenInput
94
+
95
+ for (const layer of this.network) {
96
+ input = yield* layer.forward(input)
97
+ }
98
+
99
+ const logits = input
100
+
101
+ if (logits.rows === 0) {
102
+ break
103
+ }
104
+
105
+ const lastRowStart = (logits.rows - 1) * logits.cols
106
+ const lastLogitData = logits.data.slice(lastRowStart, lastRowStart + logits.cols)
107
+ const lastLogit = T.make(1, logits.cols, lastLogitData)
108
+
109
+ const probs = Ops.softmaxRows(lastLogit)
110
+ const tokens = Ops.argmaxRows(probs)
111
+ const nextToken = tokens[tokens.length - 1]
112
+
113
+ outputTokens.push(nextToken)
114
+ tokenized.push(nextToken)
115
+
116
+ if (nextToken === endTokenId.value) {
117
+ break
118
+ }
119
+ }
120
+
121
+ return outputTokens
122
+ })
123
+ }
124
+ }
@@ -0,0 +1,138 @@
1
+ import * as Effect from "effect/Effect"
2
+ import * as FiberId from "effect/FiberId"
3
+ import type { Tensor2D } from "../tensor/Tensor2D"
4
+ import * as T from "../tensor/Tensor2D"
5
+ import * as Ops from "../tensor/ops"
6
+ import type { ShapeError } from "../tensor/ops"
7
+ import type { ModelLayer } from "./ModelLayer"
8
+ import { Adam } from "../training/Adam"
9
+
10
+ export class LayerNorm implements ModelLayer {
11
+ readonly _tag = "LayerNorm"
12
+ readonly epsilon: number = 1e-5
13
+ gamma: Tensor2D
14
+ beta: Tensor2D
15
+
16
+ private cache = new Map<number | string, { input: Tensor2D; mean: Tensor2D; variance: Tensor2D }>()
17
+ private lastCache: { input: Tensor2D; mean: Tensor2D; variance: Tensor2D } | null = null
18
+ optimizerGamma: Adam
19
+ optimizerBeta: Adam
20
+
21
+ constructor(embeddingDim: number) {
22
+ this.gamma = T.ones(1, embeddingDim)
23
+ this.beta = T.zeros(1, embeddingDim)
24
+ this.optimizerGamma = Adam.make(1, embeddingDim)
25
+ this.optimizerBeta = Adam.make(1, embeddingDim)
26
+ }
27
+
28
+ private fiberKey(fiberId: FiberId.FiberId): number | string {
29
+ return FiberId.isRuntime(fiberId) ? fiberId.id : JSON.stringify(fiberId)
30
+ }
31
+
32
+ get parametersCount(): number {
33
+ return this.gamma.data.length + this.beta.data.length
34
+ }
35
+
36
+ forward(input: Tensor2D): Effect.Effect<Tensor2D, ShapeError> {
37
+ return Effect.gen(this, function* () {
38
+ const mean = Ops.meanRows(input)
39
+ const variance = Ops.varRows(input)
40
+
41
+ const fiberId = yield* Effect.fiberId
42
+ const key = this.fiberKey(fiberId)
43
+ const cached = {
44
+ input: T.clone(input),
45
+ mean: T.clone(mean),
46
+ variance: T.clone(variance)
47
+ }
48
+ this.cache.set(key, cached)
49
+ this.lastCache = cached
50
+
51
+ // Use sqrt(variance + epsilon) for numerical stability
52
+ const rstd = Ops.mapScalar(variance, (v) => 1.0 / Math.sqrt(v + this.epsilon))
53
+ const centered = yield* Ops.broadcastSubCol(input, mean)
54
+ const normalized = yield* Ops.broadcastMulCol(centered, rstd)
55
+ const scaled = yield* Ops.broadcastMulRow(normalized, this.gamma)
56
+ const shifted = yield* Ops.broadcastAddRow(scaled, this.beta)
57
+ return shifted
58
+ })
59
+ }
60
+
61
+ backward(dOut: Tensor2D, lr: number): Effect.Effect<Tensor2D, ShapeError> {
62
+ return Effect.gen(this, function* () {
63
+ const fiberId = yield* Effect.fiberId
64
+ const key = this.fiberKey(fiberId)
65
+ const cached = this.cache.get(key) ?? this.lastCache
66
+ if (!cached) {
67
+ return yield* Effect.fail(new Ops.ShapeError("LayerNorm.backward called before forward"))
68
+ }
69
+ this.cache.delete(key)
70
+ this.lastCache = null
71
+
72
+ const { input, mean, variance } = cached
73
+ const rows = input.rows
74
+ const cols = input.cols
75
+ const nFeatures = cols
76
+
77
+ const normalized = T.zeros(rows, cols)
78
+ const gradNormalized = T.zeros(rows, cols)
79
+ for (let i = 0; i < rows; i++) {
80
+ const meanVal = mean.data[i]
81
+ // Consistent with forward: rstd = 1 / sqrt(variance + epsilon)
82
+ const rstd = 1.0 / Math.sqrt(variance.data[i] + this.epsilon)
83
+ for (let j = 0; j < cols; j++) {
84
+ const idx = i * cols + j
85
+ const norm = (input.data[idx] - meanVal) * rstd
86
+ normalized.data[idx] = norm
87
+ gradNormalized.data[idx] = this.gamma.data[j] * dOut.data[idx]
88
+ }
89
+ }
90
+
91
+ const gradGamma = T.zeros(1, cols)
92
+ const gradBeta = T.zeros(1, cols)
93
+ for (let j = 0; j < cols; j++) {
94
+ let sumGamma = 0
95
+ let sumBeta = 0
96
+ for (let i = 0; i < rows; i++) {
97
+ const idx = i * cols + j
98
+ sumGamma += normalized.data[idx] * dOut.data[idx]
99
+ sumBeta += dOut.data[idx]
100
+ }
101
+ gradGamma.data[j] = sumGamma
102
+ gradBeta.data[j] = sumBeta
103
+ }
104
+
105
+ const gradInput = T.zeros(rows, cols)
106
+ for (let i = 0; i < rows; i++) {
107
+ const meanVal = mean.data[i]
108
+ const varPlusEps = variance.data[i] + this.epsilon
109
+ const rstd = 1.0 / Math.sqrt(varPlusEps)
110
+
111
+ let sumGradNormalized = 0
112
+ let sumGradNormTimesNorm = 0
113
+ const rowOffset = i * cols
114
+
115
+ for (let j = 0; j < cols; j++) {
116
+ const idx = rowOffset + j
117
+ sumGradNormalized += gradNormalized.data[idx]
118
+ sumGradNormTimesNorm += gradNormalized.data[idx] * normalized.data[idx]
119
+ }
120
+
121
+ // Gradient of LayerNorm: dL/dx = rstd * (dL/dnorm - mean(dL/dnorm) - norm * mean(dL/dnorm * norm))
122
+ for (let j = 0; j < cols; j++) {
123
+ const idx = rowOffset + j
124
+ gradInput.data[idx] =
125
+ rstd *
126
+ (gradNormalized.data[idx] -
127
+ sumGradNormalized / nFeatures -
128
+ (normalized.data[idx] * sumGradNormTimesNorm) / nFeatures)
129
+ }
130
+ }
131
+
132
+ this.optimizerGamma.step(this.gamma, gradGamma, lr)
133
+ this.optimizerBeta.step(this.beta, gradBeta, lr)
134
+
135
+ return gradInput
136
+ })
137
+ }
138
+ }
@@ -0,0 +1,10 @@
1
+ import type * as Effect from "effect/Effect"
2
+ import type { Tensor2D } from "../tensor/Tensor2D"
3
+ import type { ShapeError } from "../tensor/ops"
4
+
5
+ export interface ModelLayer {
6
+ readonly _tag: string
7
+ readonly parametersCount: number
8
+ forward(input: Tensor2D): Effect.Effect<Tensor2D, ShapeError>
9
+ backward(dOut: Tensor2D, lr: number): Effect.Effect<Tensor2D, ShapeError>
10
+ }
@@ -0,0 +1,76 @@
1
+ import * as Effect from "effect/Effect"
2
+ import * as FiberId from "effect/FiberId"
3
+ import type { Tensor2D } from "../tensor/Tensor2D"
4
+ import * as T from "../tensor/Tensor2D"
5
+ import * as Ops from "../tensor/ops"
6
+ import type { ShapeError } from "../tensor/ops"
7
+ import type { ModelLayer } from "./ModelLayer"
8
+ import { EMBEDDING_DIM } from "../config"
9
+ import { Adam } from "../training/Adam"
10
+ import type { Rng } from "../tensor/random"
11
+
12
+ export class OutputProjection implements ModelLayer {
13
+ readonly _tag = "OutputProjection"
14
+ wOut: Tensor2D
15
+ bOut: Tensor2D
16
+
17
+ private cache = new Map<number | string, Tensor2D>()
18
+ private lastCache: Tensor2D | null = null
19
+ optimizerWOut: Adam
20
+
21
+ constructor(embeddingDim: number = EMBEDDING_DIM, vocabSize: number, rng: Rng) {
22
+ const std = Math.sqrt(2.0 / embeddingDim)
23
+ this.wOut = Ops.initNormal(embeddingDim, vocabSize, 0, std, rng)
24
+ this.bOut = T.zeros(1, vocabSize)
25
+ this.optimizerWOut = Adam.make(embeddingDim, vocabSize)
26
+ }
27
+
28
+ private fiberKey(fiberId: FiberId.FiberId): number | string {
29
+ return FiberId.isRuntime(fiberId) ? fiberId.id : JSON.stringify(fiberId)
30
+ }
31
+
32
+ get parametersCount(): number {
33
+ return this.wOut.data.length + this.bOut.data.length
34
+ }
35
+
36
+ forward(input: Tensor2D): Effect.Effect<Tensor2D, ShapeError> {
37
+ return Effect.gen(this, function* () {
38
+ const fiberId = yield* Effect.fiberId
39
+ const key = this.fiberKey(fiberId)
40
+ const cloned = T.clone(input)
41
+ this.cache.set(key, cloned)
42
+ this.lastCache = cloned
43
+ const projected = yield* Ops.matMul(input, this.wOut)
44
+ const output = yield* Ops.addRowBias(projected, this.bOut)
45
+ return output
46
+ })
47
+ }
48
+
49
+ backward(dOut: Tensor2D, lr: number): Effect.Effect<Tensor2D, ShapeError> {
50
+ return Effect.gen(this, function* () {
51
+ const fiberId = yield* Effect.fiberId
52
+ const key = this.fiberKey(fiberId)
53
+ const cachedInput = this.cache.get(key) ?? this.lastCache
54
+ if (!cachedInput) {
55
+ return yield* Effect.fail(new Ops.ShapeError("OutputProjection.backward called before forward"))
56
+ }
57
+ this.cache.delete(key)
58
+ this.lastCache = null
59
+
60
+ const input = cachedInput
61
+ const inputT = Ops.transpose(input)
62
+ const gradWOut = yield* Ops.matMul(inputT, dOut)
63
+ const gradBOut = Ops.sumCols(dOut)
64
+
65
+ const wOutT = Ops.transpose(this.wOut)
66
+ const gradInput = yield* Ops.matMul(dOut, wOutT)
67
+
68
+ this.optimizerWOut.step(this.wOut, gradWOut, lr)
69
+ for (let j = 0; j < this.bOut.data.length; j++) {
70
+ this.bOut.data[j] -= lr * gradBOut.data[j]
71
+ }
72
+
73
+ return gradInput
74
+ })
75
+ }
76
+ }