finite_mdp 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.rdoc +8 -12
- data/lib/finite_mdp/array_model.rb +226 -0
- data/lib/finite_mdp/hash_model.rb +10 -9
- data/lib/finite_mdp/model.rb +19 -18
- data/lib/finite_mdp/solver.rb +96 -83
- data/lib/finite_mdp/table_model.rb +28 -19
- data/lib/finite_mdp/vector_valued.rb +5 -5
- data/lib/finite_mdp/version.rb +2 -1
- data/lib/finite_mdp.rb +3 -2
- data/test/finite_mdp/finite_mdp_test.rb +151 -98
- metadata +33 -4
data/lib/finite_mdp/solver.rb
CHANGED
@@ -1,3 +1,9 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
# We use A to denote a matrix, which rubocop does not like.
|
4
|
+
# rubocop:disable Style/MethodName
|
5
|
+
# rubocop:disable Style/VariableName
|
6
|
+
|
1
7
|
require 'narray'
|
2
8
|
|
3
9
|
#
|
@@ -29,59 +35,54 @@ class FiniteMDP::Solver
|
|
29
35
|
# @param [Hash<state, Float>] value initial value for each state; defaults to
|
30
36
|
# zero for every state
|
31
37
|
#
|
32
|
-
def initialize
|
33
|
-
@model = model
|
38
|
+
def initialize(model, discount, policy: nil, value: Hash.new(0))
|
34
39
|
@discount = discount
|
35
40
|
|
36
41
|
# get the model data into a more compact form for calculation; this means
|
37
42
|
# that we number the states and actions for faster lookups (avoid most of
|
38
|
-
# the hashing)
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
pr = model.transition_probability(state, action, next_state)
|
46
|
-
[state_to_num[next_state], pr,
|
47
|
-
model.reward(state, action, next_state)] if pr > 0
|
48
|
-
}.compact
|
49
|
-
}
|
50
|
-
}
|
43
|
+
# the hashing)
|
44
|
+
@model =
|
45
|
+
if model.is_a?(FiniteMDP::ArrayModel)
|
46
|
+
model
|
47
|
+
else
|
48
|
+
FiniteMDP::ArrayModel.from_model(model)
|
49
|
+
end
|
51
50
|
|
52
51
|
# convert initial values and policies to compact form
|
53
|
-
@array_value =
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
# default to the first action, arbitrarily
|
63
|
-
@array_policy = [0]*model_states.size
|
64
|
-
end
|
52
|
+
@array_value = @model.states.map { |state| value[state] }
|
53
|
+
@array_policy =
|
54
|
+
if policy
|
55
|
+
@model.states.map do |state|
|
56
|
+
@model.actions(state).index(policy[state])
|
57
|
+
end
|
58
|
+
else
|
59
|
+
[0] * @model.num_states
|
60
|
+
end
|
65
61
|
|
66
62
|
raise 'some initial values are missing' if
|
67
|
-
@array_value.any?
|
63
|
+
@array_value.any?(&:nil?)
|
68
64
|
raise 'some initial policy actions are missing' if
|
69
|
-
@array_policy.any?
|
65
|
+
@array_policy.any?(&:nil?)
|
70
66
|
|
71
67
|
@policy_A = nil
|
72
68
|
end
|
73
69
|
|
74
70
|
#
|
75
|
-
# @return [
|
76
|
-
# while it is being solved
|
71
|
+
# @return [ArrayModel] the model being solved; read only; do not change the
|
72
|
+
# model while it is being solved
|
77
73
|
#
|
78
74
|
attr_reader :model
|
79
75
|
|
80
|
-
#
|
76
|
+
#
|
77
|
+
# @return [Float] discount factor, in (0, 1]
|
78
|
+
#
|
79
|
+
attr_reader :discount
|
80
|
+
|
81
|
+
#
|
81
82
|
# Current value estimate for each state.
|
82
83
|
#
|
83
84
|
# The result is converted from the solver's internal representation, so you
|
84
|
-
# cannot affect the solver by changing the result.
|
85
|
+
# cannot affect the solver by changing the result.
|
85
86
|
#
|
86
87
|
# @return [Hash<state, Float>] from states to values; read only; any changes
|
87
88
|
# made to the returned object will not affect the solver
|
@@ -98,14 +99,12 @@ class FiniteMDP::Solver
|
|
98
99
|
#
|
99
100
|
def state_action_value
|
100
101
|
q = {}
|
101
|
-
states
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
pr * (r + @discount * @array_value[next_state_n])}.inject(:+)
|
108
|
-
q[[state, state_actions[action_n]]] = q_sa
|
102
|
+
model.states.each_with_index do |state, state_n|
|
103
|
+
model.actions(state).each_with_index do |action, action_n|
|
104
|
+
q_sa = model.array[state_n][action_n].map do |next_state_n, pr, r|
|
105
|
+
pr * (r + @discount * @array_value[next_state_n])
|
106
|
+
end.inject(:+)
|
107
|
+
q[[state, action]] = q_sa
|
109
108
|
end
|
110
109
|
end
|
111
110
|
q
|
@@ -118,8 +117,9 @@ class FiniteMDP::Solver
|
|
118
117
|
# made to the returned object will not affect the solver
|
119
118
|
#
|
120
119
|
def policy
|
121
|
-
Hash[model.states.zip(@array_policy).map
|
122
|
-
[state, model.actions(state)[action_n]]
|
120
|
+
Hash[model.states.zip(@array_policy).map do |state, action_n|
|
121
|
+
[state, model.actions(state)[action_n]]
|
122
|
+
end]
|
123
123
|
end
|
124
124
|
|
125
125
|
#
|
@@ -135,7 +135,7 @@ class FiniteMDP::Solver
|
|
135
135
|
#
|
136
136
|
def evaluate_policy
|
137
137
|
delta = 0.0
|
138
|
-
|
138
|
+
model.array.each_with_index do |actions, state_n|
|
139
139
|
next_state_ns = actions[@array_policy[state_n]]
|
140
140
|
new_value = backup(next_state_ns)
|
141
141
|
delta = [delta, (@array_value[state_n] - new_value).abs].max
|
@@ -162,16 +162,16 @@ class FiniteMDP::Solver
|
|
162
162
|
def evaluate_policy_exact
|
163
163
|
if @policy_A
|
164
164
|
# update only those rows for which the policy has changed
|
165
|
-
@policy_A_action.zip(@array_policy)
|
166
|
-
each_with_index do |(old_action_n, new_action_n), state_n|
|
165
|
+
@policy_A_action.zip(@array_policy)
|
166
|
+
.each_with_index do |(old_action_n, new_action_n), state_n|
|
167
167
|
next if old_action_n == new_action_n
|
168
168
|
update_policy_Ab state_n, new_action_n
|
169
169
|
end
|
170
170
|
else
|
171
171
|
# initialise the A and the b for Ax = b
|
172
|
-
num_states =
|
172
|
+
num_states = model.num_states
|
173
173
|
@policy_A = NMatrix.float(num_states, num_states)
|
174
|
-
@policy_A_action = [-1]*num_states
|
174
|
+
@policy_A_action = [-1] * num_states
|
175
175
|
@policy_b = NVector.float(num_states)
|
176
176
|
|
177
177
|
@array_policy.each_with_index do |action_n, state_n|
|
@@ -189,26 +189,30 @@ class FiniteMDP::Solver
|
|
189
189
|
#
|
190
190
|
# This is the 'policy improvement' step in Figure 4.3 of Sutton and Barto
|
191
191
|
# (1998).
|
192
|
-
#
|
193
|
-
# @return [Boolean] false iff the policy changed for any state
|
194
192
|
#
|
195
|
-
|
196
|
-
|
197
|
-
|
193
|
+
# @param [Float] tolerance non-negative tolerance; for the policy to change,
|
194
|
+
# the action must be at least this much better than the current
|
195
|
+
# action
|
196
|
+
#
|
197
|
+
# @return [Integer] number of states that changed
|
198
|
+
#
|
199
|
+
def improve_policy(tolerance: Float::EPSILON)
|
200
|
+
changed = 0
|
201
|
+
model.array.each_with_index do |actions, state_n|
|
198
202
|
a_max = nil
|
199
203
|
v_max = -Float::MAX
|
200
204
|
actions.each_with_index do |next_state_ns, action_n|
|
201
205
|
v = backup(next_state_ns)
|
202
|
-
if v > v_max
|
206
|
+
if v > v_max + tolerance
|
203
207
|
a_max = action_n
|
204
208
|
v_max = v
|
205
209
|
end
|
206
210
|
end
|
207
211
|
raise "no feasible actions in state #{state_n}" unless a_max
|
208
|
-
|
212
|
+
changed += 1 if @array_policy[state_n] != a_max
|
209
213
|
@array_policy[state_n] = a_max
|
210
214
|
end
|
211
|
-
|
215
|
+
changed
|
212
216
|
end
|
213
217
|
|
214
218
|
#
|
@@ -223,7 +227,7 @@ class FiniteMDP::Solver
|
|
223
227
|
#
|
224
228
|
def value_iteration_single
|
225
229
|
delta = 0.0
|
226
|
-
|
230
|
+
model.array.each_with_index do |actions, state_n|
|
227
231
|
a_max = nil
|
228
232
|
v_max = -Float::MAX
|
229
233
|
actions.each_with_index do |next_state_ns, action_n|
|
@@ -247,7 +251,7 @@ class FiniteMDP::Solver
|
|
247
251
|
#
|
248
252
|
# @param [Float] tolerance small positive number
|
249
253
|
#
|
250
|
-
# @param [Integer, nil] max_iters terminate after this many iterations, even
|
254
|
+
# @param [Integer, nil] max_iters terminate after this many iterations, even
|
251
255
|
# if the value function has not converged; nil means that there is
|
252
256
|
# no limit on the number of iterations
|
253
257
|
#
|
@@ -260,7 +264,7 @@ class FiniteMDP::Solver
|
|
260
264
|
# @yieldparam [Float] delta largest change in the value function in the last
|
261
265
|
# iteration
|
262
266
|
#
|
263
|
-
def value_iteration
|
267
|
+
def value_iteration(tolerance:, max_iters: nil)
|
264
268
|
delta = Float::MAX
|
265
269
|
num_iters = 0
|
266
270
|
loop do
|
@@ -281,6 +285,11 @@ class FiniteMDP::Solver
|
|
281
285
|
# phase ends if the largest change in the value function
|
282
286
|
# (<tt>delta</tt>) is below this tolerance
|
283
287
|
#
|
288
|
+
# @param [Float] policy_tolerance small positive number; when comparing
|
289
|
+
# actions during policy improvement, ignore value function differences
|
290
|
+
# smaller than this tolerance; this helps with convergence when there
|
291
|
+
# are several equivalent or extremely similar actions
|
292
|
+
#
|
284
293
|
# @param [Integer, nil] max_value_iters terminate the policy evaluation
|
285
294
|
# phase after this many iterations, even if the value function has not
|
286
295
|
# converged; nil means that there is no limit on the number of
|
@@ -298,16 +307,20 @@ class FiniteMDP::Solver
|
|
298
307
|
# @yieldparam [Integer] num_policy_iters policy improvement iterations done so
|
299
308
|
# far
|
300
309
|
#
|
310
|
+
# @yieldparam [Integer?] actions_changed number of actions that changed in
|
311
|
+
# the policy improvement phase, if any
|
312
|
+
#
|
301
313
|
# @yieldparam [Integer] num_value_iters policy evaluation iterations done so
|
302
314
|
# far for the current policy improvement iteration
|
303
315
|
#
|
304
316
|
# @yieldparam [Float] delta largest change in the value function in the last
|
305
317
|
# policy evaluation iteration
|
306
318
|
#
|
307
|
-
def policy_iteration
|
308
|
-
|
319
|
+
def policy_iteration(value_tolerance:,
|
320
|
+
policy_tolerance: value_tolerance / 2.0, max_value_iters: nil,
|
321
|
+
max_policy_iters: nil)
|
309
322
|
|
310
|
-
|
323
|
+
num_actions_changed = nil
|
311
324
|
num_policy_iters = 0
|
312
325
|
loop do
|
313
326
|
# policy evaluation
|
@@ -315,19 +328,19 @@ class FiniteMDP::Solver
|
|
315
328
|
loop do
|
316
329
|
value_delta = evaluate_policy
|
317
330
|
num_value_iters += 1
|
331
|
+
yield(num_policy_iters, num_actions_changed, num_value_iters,
|
332
|
+
value_delta) if block_given?
|
318
333
|
|
319
334
|
break if value_delta < value_tolerance
|
320
335
|
break if max_value_iters && num_value_iters >= max_value_iters
|
321
|
-
yield num_policy_iters, num_value_iters, value_delta if block_given?
|
322
336
|
end
|
323
337
|
|
324
338
|
# policy improvement
|
325
|
-
|
339
|
+
num_actions_changed = improve_policy(tolerance: policy_tolerance)
|
326
340
|
num_policy_iters += 1
|
327
|
-
|
328
|
-
|
341
|
+
return true if num_actions_changed == 0
|
342
|
+
return false if max_policy_iters && num_policy_iters >= max_policy_iters
|
329
343
|
end
|
330
|
-
stable
|
331
344
|
end
|
332
345
|
|
333
346
|
#
|
@@ -343,18 +356,19 @@ class FiniteMDP::Solver
|
|
343
356
|
#
|
344
357
|
# @yieldparam [Integer] num_iters policy improvement iterations done so far
|
345
358
|
#
|
346
|
-
|
347
|
-
|
359
|
+
# @yieldparam [Integer] num_actions_changed number of actions that changed in
|
360
|
+
# the last policy improvement phase
|
361
|
+
#
|
362
|
+
def policy_iteration_exact(max_iters: nil)
|
348
363
|
num_iters = 0
|
349
364
|
loop do
|
350
365
|
evaluate_policy_exact
|
351
|
-
|
366
|
+
num_actions_changed = improve_policy
|
352
367
|
num_iters += 1
|
353
|
-
|
354
|
-
|
355
|
-
|
368
|
+
yield num_iters, num_actions_changed if block_given?
|
369
|
+
return true if num_actions_changed == 0
|
370
|
+
return false if max_iters && num_iters >= max_iters
|
356
371
|
end
|
357
|
-
stable
|
358
372
|
end
|
359
373
|
|
360
374
|
private
|
@@ -362,30 +376,29 @@ class FiniteMDP::Solver
|
|
362
376
|
#
|
363
377
|
# Updated value estimate for a state with the given successor states.
|
364
378
|
#
|
365
|
-
def backup
|
366
|
-
next_state_ns.map
|
367
|
-
probability*(reward + @discount
|
368
|
-
|
379
|
+
def backup(next_state_ns)
|
380
|
+
next_state_ns.map do |next_state_n, probability, reward|
|
381
|
+
probability * (reward + @discount * @array_value[next_state_n])
|
382
|
+
end.inject(:+)
|
369
383
|
end
|
370
384
|
|
371
385
|
#
|
372
386
|
# Update the row in A the entry in b (in Ax=b) for the given state; see
|
373
387
|
# {#evaluate_policy_exact}.
|
374
388
|
#
|
375
|
-
def update_policy_Ab
|
389
|
+
def update_policy_Ab(state_n, action_n)
|
376
390
|
# clear out the old values for state_n's row
|
377
391
|
@policy_A[true, state_n] = 0.0
|
378
392
|
|
379
393
|
# set new values according to state_n's successors under the current policy
|
380
394
|
b_n = 0
|
381
|
-
next_state_ns =
|
395
|
+
next_state_ns = model.array[state_n][action_n]
|
382
396
|
next_state_ns.each do |next_state_n, probability, reward|
|
383
|
-
@policy_A[next_state_n, state_n] = -@discount*probability
|
384
|
-
b_n += probability*reward
|
397
|
+
@policy_A[next_state_n, state_n] = -@discount * probability
|
398
|
+
b_n += probability * reward
|
385
399
|
end
|
386
400
|
@policy_A[state_n, state_n] += 1
|
387
401
|
@policy_A_action[state_n] = action_n
|
388
402
|
@policy_b[state_n] = b_n
|
389
403
|
end
|
390
404
|
end
|
391
|
-
|
@@ -1,3 +1,4 @@
|
|
1
|
+
# frozen_string_literal: true
|
1
2
|
#
|
2
3
|
# A finite markov decision process model for which the states, actions,
|
3
4
|
# transition probabilities and rewards are specified as a table. This is a
|
@@ -12,7 +13,7 @@ class FiniteMDP::TableModel
|
|
12
13
|
# @param [Array<[state, action, state, Float, Float]>] rows each row is
|
13
14
|
# [state, action, next state, probability, reward]
|
14
15
|
#
|
15
|
-
def initialize
|
16
|
+
def initialize(rows)
|
16
17
|
@rows = rows
|
17
18
|
end
|
18
19
|
|
@@ -28,7 +29,7 @@ class FiniteMDP::TableModel
|
|
28
29
|
# @return [Array<state>] not empty; no duplicate states
|
29
30
|
#
|
30
31
|
def states
|
31
|
-
@rows.map{|row| row[0]}.uniq
|
32
|
+
@rows.map { |row| row[0] }.uniq
|
32
33
|
end
|
33
34
|
|
34
35
|
#
|
@@ -38,23 +39,23 @@ class FiniteMDP::TableModel
|
|
38
39
|
#
|
39
40
|
# @return [Array<action>] not empty; no duplicate actions
|
40
41
|
#
|
41
|
-
def actions
|
42
|
-
@rows.map{|row| row[1] if row[0] == state}.compact.uniq
|
42
|
+
def actions(state)
|
43
|
+
@rows.map { |row| row[1] if row[0] == state }.compact.uniq
|
43
44
|
end
|
44
45
|
|
45
46
|
#
|
46
47
|
# Possible successor states after taking the given action in the given state;
|
47
48
|
# see {Model#next_states}.
|
48
|
-
#
|
49
|
+
#
|
49
50
|
# @param [state] state
|
50
51
|
#
|
51
52
|
# @param [action] action
|
52
53
|
#
|
53
54
|
# @return [Array<state>] not empty; no duplicate states
|
54
55
|
#
|
55
|
-
def next_states
|
56
|
-
@rows.map{|row| row[2] if row[0] == state && row[1] == action}.compact
|
57
|
-
end
|
56
|
+
def next_states(state, action)
|
57
|
+
@rows.map { |row| row[2] if row[0] == state && row[1] == action }.compact
|
58
|
+
end
|
58
59
|
|
59
60
|
#
|
60
61
|
# Probability of the given transition; see {Model#transition_probability}.
|
@@ -66,10 +67,10 @@ class FiniteMDP::TableModel
|
|
66
67
|
# @param [state] next_state
|
67
68
|
#
|
68
69
|
# @return [Float] in [0, 1]; zero if the transition is not in the table
|
69
|
-
#
|
70
|
-
def transition_probability
|
71
|
-
|
72
|
-
|
70
|
+
#
|
71
|
+
def transition_probability(state, action, next_state)
|
72
|
+
row = find_row(state, action, next_state)
|
73
|
+
row ? row[3] : 0
|
73
74
|
end
|
74
75
|
|
75
76
|
#
|
@@ -83,9 +84,9 @@ class FiniteMDP::TableModel
|
|
83
84
|
#
|
84
85
|
# @return [Float, nil] nil if the transition is not in the table
|
85
86
|
#
|
86
|
-
def reward
|
87
|
-
|
88
|
-
|
87
|
+
def reward(state, action, next_state)
|
88
|
+
row = find_row(state, action, next_state)
|
89
|
+
row[4] if row
|
89
90
|
end
|
90
91
|
|
91
92
|
#
|
@@ -105,18 +106,26 @@ class FiniteMDP::TableModel
|
|
105
106
|
#
|
106
107
|
# @return [TableModel]
|
107
108
|
#
|
108
|
-
def self.from_model
|
109
|
+
def self.from_model(model, sparse = true)
|
109
110
|
rows = []
|
110
111
|
model.states.each do |state|
|
111
112
|
model.actions(state).each do |action|
|
112
113
|
model.next_states(state, action).each do |next_state|
|
113
114
|
pr = model.transition_probability(state, action, next_state)
|
114
|
-
|
115
|
-
|
115
|
+
next unless pr > 0 || !sparse
|
116
|
+
reward = model.reward(state, action, next_state)
|
117
|
+
rows << [state, action, next_state, pr, reward]
|
116
118
|
end
|
117
119
|
end
|
118
120
|
end
|
119
121
|
FiniteMDP::TableModel.new(rows)
|
120
122
|
end
|
121
|
-
end
|
122
123
|
|
124
|
+
private
|
125
|
+
|
126
|
+
def find_row(state, action, next_state)
|
127
|
+
@rows.find do |row|
|
128
|
+
row[0] == state && row[1] == action && row[2] == next_state
|
129
|
+
end
|
130
|
+
end
|
131
|
+
end
|
@@ -1,3 +1,4 @@
|
|
1
|
+
# frozen_string_literal: true
|
1
2
|
#
|
2
3
|
# Define an object's hash code and equality (in the sense of <tt>eql?</tt>)
|
3
4
|
# according to its array representation (<tt>to_a</tt>). See notes for {Model}
|
@@ -7,7 +8,7 @@
|
|
7
8
|
#
|
8
9
|
# @example
|
9
10
|
#
|
10
|
-
# class MyPoint
|
11
|
+
# class MyPoint
|
11
12
|
# include FiniteMDP::VectorValued
|
12
13
|
#
|
13
14
|
# def initialize x, y
|
@@ -31,7 +32,7 @@ module FiniteMDP::VectorValued
|
|
31
32
|
# @return [Integer]
|
32
33
|
#
|
33
34
|
def hash
|
34
|
-
|
35
|
+
to_a.hash
|
35
36
|
end
|
36
37
|
|
37
38
|
#
|
@@ -39,8 +40,7 @@ module FiniteMDP::VectorValued
|
|
39
40
|
#
|
40
41
|
# @return [Boolean]
|
41
42
|
#
|
42
|
-
def eql?
|
43
|
-
|
43
|
+
def eql?(other)
|
44
|
+
to_a.eql? other.to_a
|
44
45
|
end
|
45
46
|
end
|
46
|
-
|
data/lib/finite_mdp/version.rb
CHANGED
data/lib/finite_mdp.rb
CHANGED
@@ -1,14 +1,15 @@
|
|
1
|
+
# frozen_string_literal: true
|
1
2
|
require 'enumerator'
|
2
3
|
|
3
4
|
require 'finite_mdp/version'
|
4
5
|
require 'finite_mdp/vector_valued'
|
5
6
|
require 'finite_mdp/model'
|
7
|
+
require 'finite_mdp/array_model'
|
6
8
|
require 'finite_mdp/hash_model'
|
7
9
|
require 'finite_mdp/table_model'
|
8
10
|
require 'finite_mdp/solver'
|
9
11
|
|
10
|
-
# TODO maybe for efficiency it would be worth including a special case for
|
12
|
+
# TODO: maybe for efficiency it would be worth including a special case for
|
11
13
|
# models in which rewards depend only on the state -- a few minor
|
12
14
|
# simplifications are possible in the solver, but it won't make a huge
|
13
15
|
# difference.
|
14
|
-
|