finite_mdp 0.0.1
Sign up to get free protection for your applications and to get access to all the features.
- data/README.rdoc +229 -0
- data/lib/finite_mdp/hash_model.rb +123 -0
- data/lib/finite_mdp/model.rb +195 -0
- data/lib/finite_mdp/solver.rb +344 -0
- data/lib/finite_mdp/table_model.rb +122 -0
- data/lib/finite_mdp/vector_valued.rb +46 -0
- data/lib/finite_mdp/version.rb +3 -0
- data/lib/finite_mdp.rb +14 -0
- data/test/finite_mdp_test.rb +347 -0
- metadata +94 -0
@@ -0,0 +1,344 @@
|
|
1
|
+
require 'narray'
|
2
|
+
|
3
|
+
#
|
4
|
+
# Find optimal values and policies using policy iteration and/or value
|
5
|
+
# iteration. The methods here are suitable for finding deterministic policies
|
6
|
+
# for infinite-horizon problems.
|
7
|
+
#
|
8
|
+
# The computations are carried out on an intermediate form of the given model,
|
9
|
+
# which is stored using nested arrays:
|
10
|
+
# model[state_num][action_num] = [[next_state_num, probability, reward], ...]
|
11
|
+
# The solver assigns numbers to each state and each action automatically. Note
|
12
|
+
# that the successor state data are stored in sparse format, and any transitions
|
13
|
+
# that are in the given model but have zero probability are not stored.
|
14
|
+
#
|
15
|
+
# TODO implement backward induction for finite horizon problems
|
16
|
+
#
|
17
|
+
# TODO maybe implement a 'dense' storage format for models with many successor
|
18
|
+
# states, probably as a different solver class
|
19
|
+
#
|
20
|
+
class FiniteMDP::Solver
|
21
|
+
#
|
22
|
+
# @param [Model] model
|
23
|
+
#
|
24
|
+
# @param [Float] discount in (0, 1]
|
25
|
+
#
|
26
|
+
# @param [Hash<state, action>, nil] policy initial policy; if nil, an
|
27
|
+
# arbitrary action is selected for each state
|
28
|
+
#
|
29
|
+
# @param [Hash<state, Float>] value initial value for each state; defaults to
|
30
|
+
# zero for every state
|
31
|
+
#
|
32
|
+
def initialize model, discount, policy=nil, value=Hash.new(0)
|
33
|
+
@model = model
|
34
|
+
@discount = discount
|
35
|
+
|
36
|
+
# get the model data into a more compact form for calculation; this means
|
37
|
+
# that we number the states and actions for faster lookups (avoid most of
|
38
|
+
# the hashing); the 'next states' map is still stored in sparse format
|
39
|
+
# (that is, as a hash)
|
40
|
+
model_states = model.states
|
41
|
+
state_to_num = Hash[model_states.zip((0...model_states.size).to_a)]
|
42
|
+
@array_model = model_states.map {|state|
|
43
|
+
model.actions(state).map {|action|
|
44
|
+
model.next_states(state, action).map {|next_state|
|
45
|
+
pr = model.transition_probability(state, action, next_state)
|
46
|
+
[state_to_num[next_state], pr,
|
47
|
+
model.reward(state, action, next_state)] if pr > 0
|
48
|
+
}.compact
|
49
|
+
}
|
50
|
+
}
|
51
|
+
|
52
|
+
# convert initial values and policies to compact form
|
53
|
+
@array_value = model_states.map {|state| value[state]}
|
54
|
+
if policy
|
55
|
+
action_to_num = model_states.map{|state|
|
56
|
+
actions = model.actions(state)
|
57
|
+
Hash[actions.zip((0...actions.size).to_a)]
|
58
|
+
}
|
59
|
+
@array_policy = action_to_num.zip(model_states).
|
60
|
+
map {|a_to_n, state| a_to_n[policy[state]]}
|
61
|
+
else
|
62
|
+
# default to the first action, arbitrarily
|
63
|
+
@array_policy = [0]*model_states.size
|
64
|
+
end
|
65
|
+
|
66
|
+
raise 'some initial values are missing' if
|
67
|
+
@array_value.any? {|v| v.nil?}
|
68
|
+
raise 'some initial policy actions are missing' if
|
69
|
+
@array_policy.any? {|a| a.nil?}
|
70
|
+
|
71
|
+
@policy_A = nil
|
72
|
+
end
|
73
|
+
|
74
|
+
#
|
75
|
+
# @return [Model] the model being solved; read only; do not change the model
|
76
|
+
# while it is being solved
|
77
|
+
#
|
78
|
+
attr_reader :model
|
79
|
+
|
80
|
+
#
|
81
|
+
# Current value estimate for each state.
|
82
|
+
#
|
83
|
+
# The result is converted from the solver's internal representation, so you
|
84
|
+
# cannot affect the solver by changing the result.
|
85
|
+
#
|
86
|
+
# @return [Hash<state, Float>] from states to values; read only; any changes
|
87
|
+
# made to the returned object will not affect the solver
|
88
|
+
#
|
89
|
+
def value
|
90
|
+
Hash[model.states.zip(@array_value)]
|
91
|
+
end
|
92
|
+
|
93
|
+
#
|
94
|
+
# Current estimate of the optimal action for each state.
|
95
|
+
#
|
96
|
+
# @return [Hash<state, action>] from states to actions; read only; any changes
|
97
|
+
# made to the returned object will not affect the solver
|
98
|
+
#
|
99
|
+
def policy
|
100
|
+
Hash[model.states.zip(@array_policy).map{|state, action_n|
|
101
|
+
[state, model.actions(state)[action_n]]}]
|
102
|
+
end
|
103
|
+
|
104
|
+
#
|
105
|
+
# Refine the estimate of the value function for the current policy. This is
|
106
|
+
# done by iterating the Bellman equations; see also {#evaluate_policy_exact}
|
107
|
+
# for a different approach.
|
108
|
+
#
|
109
|
+
# This is the 'policy evaluation' step in Figure 4.3 of Sutton and Barto
|
110
|
+
# (1998).
|
111
|
+
#
|
112
|
+
# @return [Float] largest absolute change (over all states) in the value
|
113
|
+
# function
|
114
|
+
#
|
115
|
+
def evaluate_policy
|
116
|
+
delta = 0.0
|
117
|
+
@array_model.each_with_index do |actions, state_n|
|
118
|
+
next_state_ns = actions[@array_policy[state_n]]
|
119
|
+
new_value = backup(next_state_ns)
|
120
|
+
delta = [delta, (@array_value[state_n] - new_value).abs].max
|
121
|
+
@array_value[state_n] = new_value
|
122
|
+
end
|
123
|
+
delta
|
124
|
+
end
|
125
|
+
|
126
|
+
#
|
127
|
+
# Evaluate the value function for the current policy by solving a linear
|
128
|
+
# system of n equations in n unknowns, where n is the number of states in the
|
129
|
+
# model.
|
130
|
+
#
|
131
|
+
# This routine currently uses dense linear algebra, so it requires that the
|
132
|
+
# full n-by-n matrix be stored in memory. This may be a problem for moderately
|
133
|
+
# large n.
|
134
|
+
#
|
135
|
+
# All of the coefficients (A and b in Ax = b) are computed first call, but
|
136
|
+
# subsequent calls recompute only those rows for which the policy has changed
|
137
|
+
# since the last call.
|
138
|
+
#
|
139
|
+
# @return [nil]
|
140
|
+
#
|
141
|
+
def evaluate_policy_exact
|
142
|
+
if @policy_A
|
143
|
+
# update only those rows for which the policy has changed
|
144
|
+
@policy_A_action.zip(@array_policy).
|
145
|
+
each_with_index do |(old_action_n, new_action_n), state_n|
|
146
|
+
next if old_action_n == new_action_n
|
147
|
+
update_policy_Ab state_n, new_action_n
|
148
|
+
end
|
149
|
+
else
|
150
|
+
# initialise the A and the b for Ax = b
|
151
|
+
num_states = @array_model.size
|
152
|
+
@policy_A = NMatrix.float(num_states, num_states)
|
153
|
+
@policy_A_action = [-1]*num_states
|
154
|
+
@policy_b = NVector.float(num_states)
|
155
|
+
|
156
|
+
@array_policy.each_with_index do |action_n, state_n|
|
157
|
+
update_policy_Ab state_n, action_n
|
158
|
+
end
|
159
|
+
end
|
160
|
+
|
161
|
+
value = @policy_b / @policy_A # solve linear system
|
162
|
+
@array_value = value.to_a
|
163
|
+
nil
|
164
|
+
end
|
165
|
+
|
166
|
+
#
|
167
|
+
# Make policy greedy with respect to the current value function.
|
168
|
+
#
|
169
|
+
# This is the 'policy improvement' step in Figure 4.3 of Sutton and Barto
|
170
|
+
# (1998).
|
171
|
+
#
|
172
|
+
# @return [Boolean] false iff the policy changed for any state
|
173
|
+
#
|
174
|
+
def improve_policy
|
175
|
+
stable = true
|
176
|
+
@array_model.each_with_index do |actions, state_n|
|
177
|
+
a_max = nil
|
178
|
+
v_max = -Float::MAX
|
179
|
+
actions.each_with_index do |next_state_ns, action_n|
|
180
|
+
v = backup(next_state_ns)
|
181
|
+
if v > v_max
|
182
|
+
a_max = action_n
|
183
|
+
v_max = v
|
184
|
+
end
|
185
|
+
end
|
186
|
+
raise "no feasible actions in state #{state_n}" unless a_max
|
187
|
+
stable = false if @array_policy[state_n] != a_max
|
188
|
+
@array_policy[state_n] = a_max
|
189
|
+
end
|
190
|
+
stable
|
191
|
+
end
|
192
|
+
|
193
|
+
#
|
194
|
+
# A single iteration of value iteration.
|
195
|
+
#
|
196
|
+
# This is the algorithm from Figure 4.5 of Sutton and Barto (1998). It is
|
197
|
+
# mostly equivalent to calling {#evaluate_policy} and then {#improve_policy},
|
198
|
+
# but it is somewhat more efficient.
|
199
|
+
#
|
200
|
+
# @return [Float] largest absolute change (over all states) in the value
|
201
|
+
# function
|
202
|
+
#
|
203
|
+
def value_iteration_single
|
204
|
+
delta = 0.0
|
205
|
+
@array_model.each_with_index do |actions, state_n|
|
206
|
+
a_max = nil
|
207
|
+
v_max = -Float::MAX
|
208
|
+
actions.each_with_index do |next_state_ns, action_n|
|
209
|
+
v = backup(next_state_ns)
|
210
|
+
if v > v_max
|
211
|
+
a_max = action_n
|
212
|
+
v_max = v
|
213
|
+
end
|
214
|
+
end
|
215
|
+
delta = [delta, (@array_value[state_n] - v_max).abs].max
|
216
|
+
@array_value[state_n] = v_max
|
217
|
+
@array_policy[state_n] = a_max
|
218
|
+
end
|
219
|
+
delta
|
220
|
+
end
|
221
|
+
|
222
|
+
#
|
223
|
+
# Value iteration; call {#value_iteration_single} up to
|
224
|
+
# <tt>max_iters</tt> times until the largest change in the value function
|
225
|
+
# (<tt>delta</tt>) is less than <tt>tolerance</tt>.
|
226
|
+
#
|
227
|
+
# @param [Float] tolerance small positive number
|
228
|
+
#
|
229
|
+
# @param [Integer, nil] max_iters terminate after this many iterations, even
|
230
|
+
# if the value function has not converged; nil means that there is
|
231
|
+
# no limit on the number of iterations
|
232
|
+
#
|
233
|
+
# @return [Boolean] true iff iteration converged to within tolerance
|
234
|
+
#
|
235
|
+
def value_iteration tolerance, max_iters=nil
|
236
|
+
delta = Float::MAX
|
237
|
+
num_iters = 0
|
238
|
+
loop do
|
239
|
+
delta = value_iteration_single
|
240
|
+
num_iters += 1
|
241
|
+
|
242
|
+
break if delta < tolerance
|
243
|
+
break if max_iters && num_iters > max_iters
|
244
|
+
end
|
245
|
+
delta < tolerance
|
246
|
+
end
|
247
|
+
|
248
|
+
#
|
249
|
+
# Solve with policy iteration using approximate (iterative) policy evaluation.
|
250
|
+
#
|
251
|
+
# @param [Float] value_tolerance small positive number; the policy evaluation
|
252
|
+
# phase ends if the largest change in the value function
|
253
|
+
# (<tt>delta</tt>) is below this tolerance
|
254
|
+
#
|
255
|
+
# @param [Integer, nil] max_value_iters terminate the policy evaluation
|
256
|
+
# phase after this many iterations, even if the value function has not
|
257
|
+
# converged; nil means that there is no limit on the number of
|
258
|
+
# iterations in each policy evaluation phase
|
259
|
+
#
|
260
|
+
# @param [Integer, nil] max_policy_iters terminate after this many
|
261
|
+
# iterations, even if a stable policy has not been obtained; nil means
|
262
|
+
# that there is no limit on the number of iterations
|
263
|
+
#
|
264
|
+
# @return [Boolean] true iff a stable policy was obtained
|
265
|
+
#
|
266
|
+
def policy_iteration value_tolerance, max_value_iters=nil,
|
267
|
+
max_policy_iters=nil
|
268
|
+
|
269
|
+
stable = false
|
270
|
+
num_policy_iters = 0
|
271
|
+
loop do
|
272
|
+
# policy evaluation
|
273
|
+
num_value_iters = 0
|
274
|
+
loop do
|
275
|
+
value_delta = evaluate_policy
|
276
|
+
num_value_iters += 1
|
277
|
+
|
278
|
+
break if value_delta < value_tolerance
|
279
|
+
break if max_value_iters && num_value_iters > max_value_iters
|
280
|
+
end
|
281
|
+
|
282
|
+
# policy improvement
|
283
|
+
stable = improve_policy
|
284
|
+
num_policy_iters += 1
|
285
|
+
break if stable
|
286
|
+
break if max_policy_iters && num_policy_iters > max_policy_iters
|
287
|
+
end
|
288
|
+
stable
|
289
|
+
end
|
290
|
+
|
291
|
+
#
|
292
|
+
# Solve with policy iteration using exact policy evaluation.
|
293
|
+
#
|
294
|
+
# @param [Integer, nil] max_iters terminate after this many
|
295
|
+
# iterations, even if a stable policy has not been obtained; nil means
|
296
|
+
# that there is no limit on the number of iterations
|
297
|
+
#
|
298
|
+
# @return [Boolean] true iff a stable policy was obtained
|
299
|
+
#
|
300
|
+
def policy_iteration_exact max_iters=nil
|
301
|
+
stable = false
|
302
|
+
num_iters = 0
|
303
|
+
loop do
|
304
|
+
evaluate_policy_exact
|
305
|
+
stable = improve_policy
|
306
|
+
num_iters += 1
|
307
|
+
break if stable
|
308
|
+
break if max_iters && num_iters > max_iters
|
309
|
+
end
|
310
|
+
stable
|
311
|
+
end
|
312
|
+
|
313
|
+
private
|
314
|
+
|
315
|
+
#
|
316
|
+
# Updated value estimate for a state with the given successor states.
|
317
|
+
#
|
318
|
+
def backup next_state_ns
|
319
|
+
next_state_ns.map {|next_state_n, probability, reward|
|
320
|
+
probability*(reward + @discount*@array_value[next_state_n])
|
321
|
+
}.inject(:+)
|
322
|
+
end
|
323
|
+
|
324
|
+
#
|
325
|
+
# Update the row in A the entry in b (in Ax=b) for the given state; see
|
326
|
+
# {#evaluate_policy_exact}.
|
327
|
+
#
|
328
|
+
def update_policy_Ab state_n, action_n
|
329
|
+
# clear out the old values for state_n's row
|
330
|
+
@policy_A[true, state_n] = 0.0
|
331
|
+
|
332
|
+
# set new values according to state_n's successors under the current policy
|
333
|
+
b_n = 0
|
334
|
+
next_state_ns = @array_model[state_n][action_n]
|
335
|
+
next_state_ns.each do |next_state_n, probability, reward|
|
336
|
+
@policy_A[next_state_n, state_n] = -@discount*probability
|
337
|
+
b_n += probability*reward
|
338
|
+
end
|
339
|
+
@policy_A[state_n, state_n] += 1
|
340
|
+
@policy_A_action[state_n] = action_n
|
341
|
+
@policy_b[state_n] = b_n
|
342
|
+
end
|
343
|
+
end
|
344
|
+
|
@@ -0,0 +1,122 @@
|
|
1
|
+
#
|
2
|
+
# A finite markov decision process model for which the states, actions,
|
3
|
+
# transition probabilities and rewards are specified as a table. This is a
|
4
|
+
# common way of specifying small models.
|
5
|
+
#
|
6
|
+
# The states and actions can be arbitrary objects; see notes for {Model}.
|
7
|
+
#
|
8
|
+
class FiniteMDP::TableModel
|
9
|
+
include FiniteMDP::Model
|
10
|
+
|
11
|
+
#
|
12
|
+
# @param [Array<[state, action, state, Float, Float]>] rows each row is
|
13
|
+
# [state, action, next state, probability, reward]
|
14
|
+
#
|
15
|
+
def initialize rows
|
16
|
+
@rows = rows
|
17
|
+
end
|
18
|
+
|
19
|
+
#
|
20
|
+
# @return [Array<[state, action, state, Float, Float]>] each row is [state,
|
21
|
+
# action, next state, probability, reward]
|
22
|
+
#
|
23
|
+
attr_accessor :rows
|
24
|
+
|
25
|
+
#
|
26
|
+
# States in this model; see {Model#states}.
|
27
|
+
#
|
28
|
+
# @return [Array<state>] not empty; no duplicate states
|
29
|
+
#
|
30
|
+
def states
|
31
|
+
@rows.map{|row| row[0]}.uniq
|
32
|
+
end
|
33
|
+
|
34
|
+
#
|
35
|
+
# Actions that are valid for the given state; see {Model#actions}.
|
36
|
+
#
|
37
|
+
# @param [state] state
|
38
|
+
#
|
39
|
+
# @return [Array<action>] not empty; no duplicate actions
|
40
|
+
#
|
41
|
+
def actions state
|
42
|
+
@rows.map{|row| row[1] if row[0] == state}.compact.uniq
|
43
|
+
end
|
44
|
+
|
45
|
+
#
|
46
|
+
# Possible successor states after taking the given action in the given state;
|
47
|
+
# see {Model#next_states}.
|
48
|
+
#
|
49
|
+
# @param [state] state
|
50
|
+
#
|
51
|
+
# @param [action] action
|
52
|
+
#
|
53
|
+
# @return [Array<state>] not empty; no duplicate states
|
54
|
+
#
|
55
|
+
def next_states state, action
|
56
|
+
@rows.map{|row| row[2] if row[0] == state && row[1] == action}.compact
|
57
|
+
end
|
58
|
+
|
59
|
+
#
|
60
|
+
# Probability of the given transition; see {Model#transition_probability}.
|
61
|
+
#
|
62
|
+
# @param [state] state
|
63
|
+
#
|
64
|
+
# @param [action] action
|
65
|
+
#
|
66
|
+
# @param [state] next_state
|
67
|
+
#
|
68
|
+
# @return [Float] in [0, 1]; zero if the transition is not in the table
|
69
|
+
#
|
70
|
+
def transition_probability state, action, next_state
|
71
|
+
@rows.map{|row| row[3] if row[0] == state &&
|
72
|
+
row[1] == action && row[2] == next_state}.compact.first || 0
|
73
|
+
end
|
74
|
+
|
75
|
+
#
|
76
|
+
# Reward for a given transition; see {Model#reward}.
|
77
|
+
#
|
78
|
+
# @param [state] state
|
79
|
+
#
|
80
|
+
# @param [action] action
|
81
|
+
#
|
82
|
+
# @param [state] next_state
|
83
|
+
#
|
84
|
+
# @return [Float, nil] nil if the transition is not in the table
|
85
|
+
#
|
86
|
+
def reward state, action, next_state
|
87
|
+
@rows.map{|row| row[4] if row[0] == state &&
|
88
|
+
row[1] == action && row[2] == next_state}.compact.first
|
89
|
+
end
|
90
|
+
|
91
|
+
#
|
92
|
+
# @return [String] can be quite large
|
93
|
+
#
|
94
|
+
def inspect
|
95
|
+
rows.map(&:inspect).join("\n")
|
96
|
+
end
|
97
|
+
|
98
|
+
#
|
99
|
+
# Convert any model into a table model.
|
100
|
+
#
|
101
|
+
# @param [Model] model
|
102
|
+
#
|
103
|
+
# @param [Boolean] sparse do not store rows for transitions with zero
|
104
|
+
# probability
|
105
|
+
#
|
106
|
+
# @return [TableModel]
|
107
|
+
#
|
108
|
+
def self.from_model model, sparse=true
|
109
|
+
rows = []
|
110
|
+
model.states.each do |state|
|
111
|
+
model.actions(state).each do |action|
|
112
|
+
model.next_states(state, action).each do |next_state|
|
113
|
+
pr = model.transition_probability(state, action, next_state)
|
114
|
+
rows << [state, action, next_state, pr,
|
115
|
+
model.reward(state, action, next_state)] if pr > 0 || !sparse
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|
119
|
+
FiniteMDP::TableModel.new(rows)
|
120
|
+
end
|
121
|
+
end
|
122
|
+
|
@@ -0,0 +1,46 @@
|
|
1
|
+
#
|
2
|
+
# Define an object's hash code and equality (in the sense of <tt>eql?</tt>)
|
3
|
+
# according to its array representation (<tt>to_a</tt>). See notes for {Model}
|
4
|
+
# for why this might be useful.
|
5
|
+
#
|
6
|
+
# A class that includes this module must define <tt>to_a</tt>.
|
7
|
+
#
|
8
|
+
# @example
|
9
|
+
#
|
10
|
+
# class MyPoint
|
11
|
+
# include FiniteMDP::VectorValued
|
12
|
+
#
|
13
|
+
# def initialize x, y
|
14
|
+
# @x, @y = x, y
|
15
|
+
# end
|
16
|
+
#
|
17
|
+
# attr_accessor :x, :y
|
18
|
+
#
|
19
|
+
# # must implement to_a to make VectorValued work
|
20
|
+
# def to_a
|
21
|
+
# [x, y]
|
22
|
+
# end
|
23
|
+
# end
|
24
|
+
#
|
25
|
+
# MyPoint.new(0, 0).eql?(MyPoint.new(0, 0)) #=> true as expected
|
26
|
+
#
|
27
|
+
module FiniteMDP::VectorValued
|
28
|
+
#
|
29
|
+
# Redefine hashing based on +to_a+.
|
30
|
+
#
|
31
|
+
# @return [Integer]
|
32
|
+
#
|
33
|
+
def hash
|
34
|
+
self.to_a.hash
|
35
|
+
end
|
36
|
+
|
37
|
+
#
|
38
|
+
# Redefine equality based on +to_a+.
|
39
|
+
#
|
40
|
+
# @return [Boolean]
|
41
|
+
#
|
42
|
+
def eql? state
|
43
|
+
self.to_a.eql? state.to_a
|
44
|
+
end
|
45
|
+
end
|
46
|
+
|
data/lib/finite_mdp.rb
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
require 'enumerator'
|
2
|
+
|
3
|
+
require 'finite_mdp/version'
|
4
|
+
require 'finite_mdp/vector_valued'
|
5
|
+
require 'finite_mdp/model'
|
6
|
+
require 'finite_mdp/hash_model'
|
7
|
+
require 'finite_mdp/table_model'
|
8
|
+
require 'finite_mdp/solver'
|
9
|
+
|
10
|
+
# TODO maybe for efficiency it would be worth including a special case for
|
11
|
+
# models in which rewards depend only on the state -- a few minor
|
12
|
+
# simplifications are possible in the solver, but it won't make a huge
|
13
|
+
# difference.
|
14
|
+
|