finite_mdp 0.0.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,347 @@
1
+ # uncomment for coverage in ruby 1.9
2
+ #require 'simplecov'
3
+ #SimpleCov.start
4
+
5
+ require 'test/unit'
6
+ require 'finite_mdp'
7
+ require 'set'
8
+
9
+ class TestFiniteMDP < Test::Unit::TestCase
10
+ include FiniteMDP
11
+
12
+ # check that we get the same model back; model parameters must be set before
13
+ # calling; see test_recycling_robot
14
+ def check_recycling_robot_model model, sparse
15
+ model.check_transition_probabilities_sum
16
+
17
+ assert_equal Set[:high, :low], Set[*model.states]
18
+ assert_equal Set[:search, :wait], Set[*model.actions(:high)]
19
+ assert_equal Set[:search, :wait, :recharge], Set[*model.actions(:low)]
20
+
21
+ if sparse
22
+ assert_equal [:low], model.next_states(:low, :wait)
23
+ assert_equal [:high], model.next_states(:low, :recharge)
24
+ assert_equal [:high], model.next_states(:high, :wait)
25
+ else
26
+ assert_equal Set[:high, :low], Set[*model.next_states(:low, :wait)]
27
+ assert_equal Set[:high, :low], Set[*model.next_states(:low, :recharge)]
28
+ assert_equal Set[:high, :low], Set[*model.next_states(:high, :wait)]
29
+ end
30
+ assert_equal Set[:high, :low], Set[*model.next_states(:low, :search)]
31
+ assert_equal Set[:high, :low], Set[*model.next_states(:high, :search)]
32
+
33
+ assert_equal 1-@beta, model.transition_probability(:low, :search, :high)
34
+ assert_equal @beta, model.transition_probability(:low, :search, :low)
35
+ assert_equal 0, model.transition_probability(:low, :wait, :high)
36
+ assert_equal 1, model.transition_probability(:low, :wait, :low)
37
+ assert_equal 1, model.transition_probability(:low, :recharge, :high)
38
+ assert_equal 0, model.transition_probability(:low, :recharge, :low)
39
+
40
+ assert_equal @alpha, model.transition_probability(:high, :search, :high)
41
+ assert_equal 1-@alpha, model.transition_probability(:high, :search, :low)
42
+ assert_equal 1, model.transition_probability(:high, :wait, :high)
43
+ assert_equal 0, model.transition_probability(:high, :wait, :low)
44
+
45
+ assert_equal @r_rescue, model.reward(:low, :search, :high)
46
+ assert_equal @r_search, model.reward(:low, :search, :low)
47
+ assert_equal @r_wait, model.reward(:low, :wait, :low)
48
+ assert_equal 0, model.reward(:low, :recharge, :high)
49
+
50
+ assert_equal @r_search, model.reward(:high, :search, :high)
51
+ assert_equal @r_search, model.reward(:high, :search, :low)
52
+ assert_equal @r_wait, model.reward(:high, :wait, :high)
53
+
54
+ if sparse
55
+ assert_equal nil, model.reward(:low, :wait, :high)
56
+ assert_equal nil, model.reward(:low, :recharge, :low)
57
+ assert_equal nil, model.reward(:high, :wait, :low)
58
+ else
59
+ assert_equal @r_wait, model.reward(:low, :wait, :high)
60
+ assert_equal 0, model.reward(:low, :recharge, :low)
61
+ assert_equal @r_wait, model.reward(:high, :wait, :low)
62
+ end
63
+ end
64
+
65
+ #
66
+ # Example 3.7 from Sutton and Barto (1998).
67
+ #
68
+ def test_recycling_robot
69
+ @alpha = 0.1
70
+ @beta = 0.1
71
+ @r_search = 2
72
+ @r_wait = 1
73
+ @r_rescue = -3
74
+
75
+ table_model = TableModel.new [
76
+ [:high, :search, :high, @alpha, @r_search],
77
+ [:high, :search, :low, 1-@alpha, @r_search],
78
+ [:low, :search, :high, 1-@beta, @r_rescue],
79
+ [:low, :search, :low, @beta, @r_search],
80
+ [:high, :wait, :high, 1, @r_wait],
81
+ [:high, :wait, :low, 0, @r_wait],
82
+ [:low, :wait, :high, 0, @r_wait],
83
+ [:low, :wait, :low, 1, @r_wait],
84
+ [:low, :recharge, :high, 1, 0],
85
+ [:low, :recharge, :low, 0, 0]]
86
+
87
+ assert_equal 10, table_model.rows.size
88
+
89
+ # check round trips for different model formats; don't sparsify yet
90
+ check_recycling_robot_model table_model, false
91
+ check_recycling_robot_model TableModel.from_model(table_model, false), false
92
+
93
+ hash_model = HashModel.from_model(table_model, false)
94
+ check_recycling_robot_model hash_model, false
95
+ check_recycling_robot_model TableModel.from_model(hash_model, false), false
96
+
97
+ # if we sparsify, we should lose some rows
98
+ sparse_table_model = TableModel.from_model(table_model)
99
+ assert_equal 7, sparse_table_model.rows.size
100
+ check_recycling_robot_model sparse_table_model, true
101
+
102
+ sparse_hash_model = HashModel.from_model(table_model)
103
+ check_recycling_robot_model sparse_hash_model, true
104
+
105
+ # once they're gone, they don't come back
106
+ sparse_hash_model = HashModel.from_model(sparse_table_model, false)
107
+ check_recycling_robot_model sparse_hash_model, true
108
+
109
+ # try solving with value iteration
110
+ solver = Solver.new(table_model, 0.95, Hash.new {:wait})
111
+ assert solver.value_iteration(1e-4, 200), "did not converge"
112
+ assert_equal({:high => :search, :low => :recharge}, solver.policy)
113
+
114
+ # try solving with policy iteration using iterative policy evaluation
115
+ solver = Solver.new(table_model, 0.95, Hash.new {:wait})
116
+ assert solver.policy_iteration(1e-4, 2, 20), "did not find stable policy"
117
+ assert_equal({:high => :search, :low => :recharge}, solver.policy)
118
+
119
+ # try solving with policy iteration using exact policy evaluation
120
+ solver = Solver.new(table_model, 0.95, Hash.new {:wait})
121
+ assert solver.policy_iteration_exact(20), "did not find stable policy"
122
+ assert_equal({:high => :search, :low => :recharge}, solver.policy)
123
+ end
124
+
125
+ #
126
+ # An example model for testing; taken from Russel, Norvig (2003). Artificial
127
+ # Intelligence: A Modern Approach, Chapter 17.
128
+ #
129
+ # See http://aima.cs.berkeley.edu/python/mdp.html for a Python implementation.
130
+ #
131
+ class AIMAGridModel
132
+ include FiniteMDP::Model
133
+
134
+ #
135
+ # @param [Array<Array<Float, nil>>] grid rewards at each point, or nil if a
136
+ # grid square is an obstacle
137
+ #
138
+ # @param [Array<[i, j]>] terminii coordinates of the terminal states
139
+ #
140
+ def initialize grid, terminii
141
+ @grid, @terminii = grid, terminii
142
+ end
143
+
144
+ attr_reader :grid, :terminii
145
+
146
+ # every position on the grid is a state, except for obstacles, which are
147
+ # indicated by a nil in the grid
148
+ def states
149
+ is, js = (0...grid.size).to_a, (0...grid.first.size).to_a
150
+ is.product(js).select {|i, j| grid[i][j]} + [:stop]
151
+ end
152
+
153
+ # can move north, east, south or west on the grid
154
+ MOVES = {
155
+ '^' => [-1, 0],
156
+ '>' => [ 0, 1],
157
+ 'v' => [ 1, 0],
158
+ '<' => [ 0, -1]}
159
+
160
+ # agent can move north, south, east or west (unless it's in the :stop
161
+ # state); if it tries to move off the grid or into an obstacle, it stays
162
+ # where it is
163
+ def actions state
164
+ if state == :stop || terminii.member?(state)
165
+ [:stop]
166
+ else
167
+ MOVES.keys
168
+ end
169
+ end
170
+
171
+ # define the transition model
172
+ def transition_probability state, action, next_state
173
+ if state == :stop || terminii.member?(state)
174
+ (action == :stop && next_state == :stop) ? 1 : 0
175
+ else
176
+ # agent usually succeeds in moving forward, but sometimes it ends up
177
+ # moving left or right
178
+ move = case action
179
+ when '^' then [['^', 0.8], ['<', 0.1], ['>', 0.1]]
180
+ when '>' then [['>', 0.8], ['^', 0.1], ['v', 0.1]]
181
+ when 'v' then [['v', 0.8], ['<', 0.1], ['>', 0.1]]
182
+ when '<' then [['<', 0.8], ['^', 0.1], ['v', 0.1]]
183
+ end
184
+ move.map {|m, pr|
185
+ m_state = [state[0] + MOVES[m][0], state[1] + MOVES[m][1]]
186
+ m_state = state unless states.member?(m_state) # stay in bounds
187
+ pr if m_state == next_state
188
+ }.compact.inject(:+) || 0
189
+ end
190
+ end
191
+
192
+ # reward is given by the grid cells; zero reward for the :stop state
193
+ def reward state, action, next_state
194
+ state == :stop ? 0 : grid[state[0]][state[1]]
195
+ end
196
+
197
+ # helper for functions below
198
+ def hash_to_grid hash
199
+ 0.upto(grid.size-1).map{|i| 0.upto(grid[i].size-1).map{|j| hash[[i,j]]}}
200
+ end
201
+
202
+ # print the values in a grid
203
+ def pretty_value value
204
+ hash_to_grid(Hash[value.map {|s, v| [s, "%+.3f" % v]}]).map{|row|
205
+ row.map{|cell| cell || ' '}.join(' ')}
206
+ end
207
+
208
+ # print the policy using ASCII arrows
209
+ def pretty_policy policy
210
+ hash_to_grid(policy).map{|row| row.map{|cell|
211
+ (cell.nil? || cell == :stop) ? ' ' : cell}.join(' ')}
212
+ end
213
+ end
214
+
215
+ def check_grid_solutions model, pretty_policy
216
+ # solve with policy iteration (approximate policy evaluation)
217
+ solver = Solver.new(model, 1)
218
+ assert solver.policy_iteration(1e-5, 10, 50), "did not converge"
219
+ assert_equal pretty_policy, model.pretty_policy(solver.policy)
220
+
221
+ # solve with policy (exact policy evaluation)
222
+ solver = Solver.new(model, 0.9999) # discount 1 gives singular matrix
223
+ assert solver.policy_iteration_exact(20), "did not converge"
224
+ assert_equal pretty_policy, model.pretty_policy(solver.policy)
225
+
226
+ # solve with value iteration
227
+ solver = Solver.new(model, 1)
228
+ assert solver.value_iteration(1e-5, 100), "did not converge"
229
+ assert_equal pretty_policy, model.pretty_policy(solver.policy)
230
+
231
+ solver
232
+ end
233
+
234
+ def test_aima_grid_1
235
+ # the grid from Figures 17.1, 17.2(a) and 17.3
236
+ model = AIMAGridModel.new(
237
+ [[-0.04, -0.04, -0.04, +1],
238
+ [-0.04, nil, -0.04, -1],
239
+ [-0.04, -0.04, -0.04, -0.04]],
240
+ [[0, 3], [1, 3]]) # terminals (the +1 and -1 states)
241
+ model.check_transition_probabilities_sum
242
+
243
+ assert_equal Set[
244
+ [0, 0], [0, 1], [0, 2], [0, 3],
245
+ [1, 0], [1, 2], [1, 3],
246
+ [2, 0], [2, 1], [2, 2], [2, 3], :stop], Set[*model.states]
247
+
248
+ assert_equal Set[*%w(^ > v <)], Set[*model.actions([0, 0])]
249
+ assert_equal [:stop], model.actions([1, 3])
250
+ assert_equal [:stop], model.actions(:stop)
251
+
252
+ # check policy against Figure 17.2(a)
253
+ solver = check_grid_solutions model,
254
+ ["> > > ",
255
+ "^ ^ ",
256
+ "^ < < <"]
257
+
258
+ # check the actual (non-pretty) policy
259
+ assert_equal [
260
+ ['>', '>', '>', :stop],
261
+ ['^', nil, '^', :stop],
262
+ ['^', '<', '<', '<']], model.hash_to_grid(solver.policy)
263
+
264
+ # check values against Figure 17.3
265
+ assert [[0.812, 0.868, 0.918, 1],
266
+ [0.762, nil, 0.660, -1],
267
+ [0.705, 0.655, 0.611, 0.388]].flatten.
268
+ zip(model.hash_to_grid(solver.value).flatten).
269
+ all? {|x,y| (x.nil? && y.nil?) || (x-y).abs < 5e-4}
270
+ end
271
+
272
+ def test_aima_grid_2
273
+ # a grid from Figure 17.2(b)
274
+ r = -1.7
275
+ model = AIMAGridModel.new(
276
+ [[ r, r, r, +1],
277
+ [ r, nil, r, -1],
278
+ [ r, r, r, r]],
279
+ [[0, 3], [1, 3]]) # terminals (the +1 and -1 states)
280
+ model.check_transition_probabilities_sum
281
+
282
+ check_grid_solutions model,
283
+ ["> > > ",
284
+ "^ > ",
285
+ "> > > ^"]
286
+ end
287
+
288
+ def test_aima_grid_3
289
+ # a grid from Figure 17.2(b)
290
+ r = -0.3
291
+ model = AIMAGridModel.new(
292
+ [[ r, r, r, +1],
293
+ [ r, nil, r, -1],
294
+ [ r, r, r, r]],
295
+ [[0, 3], [1, 3]]) # terminals (the +1 and -1 states)
296
+ model.check_transition_probabilities_sum
297
+
298
+ check_grid_solutions model,
299
+ ["> > > ",
300
+ "^ ^ ",
301
+ "^ > ^ <"]
302
+ end
303
+
304
+ def test_aima_grid_4
305
+ # a grid from Figure 17.2(b)
306
+ r = -0.01
307
+ model = AIMAGridModel.new(
308
+ [[ r, r, r, +1],
309
+ [ r, nil, r, -1],
310
+ [ r, r, r, r]],
311
+ [[0, 3], [1, 3]]) # terminals (the +1 and -1 states)
312
+ model.check_transition_probabilities_sum
313
+
314
+ check_grid_solutions model,
315
+ ["> > > ",
316
+ "^ < ",
317
+ "^ < < v"]
318
+ end
319
+
320
+ class MyPoint
321
+ include FiniteMDP::VectorValued
322
+
323
+ def initialize x, y
324
+ @x, @y = x, y
325
+ end
326
+
327
+ attr_accessor :x, :y
328
+
329
+ # must implement to_a to make VectorValued work
330
+ def to_a
331
+ [x, y]
332
+ end
333
+ end
334
+
335
+ def test_vector_valued
336
+ p1 = MyPoint.new(0, 0)
337
+ p2 = MyPoint.new(0, 1)
338
+ p3 = MyPoint.new(0, 0)
339
+
340
+ assert !p1.eql?(p2)
341
+ assert !p3.eql?(p2)
342
+ assert p1.eql?(p1)
343
+ assert p1.eql?(p3)
344
+ assert_equal p1.hash, p3.hash
345
+ end
346
+ end
347
+
metadata ADDED
@@ -0,0 +1,94 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: finite_mdp
3
+ version: !ruby/object:Gem::Version
4
+ prerelease:
5
+ version: 0.0.1
6
+ platform: ruby
7
+ authors:
8
+ - John Lees-Miller
9
+ autorequire:
10
+ bindir: bin
11
+ cert_chain: []
12
+
13
+ date: 2011-04-17 00:00:00 Z
14
+ dependencies:
15
+ - !ruby/object:Gem::Dependency
16
+ name: narray
17
+ prerelease: false
18
+ requirement: &id001 !ruby/object:Gem::Requirement
19
+ none: false
20
+ requirements:
21
+ - - ">="
22
+ - !ruby/object:Gem::Version
23
+ version: 0.5.9
24
+ - - ~>
25
+ - !ruby/object:Gem::Version
26
+ version: "0"
27
+ type: :runtime
28
+ version_requirements: *id001
29
+ - !ruby/object:Gem::Dependency
30
+ name: gemma
31
+ prerelease: false
32
+ requirement: &id002 !ruby/object:Gem::Requirement
33
+ none: false
34
+ requirements:
35
+ - - ">="
36
+ - !ruby/object:Gem::Version
37
+ version: 1.0.1
38
+ - - ~>
39
+ - !ruby/object:Gem::Version
40
+ version: "1.0"
41
+ type: :development
42
+ version_requirements: *id002
43
+ description: Solve small finite Markov Decision Process models.
44
+ email:
45
+ - jdleesmiller@gmail.com
46
+ executables: []
47
+
48
+ extensions: []
49
+
50
+ extra_rdoc_files:
51
+ - README.rdoc
52
+ files:
53
+ - lib/finite_mdp/hash_model.rb
54
+ - lib/finite_mdp/vector_valued.rb
55
+ - lib/finite_mdp/model.rb
56
+ - lib/finite_mdp/version.rb
57
+ - lib/finite_mdp/solver.rb
58
+ - lib/finite_mdp/table_model.rb
59
+ - lib/finite_mdp.rb
60
+ - README.rdoc
61
+ - test/finite_mdp_test.rb
62
+ homepage: http://github.com/jdleesmiller/finite_mdp
63
+ licenses: []
64
+
65
+ post_install_message:
66
+ rdoc_options:
67
+ - --main
68
+ - README.rdoc
69
+ - --title
70
+ - finite_mdp-0.0.1 Documentation
71
+ require_paths:
72
+ - lib
73
+ required_ruby_version: !ruby/object:Gem::Requirement
74
+ none: false
75
+ requirements:
76
+ - - ">="
77
+ - !ruby/object:Gem::Version
78
+ version: "0"
79
+ required_rubygems_version: !ruby/object:Gem::Requirement
80
+ none: false
81
+ requirements:
82
+ - - ">="
83
+ - !ruby/object:Gem::Version
84
+ version: "0"
85
+ requirements: []
86
+
87
+ rubyforge_project: finite_mdp
88
+ rubygems_version: 1.7.2
89
+ signing_key:
90
+ specification_version: 3
91
+ summary: Solve small finite Markov Decision Process models.
92
+ test_files:
93
+ - test/finite_mdp_test.rb
94
+ has_rdoc: