active-inference 0.0.1 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -9,7 +9,7 @@ Active Inference is a theory of how biological agents perceive and act in the wo
9
9
  - **Risk**: avoiding unpreferred outcomes
10
10
  - **Ambiguity**: seeking informative observations
11
11
 
12
- This library provides building blocks for creating agents that learn from observations and plan actions using these principles.
12
+ This library provides building blocks for creating agents that perceive, learn, plan, and act using these principles.
13
13
 
14
14
  ## Installation
15
15
 
@@ -61,6 +61,7 @@ const action = agent.step('see_reward');
61
61
  | `preferences` | Log probabilities of preferred observations |
62
62
  | `planningHorizon` | Steps to look ahead (default: 1) |
63
63
  | `precision` | Action selection temperature (default: 1) |
64
+ | `habits` | Prior over actions / E matrix (default: uniform) |
64
65
  | `seed` | Random seed for reproducibility |
65
66
 
66
67
  ### Agent
@@ -75,6 +76,58 @@ const action = agent.step('see_reward');
75
76
  | `freeEnergy` | Variational Free Energy |
76
77
  | `exportBelief()` | Get full belief distribution |
77
78
 
79
+ ## Learning
80
+
81
+ The library supports **Dirichlet-categorical learning** — agents that update their generative models from experience. Instead of fixed probability matrices, learnable models maintain pseudo-count concentrations that are refined over time.
82
+
83
+ - `DirichletObservation` and `DirichletTransition` are drop-in replacements for their Discrete counterparts. Learning happens automatically on every `step()` call.
84
+ - `DirichletPreferences` provides learnable preferred observations — call `.learn()` manually and pass `.preferences` to the agent config.
85
+
86
+ Low concentrations encode weak priors (learns fast). High concentrations encode strong priors (resists change).
87
+
88
+ ```typescript
89
+ import {
90
+ createAgent,
91
+ DiscreteBelief,
92
+ DirichletTransition,
93
+ DirichletObservation,
94
+ } from 'active-inference';
95
+
96
+ const agent = createAgent({
97
+ belief: new DiscreteBelief({ safe: 0.5, danger: 0.5 }),
98
+ transitionModel: new DirichletTransition({
99
+ flee: {
100
+ safe: { safe: 1, danger: 1 },
101
+ danger: { safe: 1, danger: 1 },
102
+ },
103
+ stay: {
104
+ safe: { safe: 1, danger: 1 },
105
+ danger: { safe: 1, danger: 1 },
106
+ },
107
+ }),
108
+ observationModel: new DirichletObservation({
109
+ see_safe: { safe: 1, danger: 1 },
110
+ see_danger: { safe: 1, danger: 1 },
111
+ }),
112
+ preferences: { see_safe: 0, see_danger: -5 },
113
+ seed: 42,
114
+ });
115
+
116
+ // Models update automatically on each step
117
+ const action = agent.step('see_safe');
118
+ ```
119
+
120
+ ## Examples
121
+
122
+ ### Cart-Pole Balancing
123
+
124
+ Interactive browser demo — an Active Inference agent balances an inverted pendulum using a 49-state generative model with 3-step planning horizon.
125
+
126
+ ```bash
127
+ npm run build:examples
128
+ open examples/cart-pole/index.html
129
+ ```
130
+
78
131
  ## Contributing
79
132
 
80
133
  ```bash
package/dist/index.d.ts CHANGED
@@ -2,4 +2,8 @@ export { Agent } from './models/agent.model';
2
2
  export { DiscreteBelief } from './beliefs/discrete.belief';
3
3
  export { DiscreteTransition } from './transition/discrete.transition';
4
4
  export { DiscreteObservation } from './observation/discrete.observation';
5
+ export { DirichletObservation } from './observation/dirichlet.observation';
6
+ export { DirichletTransition } from './transition/dirichlet.transition';
7
+ export { DirichletPreferences } from './preferences/dirichlet.preferences';
5
8
  export { createAgent, AgentConfig } from './factory';
9
+ export type { ILearnable } from './models/learnable.model';
package/dist/index.js CHANGED
@@ -2,4 +2,7 @@ export { Agent } from './models/agent.model';
2
2
  export { DiscreteBelief } from './beliefs/discrete.belief';
3
3
  export { DiscreteTransition } from './transition/discrete.transition';
4
4
  export { DiscreteObservation } from './observation/discrete.observation';
5
+ export { DirichletObservation } from './observation/dirichlet.observation';
6
+ export { DirichletTransition } from './transition/dirichlet.transition';
7
+ export { DirichletPreferences } from './preferences/dirichlet.preferences';
5
8
  export { createAgent } from './factory';
@@ -82,6 +82,8 @@ export declare class Agent<A extends string = string, O extends string = string,
82
82
  private _planningHorizon;
83
83
  private _precision;
84
84
  private _habits;
85
+ private _previousBelief;
86
+ private _previousAction;
85
87
  /**
86
88
  * Create a new Active Inference agent.
87
89
  *
@@ -95,6 +97,7 @@ export declare class Agent<A extends string = string, O extends string = string,
95
97
  * @param habits - Prior over actions (default: uniform)
96
98
  */
97
99
  constructor(belief: Belief<S>, transitionModel: ITransitionModel<A, S>, observationModel: IObservationModel<O, S>, preferences: Preferences<O>, random?: Random, planningHorizon?: number, precision?: number, habits?: Partial<Habits<A>>);
100
+ private get resolvedPreferences();
98
101
  /**
99
102
  * Most likely hidden state (Maximum A Posteriori estimate).
100
103
  */
@@ -205,6 +208,14 @@ export declare class Agent<A extends string = string, O extends string = string,
205
208
  * @returns Variational Free Energy (can be negative)
206
209
  */
207
210
  get freeEnergy(): number;
211
+ /**
212
+ * Update learnable models from the current observation and belief.
213
+ *
214
+ * Called after observe() so the posterior belief is available.
215
+ * - A-learning: update observation model with (observation, posterior)
216
+ * - B-learning: update transition model with (previous_action, previous_belief, posterior)
217
+ */
218
+ private updateModels;
208
219
  /**
209
220
  * Generate all possible policies (action sequences) of given depth.
210
221
  * For depth=2 with actions [a,b]: [[a,a], [a,b], [b,a], [b,b]]
@@ -1,4 +1,5 @@
1
1
  import { LinearAlgebra, Random } from '../helpers/math.helpers';
2
+ import { isLearnable } from './learnable.model';
2
3
  /**
3
4
  * Active Inference agent implementing the Free Energy Principle.
4
5
  *
@@ -68,12 +69,17 @@ export class Agent {
68
69
  this.transitionModel = transitionModel;
69
70
  this.observationModel = observationModel;
70
71
  this.preferences = preferences;
72
+ this._previousBelief = null;
73
+ this._previousAction = null;
71
74
  this._belief = belief.copy();
72
75
  this._random = random ?? new Random();
73
76
  this._planningHorizon = Math.max(1, Math.floor(planningHorizon));
74
77
  this._precision = Math.max(0, precision);
75
78
  this._habits = habits;
76
79
  }
80
+ get resolvedPreferences() {
81
+ return this.preferences;
82
+ }
77
83
  /**
78
84
  * Most likely hidden state (Maximum A Posteriori estimate).
79
85
  */
@@ -168,7 +174,11 @@ export class Agent {
168
174
  */
169
175
  step(observation) {
170
176
  this.observe(observation);
171
- return this.act();
177
+ this.updateModels(observation);
178
+ const action = this.act();
179
+ this._previousBelief = this._belief;
180
+ this._previousAction = action;
181
+ return action;
172
182
  }
173
183
  /**
174
184
  * Export current belief as a plain object for serialization.
@@ -217,6 +227,30 @@ export class Agent {
217
227
  get freeEnergy() {
218
228
  return -this._belief.entropy() + this.computeAmbiguity(this._belief);
219
229
  }
230
+ /**
231
+ * Update learnable models from the current observation and belief.
232
+ *
233
+ * Called after observe() so the posterior belief is available.
234
+ * - A-learning: update observation model with (observation, posterior)
235
+ * - B-learning: update transition model with (previous_action, previous_belief, posterior)
236
+ */
237
+ updateModels(observation) {
238
+ const posteriorDist = this.exportBelief();
239
+ // A-matrix: P(o|s)
240
+ if (isLearnable(this.observationModel)) {
241
+ this.observationModel.learn(observation, posteriorDist);
242
+ }
243
+ // B-matrix: P(s'|s,a) — only after at least one action
244
+ if (isLearnable(this.transitionModel) &&
245
+ this._previousAction !== null &&
246
+ this._previousBelief !== null) {
247
+ const prevDist = {};
248
+ for (const state of this._previousBelief.states) {
249
+ prevDist[state] = this._previousBelief.probability(state);
250
+ }
251
+ this.transitionModel.learn(this._previousAction, prevDist, posteriorDist);
252
+ }
253
+ }
220
254
  /**
221
255
  * Generate all possible policies (action sequences) of given depth.
222
256
  * For depth=2 with actions [a,b]: [[a,a], [a,b], [b,a], [b,b]]
@@ -288,6 +322,7 @@ export class Agent {
288
322
  */
289
323
  computeRisk(predictedBelief) {
290
324
  let risk = 0;
325
+ const prefs = this.resolvedPreferences;
291
326
  for (const obs of this.observationModel.observations) {
292
327
  let expectedObsProb = 0;
293
328
  for (const state of predictedBelief.states) {
@@ -295,7 +330,7 @@ export class Agent {
295
330
  this.observationModel.probability(obs, state) *
296
331
  predictedBelief.probability(state);
297
332
  }
298
- const preferredLogProb = this.preferences[obs] ?? -10;
333
+ const preferredLogProb = prefs[obs] ?? -10;
299
334
  if (expectedObsProb > 0) {
300
335
  risk -= expectedObsProb * preferredLogProb;
301
336
  }
@@ -0,0 +1,15 @@
1
+ /**
2
+ * Interface for models that support Dirichlet concentration-based learning.
3
+ *
4
+ * Models implementing this interface maintain Dirichlet concentration
5
+ * parameters (pseudo-counts) that are updated from experience.
6
+ * The underlying probability matrices are derived by normalizing
7
+ * concentrations.
8
+ */
9
+ export interface ILearnable {
10
+ readonly learnable: true;
11
+ }
12
+ /**
13
+ * Type guard to check if a model supports learning.
14
+ */
15
+ export declare function isLearnable(obj: unknown): obj is ILearnable;
@@ -0,0 +1,9 @@
1
+ /**
2
+ * Type guard to check if a model supports learning.
3
+ */
4
+ export function isLearnable(obj) {
5
+ return (typeof obj === 'object' &&
6
+ obj !== null &&
7
+ 'learnable' in obj &&
8
+ obj.learnable === true);
9
+ }
@@ -0,0 +1,83 @@
1
+ import { Distribution } from '../models/belief.model';
2
+ import { IObservationModel, ObservationMatrix } from '../models/observation.model';
3
+ import { ILearnable } from '../models/learnable.model';
4
+ /**
5
+ * Dirichlet concentration parameters for observation model.
6
+ * Same structure as ObservationMatrix: observation → state → concentration.
7
+ *
8
+ * Each value is a positive pseudo-count. Higher values encode stronger
9
+ * prior beliefs about the observation-state mapping.
10
+ */
11
+ export type ObservationConcentrations<O extends string = string, S extends string = string> = Record<O, Record<S, number>>;
12
+ /**
13
+ * Learnable observation model using Dirichlet concentrations.
14
+ *
15
+ * Instead of a fixed A matrix, this model maintains Dirichlet pseudo-counts
16
+ * from which the probability matrix P(o|s) is derived by normalization:
17
+ *
18
+ * P(o|s) = a[o][s] / Σ_o' a[o'][s]
19
+ *
20
+ * After each observation, concentrations are updated using the posterior
21
+ * belief about states (Dirichlet-categorical conjugate update):
22
+ *
23
+ * a[o*][s] += Q(s) for the observed o*
24
+ *
25
+ * @typeParam O - Union type of possible observation names
26
+ * @typeParam S - Union type of possible state names
27
+ *
28
+ * @example
29
+ * ```typescript
30
+ * // Weak prior: agent is uncertain about observation-state mapping
31
+ * const obs = new DirichletObservation({
32
+ * see_safe: { safe: 2, danger: 1 },
33
+ * see_danger: { safe: 1, danger: 2 },
34
+ * });
35
+ *
36
+ * // Strong prior: agent has confident beliefs (equivalent to scale * probabilities)
37
+ * const obs = new DirichletObservation({
38
+ * see_safe: { safe: 90, danger: 10 },
39
+ * see_danger: { safe: 10, danger: 90 },
40
+ * });
41
+ * ```
42
+ */
43
+ export declare class DirichletObservation<O extends string = string, S extends string = string> implements IObservationModel<O, S>, ILearnable {
44
+ concentrations: ObservationConcentrations<O, S>;
45
+ readonly learnable: true;
46
+ private _matrix;
47
+ /**
48
+ * @param concentrations - Dirichlet pseudo-counts a[o][s].
49
+ * Each value must be > 0. Higher values encode stronger prior beliefs.
50
+ */
51
+ constructor(concentrations: ObservationConcentrations<O, S>);
52
+ get observations(): O[];
53
+ get states(): S[];
54
+ /**
55
+ * Normalized probability matrix derived from concentrations.
56
+ * Lazily computed and cached; invalidated on learn().
57
+ *
58
+ * Column-wise normalization (per state s):
59
+ * P(o|s) = a[o][s] / Σ_o' a[o'][s]
60
+ */
61
+ get matrix(): ObservationMatrix<O, S>;
62
+ getLikelihood(observation: O): Distribution<S>;
63
+ probability(observation: O, state: S): number;
64
+ /**
65
+ * Update concentrations from an observation and posterior belief.
66
+ *
67
+ * Dirichlet-categorical conjugate update:
68
+ * a[o*][s] += posteriorBelief[s] for the observed o*
69
+ *
70
+ * The belief-weighting handles state uncertainty: if the agent
71
+ * is 80% sure it's in state A, the count for (o*, A) increases
72
+ * by 0.8 and (o*, B) by 0.2.
73
+ *
74
+ * @param observation - The observation that was received
75
+ * @param posteriorBelief - Posterior belief distribution over states
76
+ */
77
+ learn(observation: O, posteriorBelief: Distribution<S>): void;
78
+ /**
79
+ * Column-wise normalization of concentrations.
80
+ * For each state s: P(o|s) = a[o][s] / Σ_o' a[o'][s]
81
+ */
82
+ private normalize;
83
+ }
@@ -0,0 +1,119 @@
1
+ /**
2
+ * Learnable observation model using Dirichlet concentrations.
3
+ *
4
+ * Instead of a fixed A matrix, this model maintains Dirichlet pseudo-counts
5
+ * from which the probability matrix P(o|s) is derived by normalization:
6
+ *
7
+ * P(o|s) = a[o][s] / Σ_o' a[o'][s]
8
+ *
9
+ * After each observation, concentrations are updated using the posterior
10
+ * belief about states (Dirichlet-categorical conjugate update):
11
+ *
12
+ * a[o*][s] += Q(s) for the observed o*
13
+ *
14
+ * @typeParam O - Union type of possible observation names
15
+ * @typeParam S - Union type of possible state names
16
+ *
17
+ * @example
18
+ * ```typescript
19
+ * // Weak prior: agent is uncertain about observation-state mapping
20
+ * const obs = new DirichletObservation({
21
+ * see_safe: { safe: 2, danger: 1 },
22
+ * see_danger: { safe: 1, danger: 2 },
23
+ * });
24
+ *
25
+ * // Strong prior: agent has confident beliefs (equivalent to scale * probabilities)
26
+ * const obs = new DirichletObservation({
27
+ * see_safe: { safe: 90, danger: 10 },
28
+ * see_danger: { safe: 10, danger: 90 },
29
+ * });
30
+ * ```
31
+ */
32
+ export class DirichletObservation {
33
+ /**
34
+ * @param concentrations - Dirichlet pseudo-counts a[o][s].
35
+ * Each value must be > 0. Higher values encode stronger prior beliefs.
36
+ */
37
+ constructor(concentrations) {
38
+ this.concentrations = concentrations;
39
+ this.learnable = true;
40
+ this._matrix = null;
41
+ // Deep copy to avoid aliasing
42
+ this.concentrations = {};
43
+ for (const obs of Object.keys(concentrations)) {
44
+ this.concentrations[obs] = { ...concentrations[obs] };
45
+ }
46
+ }
47
+ get observations() {
48
+ return Object.keys(this.concentrations);
49
+ }
50
+ get states() {
51
+ const firstObs = this.observations[0];
52
+ return Object.keys(this.concentrations[firstObs] || {});
53
+ }
54
+ /**
55
+ * Normalized probability matrix derived from concentrations.
56
+ * Lazily computed and cached; invalidated on learn().
57
+ *
58
+ * Column-wise normalization (per state s):
59
+ * P(o|s) = a[o][s] / Σ_o' a[o'][s]
60
+ */
61
+ get matrix() {
62
+ if (this._matrix === null) {
63
+ this._matrix = this.normalize();
64
+ }
65
+ return this._matrix;
66
+ }
67
+ getLikelihood(observation) {
68
+ return this.matrix[observation] ?? {};
69
+ }
70
+ probability(observation, state) {
71
+ return this.matrix[observation]?.[state] ?? 0;
72
+ }
73
+ /**
74
+ * Update concentrations from an observation and posterior belief.
75
+ *
76
+ * Dirichlet-categorical conjugate update:
77
+ * a[o*][s] += posteriorBelief[s] for the observed o*
78
+ *
79
+ * The belief-weighting handles state uncertainty: if the agent
80
+ * is 80% sure it's in state A, the count for (o*, A) increases
81
+ * by 0.8 and (o*, B) by 0.2.
82
+ *
83
+ * @param observation - The observation that was received
84
+ * @param posteriorBelief - Posterior belief distribution over states
85
+ */
86
+ learn(observation, posteriorBelief) {
87
+ for (const state of this.states) {
88
+ this.concentrations[observation][state] +=
89
+ posteriorBelief[state] ?? 0;
90
+ }
91
+ this._matrix = null;
92
+ }
93
+ /**
94
+ * Column-wise normalization of concentrations.
95
+ * For each state s: P(o|s) = a[o][s] / Σ_o' a[o'][s]
96
+ */
97
+ normalize() {
98
+ const matrix = {};
99
+ const observations = this.observations;
100
+ const states = this.states;
101
+ const colSums = {};
102
+ for (const s of states) {
103
+ colSums[s] = 0;
104
+ for (const o of observations) {
105
+ colSums[s] += this.concentrations[o][s];
106
+ }
107
+ }
108
+ for (const o of observations) {
109
+ matrix[o] = {};
110
+ for (const s of states) {
111
+ matrix[o][s] =
112
+ colSums[s] > 0
113
+ ? this.concentrations[o][s] / colSums[s]
114
+ : 0;
115
+ }
116
+ }
117
+ return matrix;
118
+ }
119
+ }
@@ -0,0 +1,62 @@
1
+ import { ILearnable } from '../models/learnable.model';
2
+ /**
3
+ * Dirichlet concentration parameters for preferences.
4
+ * Maps each observation to a positive pseudo-count.
5
+ */
6
+ export type PreferenceConcentrations<O extends string = string> = Record<O, number>;
7
+ /**
8
+ * Learnable preferences using Dirichlet concentrations.
9
+ *
10
+ * Instead of fixed log-probability preferences, this class maintains
11
+ * Dirichlet pseudo-counts from which log-preferences are derived:
12
+ *
13
+ * P_preferred(o) = c[o] / Σ_o' c[o']
14
+ * preferences[o] = log(P_preferred(o))
15
+ *
16
+ * Higher concentration → higher preferred probability → less negative log → more preferred.
17
+ *
18
+ * After each observation, concentrations can be reinforced:
19
+ * c[o*] += 1
20
+ *
21
+ * This allows preferences to drift based on experience.
22
+ *
23
+ * @typeParam O - Union type of possible observation names
24
+ *
25
+ * @example
26
+ * ```typescript
27
+ * // Strongly prefer reward over no_reward
28
+ * const prefs = new DirichletPreferences({
29
+ * reward: 10, // log(10/11) ≈ -0.095
30
+ * no_reward: 1, // log(1/11) ≈ -2.398
31
+ * });
32
+ * ```
33
+ */
34
+ export declare class DirichletPreferences<O extends string = string> implements ILearnable {
35
+ concentrations: PreferenceConcentrations<O>;
36
+ readonly learnable: true;
37
+ private _preferences;
38
+ /**
39
+ * @param concentrations - Pseudo-counts c[o] for each observation.
40
+ * Higher values indicate more preferred observations.
41
+ * All values must be > 0.
42
+ */
43
+ constructor(concentrations: PreferenceConcentrations<O>);
44
+ get observations(): O[];
45
+ /**
46
+ * Log-preference values derived from concentrations.
47
+ *
48
+ * P(o) = c[o] / Σ_o' c[o']
49
+ * preferences[o] = log(P(o))
50
+ *
51
+ * Returns a plain Record<O, number> compatible with Preferences<O>.
52
+ */
53
+ get preferences(): Record<O, number>;
54
+ /**
55
+ * Reinforce an observation's preference.
56
+ *
57
+ * @param observation - The observation to reinforce
58
+ * @param amount - Amount to add to the concentration (default: 1)
59
+ */
60
+ learn(observation: O, amount?: number): void;
61
+ private normalize;
62
+ }
@@ -0,0 +1,79 @@
1
+ /**
2
+ * Learnable preferences using Dirichlet concentrations.
3
+ *
4
+ * Instead of fixed log-probability preferences, this class maintains
5
+ * Dirichlet pseudo-counts from which log-preferences are derived:
6
+ *
7
+ * P_preferred(o) = c[o] / Σ_o' c[o']
8
+ * preferences[o] = log(P_preferred(o))
9
+ *
10
+ * Higher concentration → higher preferred probability → less negative log → more preferred.
11
+ *
12
+ * After each observation, concentrations can be reinforced:
13
+ * c[o*] += 1
14
+ *
15
+ * This allows preferences to drift based on experience.
16
+ *
17
+ * @typeParam O - Union type of possible observation names
18
+ *
19
+ * @example
20
+ * ```typescript
21
+ * // Strongly prefer reward over no_reward
22
+ * const prefs = new DirichletPreferences({
23
+ * reward: 10, // log(10/11) ≈ -0.095
24
+ * no_reward: 1, // log(1/11) ≈ -2.398
25
+ * });
26
+ * ```
27
+ */
28
+ export class DirichletPreferences {
29
+ /**
30
+ * @param concentrations - Pseudo-counts c[o] for each observation.
31
+ * Higher values indicate more preferred observations.
32
+ * All values must be > 0.
33
+ */
34
+ constructor(concentrations) {
35
+ this.concentrations = concentrations;
36
+ this.learnable = true;
37
+ this._preferences = null;
38
+ this.concentrations = { ...concentrations };
39
+ }
40
+ get observations() {
41
+ return Object.keys(this.concentrations);
42
+ }
43
+ /**
44
+ * Log-preference values derived from concentrations.
45
+ *
46
+ * P(o) = c[o] / Σ_o' c[o']
47
+ * preferences[o] = log(P(o))
48
+ *
49
+ * Returns a plain Record<O, number> compatible with Preferences<O>.
50
+ */
51
+ get preferences() {
52
+ if (this._preferences === null) {
53
+ this._preferences = this.normalize();
54
+ }
55
+ return this._preferences;
56
+ }
57
+ /**
58
+ * Reinforce an observation's preference.
59
+ *
60
+ * @param observation - The observation to reinforce
61
+ * @param amount - Amount to add to the concentration (default: 1)
62
+ */
63
+ learn(observation, amount = 1) {
64
+ this.concentrations[observation] += amount;
65
+ this._preferences = null;
66
+ }
67
+ normalize() {
68
+ const result = {};
69
+ let sum = 0;
70
+ for (const o of this.observations) {
71
+ sum += this.concentrations[o];
72
+ }
73
+ for (const o of this.observations) {
74
+ const p = this.concentrations[o] / sum;
75
+ result[o] = Math.log(Math.max(p, 1e-16));
76
+ }
77
+ return result;
78
+ }
79
+ }
@@ -0,0 +1,82 @@
1
+ import { Distribution } from '../models/belief.model';
2
+ import type { Belief } from '../models/belief.model';
3
+ import { ITransitionModel, TransitionMatrix } from '../models/transition.model';
4
+ import { ILearnable } from '../models/learnable.model';
5
+ /**
6
+ * Dirichlet concentration parameters for transition model.
7
+ * Same structure as TransitionMatrix: action → current_state → next_state → concentration.
8
+ */
9
+ export type TransitionConcentrations<A extends string = string, S extends string = string> = Record<A, Record<S, Record<S, number>>>;
10
+ /**
11
+ * Learnable transition model using Dirichlet concentrations.
12
+ *
13
+ * Instead of a fixed B matrix, this model maintains Dirichlet pseudo-counts
14
+ * from which the probability matrix P(s'|s,a) is derived by normalization:
15
+ *
16
+ * P(s'|s,a) = b[a][s][s'] / Σ_s'' b[a][s][s'']
17
+ *
18
+ * After each state transition, concentrations are updated using the
19
+ * outer product of prior and posterior beliefs:
20
+ *
21
+ * b[a][s][s'] += Q_prior(s) × Q_posterior(s')
22
+ *
23
+ * @typeParam A - Union type of possible action names
24
+ * @typeParam S - Union type of possible state names
25
+ *
26
+ * @example
27
+ * ```typescript
28
+ * const transition = new DirichletTransition({
29
+ * move: {
30
+ * here: { here: 1, there: 5 }, // move from here → likely end up there
31
+ * there: { here: 1, there: 5 },
32
+ * },
33
+ * stay: {
34
+ * here: { here: 5, there: 1 },
35
+ * there: { here: 1, there: 5 },
36
+ * },
37
+ * });
38
+ * ```
39
+ */
40
+ export declare class DirichletTransition<A extends string = string, S extends string = string> implements ITransitionModel<A, S>, ILearnable {
41
+ concentrations: TransitionConcentrations<A, S>;
42
+ readonly learnable: true;
43
+ private _matrix;
44
+ /**
45
+ * @param concentrations - Dirichlet pseudo-counts b[a][s][s'].
46
+ * Structure: action → current_state → next_state → count.
47
+ * Each value must be > 0.
48
+ */
49
+ constructor(concentrations: TransitionConcentrations<A, S>);
50
+ get actions(): A[];
51
+ get states(): S[];
52
+ /**
53
+ * Normalized probability matrix derived from concentrations.
54
+ * Lazily computed and cached; invalidated on learn().
55
+ *
56
+ * Row-wise normalization (per action a, per current state s):
57
+ * P(s'|s,a) = b[a][s][s'] / Σ_s'' b[a][s][s'']
58
+ */
59
+ get matrix(): TransitionMatrix<A, S>;
60
+ getTransition(state: S, action: A): Distribution<S>;
61
+ predict(belief: Belief<S>, action: A): Belief<S>;
62
+ /**
63
+ * Update concentrations from a state transition.
64
+ *
65
+ * Uses outer product of prior and posterior beliefs:
66
+ * b[a][s][s'] += Q_prior(s) × Q_posterior(s')
67
+ *
68
+ * This encodes the agent's best estimate of the transition
69
+ * that occurred, weighted by uncertainty in both states.
70
+ *
71
+ * @param action - The action that was taken
72
+ * @param priorBelief - Belief distribution before the action
73
+ * @param posteriorBelief - Belief distribution after observing the outcome
74
+ */
75
+ learn(action: A, priorBelief: Distribution<S>, posteriorBelief: Distribution<S>): void;
76
+ /**
77
+ * Row-wise normalization of concentrations.
78
+ * For each (action, current_state):
79
+ * P(s'|s,a) = b[a][s][s'] / Σ_s'' b[a][s][s'']
80
+ */
81
+ private normalize;
82
+ }
@@ -0,0 +1,135 @@
1
+ import { DiscreteBelief } from '../beliefs/discrete.belief';
2
+ /**
3
+ * Learnable transition model using Dirichlet concentrations.
4
+ *
5
+ * Instead of a fixed B matrix, this model maintains Dirichlet pseudo-counts
6
+ * from which the probability matrix P(s'|s,a) is derived by normalization:
7
+ *
8
+ * P(s'|s,a) = b[a][s][s'] / Σ_s'' b[a][s][s'']
9
+ *
10
+ * After each state transition, concentrations are updated using the
11
+ * outer product of prior and posterior beliefs:
12
+ *
13
+ * b[a][s][s'] += Q_prior(s) × Q_posterior(s')
14
+ *
15
+ * @typeParam A - Union type of possible action names
16
+ * @typeParam S - Union type of possible state names
17
+ *
18
+ * @example
19
+ * ```typescript
20
+ * const transition = new DirichletTransition({
21
+ * move: {
22
+ * here: { here: 1, there: 5 }, // move from here → likely end up there
23
+ * there: { here: 1, there: 5 },
24
+ * },
25
+ * stay: {
26
+ * here: { here: 5, there: 1 },
27
+ * there: { here: 1, there: 5 },
28
+ * },
29
+ * });
30
+ * ```
31
+ */
32
+ export class DirichletTransition {
33
+ /**
34
+ * @param concentrations - Dirichlet pseudo-counts b[a][s][s'].
35
+ * Structure: action → current_state → next_state → count.
36
+ * Each value must be > 0.
37
+ */
38
+ constructor(concentrations) {
39
+ this.concentrations = concentrations;
40
+ this.learnable = true;
41
+ this._matrix = null;
42
+ // Deep copy to avoid aliasing
43
+ this.concentrations = {};
44
+ for (const a of Object.keys(concentrations)) {
45
+ this.concentrations[a] = {};
46
+ for (const s of Object.keys(concentrations[a])) {
47
+ this.concentrations[a][s] = { ...concentrations[a][s] };
48
+ }
49
+ }
50
+ }
51
+ get actions() {
52
+ return Object.keys(this.concentrations);
53
+ }
54
+ get states() {
55
+ const firstAction = this.actions[0];
56
+ return Object.keys(this.concentrations[firstAction] || {});
57
+ }
58
+ /**
59
+ * Normalized probability matrix derived from concentrations.
60
+ * Lazily computed and cached; invalidated on learn().
61
+ *
62
+ * Row-wise normalization (per action a, per current state s):
63
+ * P(s'|s,a) = b[a][s][s'] / Σ_s'' b[a][s][s'']
64
+ */
65
+ get matrix() {
66
+ if (this._matrix === null) {
67
+ this._matrix = this.normalize();
68
+ }
69
+ return this._matrix;
70
+ }
71
+ getTransition(state, action) {
72
+ return this.matrix[action]?.[state] ?? {};
73
+ }
74
+ predict(belief, action) {
75
+ const newDist = {};
76
+ for (const state of this.states) {
77
+ newDist[state] = 0;
78
+ }
79
+ for (const currentState of belief.states) {
80
+ const transition = this.getTransition(currentState, action);
81
+ const currentProb = belief.probability(currentState);
82
+ for (const nextState of Object.keys(transition)) {
83
+ newDist[nextState] += transition[nextState] * currentProb;
84
+ }
85
+ }
86
+ return new DiscreteBelief(newDist);
87
+ }
88
+ /**
89
+ * Update concentrations from a state transition.
90
+ *
91
+ * Uses outer product of prior and posterior beliefs:
92
+ * b[a][s][s'] += Q_prior(s) × Q_posterior(s')
93
+ *
94
+ * This encodes the agent's best estimate of the transition
95
+ * that occurred, weighted by uncertainty in both states.
96
+ *
97
+ * @param action - The action that was taken
98
+ * @param priorBelief - Belief distribution before the action
99
+ * @param posteriorBelief - Belief distribution after observing the outcome
100
+ */
101
+ learn(action, priorBelief, posteriorBelief) {
102
+ for (const s of this.states) {
103
+ for (const sPrime of this.states) {
104
+ this.concentrations[action][s][sPrime] +=
105
+ (priorBelief[s] ?? 0) * (posteriorBelief[sPrime] ?? 0);
106
+ }
107
+ }
108
+ this._matrix = null;
109
+ }
110
+ /**
111
+ * Row-wise normalization of concentrations.
112
+ * For each (action, current_state):
113
+ * P(s'|s,a) = b[a][s][s'] / Σ_s'' b[a][s][s'']
114
+ */
115
+ normalize() {
116
+ const matrix = {};
117
+ for (const a of this.actions) {
118
+ matrix[a] = {};
119
+ for (const s of this.states) {
120
+ matrix[a][s] = {};
121
+ let rowSum = 0;
122
+ for (const sPrime of this.states) {
123
+ rowSum += this.concentrations[a][s][sPrime];
124
+ }
125
+ for (const sPrime of this.states) {
126
+ matrix[a][s][sPrime] =
127
+ rowSum > 0
128
+ ? this.concentrations[a][s][sPrime] / rowSum
129
+ : 0;
130
+ }
131
+ }
132
+ }
133
+ return matrix;
134
+ }
135
+ }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "active-inference",
3
- "version": "0.0.1",
3
+ "version": "0.1.0",
4
4
  "description": "Active Inference Framework implementation for JavaScript/TypeScript",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -11,7 +11,8 @@
11
11
  "test": "vitest run",
12
12
  "test:watch": "vitest",
13
13
  "build": "tsc",
14
- "prepublishOnly": "npm run build"
14
+ "prepublish": "npm run build",
15
+ "build:examples": "npx esbuild src/index.ts --bundle --format=esm --outfile=examples/cart-pole/active-inference.bundle.js"
15
16
  },
16
17
  "keywords": [
17
18
  "active-inference",