finite_mdp 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,347 @@
1
+ # uncomment for coverage in ruby 1.9
2
+ #require 'simplecov'
3
+ #SimpleCov.start
4
+
5
+ require 'test/unit'
6
+ require 'finite_mdp'
7
+ require 'set'
8
+
9
+ class TestFiniteMDP < Test::Unit::TestCase
10
+ include FiniteMDP
11
+
12
+ # check that we get the same model back; model parameters must be set before
13
+ # calling; see test_recycling_robot
14
+ def check_recycling_robot_model model, sparse
15
+ model.check_transition_probabilities_sum
16
+
17
+ assert_equal Set[:high, :low], Set[*model.states]
18
+ assert_equal Set[:search, :wait], Set[*model.actions(:high)]
19
+ assert_equal Set[:search, :wait, :recharge], Set[*model.actions(:low)]
20
+
21
+ if sparse
22
+ assert_equal [:low], model.next_states(:low, :wait)
23
+ assert_equal [:high], model.next_states(:low, :recharge)
24
+ assert_equal [:high], model.next_states(:high, :wait)
25
+ else
26
+ assert_equal Set[:high, :low], Set[*model.next_states(:low, :wait)]
27
+ assert_equal Set[:high, :low], Set[*model.next_states(:low, :recharge)]
28
+ assert_equal Set[:high, :low], Set[*model.next_states(:high, :wait)]
29
+ end
30
+ assert_equal Set[:high, :low], Set[*model.next_states(:low, :search)]
31
+ assert_equal Set[:high, :low], Set[*model.next_states(:high, :search)]
32
+
33
+ assert_equal 1-@beta, model.transition_probability(:low, :search, :high)
34
+ assert_equal @beta, model.transition_probability(:low, :search, :low)
35
+ assert_equal 0, model.transition_probability(:low, :wait, :high)
36
+ assert_equal 1, model.transition_probability(:low, :wait, :low)
37
+ assert_equal 1, model.transition_probability(:low, :recharge, :high)
38
+ assert_equal 0, model.transition_probability(:low, :recharge, :low)
39
+
40
+ assert_equal @alpha, model.transition_probability(:high, :search, :high)
41
+ assert_equal 1-@alpha, model.transition_probability(:high, :search, :low)
42
+ assert_equal 1, model.transition_probability(:high, :wait, :high)
43
+ assert_equal 0, model.transition_probability(:high, :wait, :low)
44
+
45
+ assert_equal @r_rescue, model.reward(:low, :search, :high)
46
+ assert_equal @r_search, model.reward(:low, :search, :low)
47
+ assert_equal @r_wait, model.reward(:low, :wait, :low)
48
+ assert_equal 0, model.reward(:low, :recharge, :high)
49
+
50
+ assert_equal @r_search, model.reward(:high, :search, :high)
51
+ assert_equal @r_search, model.reward(:high, :search, :low)
52
+ assert_equal @r_wait, model.reward(:high, :wait, :high)
53
+
54
+ if sparse
55
+ assert_equal nil, model.reward(:low, :wait, :high)
56
+ assert_equal nil, model.reward(:low, :recharge, :low)
57
+ assert_equal nil, model.reward(:high, :wait, :low)
58
+ else
59
+ assert_equal @r_wait, model.reward(:low, :wait, :high)
60
+ assert_equal 0, model.reward(:low, :recharge, :low)
61
+ assert_equal @r_wait, model.reward(:high, :wait, :low)
62
+ end
63
+ end
64
+
65
+ #
66
+ # Example 3.7 from Sutton and Barto (1998).
67
+ #
68
+ def test_recycling_robot
69
+ @alpha = 0.1
70
+ @beta = 0.1
71
+ @r_search = 2
72
+ @r_wait = 1
73
+ @r_rescue = -3
74
+
75
+ table_model = TableModel.new [
76
+ [:high, :search, :high, @alpha, @r_search],
77
+ [:high, :search, :low, 1-@alpha, @r_search],
78
+ [:low, :search, :high, 1-@beta, @r_rescue],
79
+ [:low, :search, :low, @beta, @r_search],
80
+ [:high, :wait, :high, 1, @r_wait],
81
+ [:high, :wait, :low, 0, @r_wait],
82
+ [:low, :wait, :high, 0, @r_wait],
83
+ [:low, :wait, :low, 1, @r_wait],
84
+ [:low, :recharge, :high, 1, 0],
85
+ [:low, :recharge, :low, 0, 0]]
86
+
87
+ assert_equal 10, table_model.rows.size
88
+
89
+ # check round trips for different model formats; don't sparsify yet
90
+ check_recycling_robot_model table_model, false
91
+ check_recycling_robot_model TableModel.from_model(table_model, false), false
92
+
93
+ hash_model = HashModel.from_model(table_model, false)
94
+ check_recycling_robot_model hash_model, false
95
+ check_recycling_robot_model TableModel.from_model(hash_model, false), false
96
+
97
+ # if we sparsify, we should lose some rows
98
+ sparse_table_model = TableModel.from_model(table_model)
99
+ assert_equal 7, sparse_table_model.rows.size
100
+ check_recycling_robot_model sparse_table_model, true
101
+
102
+ sparse_hash_model = HashModel.from_model(table_model)
103
+ check_recycling_robot_model sparse_hash_model, true
104
+
105
+ # once they're gone, they don't come back
106
+ sparse_hash_model = HashModel.from_model(sparse_table_model, false)
107
+ check_recycling_robot_model sparse_hash_model, true
108
+
109
+ # try solving with value iteration
110
+ solver = Solver.new(table_model, 0.95, Hash.new {:wait})
111
+ assert solver.value_iteration(1e-4, 200), "did not converge"
112
+ assert_equal({:high => :search, :low => :recharge}, solver.policy)
113
+
114
+ # try solving with policy iteration using iterative policy evaluation
115
+ solver = Solver.new(table_model, 0.95, Hash.new {:wait})
116
+ assert solver.policy_iteration(1e-4, 2, 20), "did not find stable policy"
117
+ assert_equal({:high => :search, :low => :recharge}, solver.policy)
118
+
119
+ # try solving with policy iteration using exact policy evaluation
120
+ solver = Solver.new(table_model, 0.95, Hash.new {:wait})
121
+ assert solver.policy_iteration_exact(20), "did not find stable policy"
122
+ assert_equal({:high => :search, :low => :recharge}, solver.policy)
123
+ end
124
+
125
+ #
126
+ # An example model for testing; taken from Russel, Norvig (2003). Artificial
127
+ # Intelligence: A Modern Approach, Chapter 17.
128
+ #
129
+ # See http://aima.cs.berkeley.edu/python/mdp.html for a Python implementation.
130
+ #
131
+ class AIMAGridModel
132
+ include FiniteMDP::Model
133
+
134
+ #
135
+ # @param [Array<Array<Float, nil>>] grid rewards at each point, or nil if a
136
+ # grid square is an obstacle
137
+ #
138
+ # @param [Array<[i, j]>] terminii coordinates of the terminal states
139
+ #
140
+ def initialize grid, terminii
141
+ @grid, @terminii = grid, terminii
142
+ end
143
+
144
+ attr_reader :grid, :terminii
145
+
146
+ # every position on the grid is a state, except for obstacles, which are
147
+ # indicated by a nil in the grid
148
+ def states
149
+ is, js = (0...grid.size).to_a, (0...grid.first.size).to_a
150
+ is.product(js).select {|i, j| grid[i][j]} + [:stop]
151
+ end
152
+
153
+ # can move north, east, south or west on the grid
154
+ MOVES = {
155
+ '^' => [-1, 0],
156
+ '>' => [ 0, 1],
157
+ 'v' => [ 1, 0],
158
+ '<' => [ 0, -1]}
159
+
160
+ # agent can move north, south, east or west (unless it's in the :stop
161
+ # state); if it tries to move off the grid or into an obstacle, it stays
162
+ # where it is
163
+ def actions state
164
+ if state == :stop || terminii.member?(state)
165
+ [:stop]
166
+ else
167
+ MOVES.keys
168
+ end
169
+ end
170
+
171
+ # define the transition model
172
+ def transition_probability state, action, next_state
173
+ if state == :stop || terminii.member?(state)
174
+ (action == :stop && next_state == :stop) ? 1 : 0
175
+ else
176
+ # agent usually succeeds in moving forward, but sometimes it ends up
177
+ # moving left or right
178
+ move = case action
179
+ when '^' then [['^', 0.8], ['<', 0.1], ['>', 0.1]]
180
+ when '>' then [['>', 0.8], ['^', 0.1], ['v', 0.1]]
181
+ when 'v' then [['v', 0.8], ['<', 0.1], ['>', 0.1]]
182
+ when '<' then [['<', 0.8], ['^', 0.1], ['v', 0.1]]
183
+ end
184
+ move.map {|m, pr|
185
+ m_state = [state[0] + MOVES[m][0], state[1] + MOVES[m][1]]
186
+ m_state = state unless states.member?(m_state) # stay in bounds
187
+ pr if m_state == next_state
188
+ }.compact.inject(:+) || 0
189
+ end
190
+ end
191
+
192
+ # reward is given by the grid cells; zero reward for the :stop state
193
+ def reward state, action, next_state
194
+ state == :stop ? 0 : grid[state[0]][state[1]]
195
+ end
196
+
197
+ # helper for functions below
198
+ def hash_to_grid hash
199
+ 0.upto(grid.size-1).map{|i| 0.upto(grid[i].size-1).map{|j| hash[[i,j]]}}
200
+ end
201
+
202
+ # print the values in a grid
203
+ def pretty_value value
204
+ hash_to_grid(Hash[value.map {|s, v| [s, "%+.3f" % v]}]).map{|row|
205
+ row.map{|cell| cell || ' '}.join(' ')}
206
+ end
207
+
208
+ # print the policy using ASCII arrows
209
+ def pretty_policy policy
210
+ hash_to_grid(policy).map{|row| row.map{|cell|
211
+ (cell.nil? || cell == :stop) ? ' ' : cell}.join(' ')}
212
+ end
213
+ end
214
+
215
+ def check_grid_solutions model, pretty_policy
216
+ # solve with policy iteration (approximate policy evaluation)
217
+ solver = Solver.new(model, 1)
218
+ assert solver.policy_iteration(1e-5, 10, 50), "did not converge"
219
+ assert_equal pretty_policy, model.pretty_policy(solver.policy)
220
+
221
+ # solve with policy (exact policy evaluation)
222
+ solver = Solver.new(model, 0.9999) # discount 1 gives singular matrix
223
+ assert solver.policy_iteration_exact(20), "did not converge"
224
+ assert_equal pretty_policy, model.pretty_policy(solver.policy)
225
+
226
+ # solve with value iteration
227
+ solver = Solver.new(model, 1)
228
+ assert solver.value_iteration(1e-5, 100), "did not converge"
229
+ assert_equal pretty_policy, model.pretty_policy(solver.policy)
230
+
231
+ solver
232
+ end
233
+
234
+ def test_aima_grid_1
235
+ # the grid from Figures 17.1, 17.2(a) and 17.3
236
+ model = AIMAGridModel.new(
237
+ [[-0.04, -0.04, -0.04, +1],
238
+ [-0.04, nil, -0.04, -1],
239
+ [-0.04, -0.04, -0.04, -0.04]],
240
+ [[0, 3], [1, 3]]) # terminals (the +1 and -1 states)
241
+ model.check_transition_probabilities_sum
242
+
243
+ assert_equal Set[
244
+ [0, 0], [0, 1], [0, 2], [0, 3],
245
+ [1, 0], [1, 2], [1, 3],
246
+ [2, 0], [2, 1], [2, 2], [2, 3], :stop], Set[*model.states]
247
+
248
+ assert_equal Set[*%w(^ > v <)], Set[*model.actions([0, 0])]
249
+ assert_equal [:stop], model.actions([1, 3])
250
+ assert_equal [:stop], model.actions(:stop)
251
+
252
+ # check policy against Figure 17.2(a)
253
+ solver = check_grid_solutions model,
254
+ ["> > > ",
255
+ "^ ^ ",
256
+ "^ < < <"]
257
+
258
+ # check the actual (non-pretty) policy
259
+ assert_equal [
260
+ ['>', '>', '>', :stop],
261
+ ['^', nil, '^', :stop],
262
+ ['^', '<', '<', '<']], model.hash_to_grid(solver.policy)
263
+
264
+ # check values against Figure 17.3
265
+ assert [[0.812, 0.868, 0.918, 1],
266
+ [0.762, nil, 0.660, -1],
267
+ [0.705, 0.655, 0.611, 0.388]].flatten.
268
+ zip(model.hash_to_grid(solver.value).flatten).
269
+ all? {|x,y| (x.nil? && y.nil?) || (x-y).abs < 5e-4}
270
+ end
271
+
272
+ def test_aima_grid_2
273
+ # a grid from Figure 17.2(b)
274
+ r = -1.7
275
+ model = AIMAGridModel.new(
276
+ [[ r, r, r, +1],
277
+ [ r, nil, r, -1],
278
+ [ r, r, r, r]],
279
+ [[0, 3], [1, 3]]) # terminals (the +1 and -1 states)
280
+ model.check_transition_probabilities_sum
281
+
282
+ check_grid_solutions model,
283
+ ["> > > ",
284
+ "^ > ",
285
+ "> > > ^"]
286
+ end
287
+
288
+ def test_aima_grid_3
289
+ # a grid from Figure 17.2(b)
290
+ r = -0.3
291
+ model = AIMAGridModel.new(
292
+ [[ r, r, r, +1],
293
+ [ r, nil, r, -1],
294
+ [ r, r, r, r]],
295
+ [[0, 3], [1, 3]]) # terminals (the +1 and -1 states)
296
+ model.check_transition_probabilities_sum
297
+
298
+ check_grid_solutions model,
299
+ ["> > > ",
300
+ "^ ^ ",
301
+ "^ > ^ <"]
302
+ end
303
+
304
+ def test_aima_grid_4
305
+ # a grid from Figure 17.2(b)
306
+ r = -0.01
307
+ model = AIMAGridModel.new(
308
+ [[ r, r, r, +1],
309
+ [ r, nil, r, -1],
310
+ [ r, r, r, r]],
311
+ [[0, 3], [1, 3]]) # terminals (the +1 and -1 states)
312
+ model.check_transition_probabilities_sum
313
+
314
+ check_grid_solutions model,
315
+ ["> > > ",
316
+ "^ < ",
317
+ "^ < < v"]
318
+ end
319
+
320
+ class MyPoint
321
+ include FiniteMDP::VectorValued
322
+
323
+ def initialize x, y
324
+ @x, @y = x, y
325
+ end
326
+
327
+ attr_accessor :x, :y
328
+
329
+ # must implement to_a to make VectorValued work
330
+ def to_a
331
+ [x, y]
332
+ end
333
+ end
334
+
335
+ def test_vector_valued
336
+ p1 = MyPoint.new(0, 0)
337
+ p2 = MyPoint.new(0, 1)
338
+ p3 = MyPoint.new(0, 0)
339
+
340
+ assert !p1.eql?(p2)
341
+ assert !p3.eql?(p2)
342
+ assert p1.eql?(p1)
343
+ assert p1.eql?(p3)
344
+ assert_equal p1.hash, p3.hash
345
+ end
346
+ end
347
+
metadata ADDED
@@ -0,0 +1,94 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: finite_mdp
3
+ version: !ruby/object:Gem::Version
4
+ prerelease:
5
+ version: 0.0.1
6
+ platform: ruby
7
+ authors:
8
+ - John Lees-Miller
9
+ autorequire:
10
+ bindir: bin
11
+ cert_chain: []
12
+
13
+ date: 2011-04-17 00:00:00 Z
14
+ dependencies:
15
+ - !ruby/object:Gem::Dependency
16
+ name: narray
17
+ prerelease: false
18
+ requirement: &id001 !ruby/object:Gem::Requirement
19
+ none: false
20
+ requirements:
21
+ - - ">="
22
+ - !ruby/object:Gem::Version
23
+ version: 0.5.9
24
+ - - ~>
25
+ - !ruby/object:Gem::Version
26
+ version: "0"
27
+ type: :runtime
28
+ version_requirements: *id001
29
+ - !ruby/object:Gem::Dependency
30
+ name: gemma
31
+ prerelease: false
32
+ requirement: &id002 !ruby/object:Gem::Requirement
33
+ none: false
34
+ requirements:
35
+ - - ">="
36
+ - !ruby/object:Gem::Version
37
+ version: 1.0.1
38
+ - - ~>
39
+ - !ruby/object:Gem::Version
40
+ version: "1.0"
41
+ type: :development
42
+ version_requirements: *id002
43
+ description: Solve small finite Markov Decision Process models.
44
+ email:
45
+ - jdleesmiller@gmail.com
46
+ executables: []
47
+
48
+ extensions: []
49
+
50
+ extra_rdoc_files:
51
+ - README.rdoc
52
+ files:
53
+ - lib/finite_mdp/hash_model.rb
54
+ - lib/finite_mdp/vector_valued.rb
55
+ - lib/finite_mdp/model.rb
56
+ - lib/finite_mdp/version.rb
57
+ - lib/finite_mdp/solver.rb
58
+ - lib/finite_mdp/table_model.rb
59
+ - lib/finite_mdp.rb
60
+ - README.rdoc
61
+ - test/finite_mdp_test.rb
62
+ homepage: http://github.com/jdleesmiller/finite_mdp
63
+ licenses: []
64
+
65
+ post_install_message:
66
+ rdoc_options:
67
+ - --main
68
+ - README.rdoc
69
+ - --title
70
+ - finite_mdp-0.0.1 Documentation
71
+ require_paths:
72
+ - lib
73
+ required_ruby_version: !ruby/object:Gem::Requirement
74
+ none: false
75
+ requirements:
76
+ - - ">="
77
+ - !ruby/object:Gem::Version
78
+ version: "0"
79
+ required_rubygems_version: !ruby/object:Gem::Requirement
80
+ none: false
81
+ requirements:
82
+ - - ">="
83
+ - !ruby/object:Gem::Version
84
+ version: "0"
85
+ requirements: []
86
+
87
+ rubyforge_project: finite_mdp
88
+ rubygems_version: 1.7.2
89
+ signing_key:
90
+ specification_version: 3
91
+ summary: Solve small finite Markov Decision Process models.
92
+ test_files:
93
+ - test/finite_mdp_test.rb
94
+ has_rdoc: