finite_mdp 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.rdoc +8 -12
- data/lib/finite_mdp/array_model.rb +226 -0
- data/lib/finite_mdp/hash_model.rb +10 -9
- data/lib/finite_mdp/model.rb +19 -18
- data/lib/finite_mdp/solver.rb +96 -83
- data/lib/finite_mdp/table_model.rb +28 -19
- data/lib/finite_mdp/vector_valued.rb +5 -5
- data/lib/finite_mdp/version.rb +2 -1
- data/lib/finite_mdp.rb +3 -2
- data/test/finite_mdp/finite_mdp_test.rb +151 -98
- metadata +33 -4
    
        data/lib/finite_mdp/solver.rb
    CHANGED
    
    | @@ -1,3 +1,9 @@ | |
| 1 | 
            +
            # frozen_string_literal: true
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # We use A to denote a matrix, which rubocop does not like.
         | 
| 4 | 
            +
            # rubocop:disable Style/MethodName
         | 
| 5 | 
            +
            # rubocop:disable Style/VariableName
         | 
| 6 | 
            +
             | 
| 1 7 | 
             
            require 'narray'
         | 
| 2 8 |  | 
| 3 9 | 
             
            #
         | 
| @@ -29,59 +35,54 @@ class FiniteMDP::Solver | |
| 29 35 | 
             
              # @param [Hash<state, Float>] value initial value for each state; defaults to
         | 
| 30 36 | 
             
              #        zero for every state
         | 
| 31 37 | 
             
              #
         | 
| 32 | 
            -
              def initialize | 
| 33 | 
            -
                @model = model
         | 
| 38 | 
            +
              def initialize(model, discount, policy: nil, value: Hash.new(0))
         | 
| 34 39 | 
             
                @discount = discount
         | 
| 35 40 |  | 
| 36 41 | 
             
                # get the model data into a more compact form for calculation; this means
         | 
| 37 42 | 
             
                # that we number the states and actions for faster lookups (avoid most of
         | 
| 38 | 
            -
                # the hashing) | 
| 39 | 
            -
                 | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
                      pr = model.transition_probability(state, action, next_state)
         | 
| 46 | 
            -
                      [state_to_num[next_state], pr, 
         | 
| 47 | 
            -
                        model.reward(state, action, next_state)] if pr > 0
         | 
| 48 | 
            -
                    }.compact
         | 
| 49 | 
            -
                  }
         | 
| 50 | 
            -
                }
         | 
| 43 | 
            +
                # the hashing)
         | 
| 44 | 
            +
                @model =
         | 
| 45 | 
            +
                  if model.is_a?(FiniteMDP::ArrayModel)
         | 
| 46 | 
            +
                    model
         | 
| 47 | 
            +
                  else
         | 
| 48 | 
            +
                    FiniteMDP::ArrayModel.from_model(model)
         | 
| 49 | 
            +
                  end
         | 
| 51 50 |  | 
| 52 51 | 
             
                # convert initial values and policies to compact form
         | 
| 53 | 
            -
                @array_value =  | 
| 54 | 
            -
                 | 
| 55 | 
            -
                   | 
| 56 | 
            -
                     | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
                   | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
| 62 | 
            -
                  # default to the first action, arbitrarily
         | 
| 63 | 
            -
                  @array_policy = [0]*model_states.size
         | 
| 64 | 
            -
                end
         | 
| 52 | 
            +
                @array_value = @model.states.map { |state| value[state] }
         | 
| 53 | 
            +
                @array_policy =
         | 
| 54 | 
            +
                  if policy
         | 
| 55 | 
            +
                    @model.states.map do |state|
         | 
| 56 | 
            +
                      @model.actions(state).index(policy[state])
         | 
| 57 | 
            +
                    end
         | 
| 58 | 
            +
                  else
         | 
| 59 | 
            +
                    [0] * @model.num_states
         | 
| 60 | 
            +
                  end
         | 
| 65 61 |  | 
| 66 62 | 
             
                raise 'some initial values are missing' if
         | 
| 67 | 
            -
                  @array_value.any? | 
| 63 | 
            +
                  @array_value.any?(&:nil?)
         | 
| 68 64 | 
             
                raise 'some initial policy actions are missing' if
         | 
| 69 | 
            -
                  @array_policy.any? | 
| 65 | 
            +
                  @array_policy.any?(&:nil?)
         | 
| 70 66 |  | 
| 71 67 | 
             
                @policy_A = nil
         | 
| 72 68 | 
             
              end
         | 
| 73 69 |  | 
| 74 70 | 
             
              #
         | 
| 75 | 
            -
              # @return [ | 
| 76 | 
            -
              #         while it is being solved
         | 
| 71 | 
            +
              # @return [ArrayModel] the model being solved; read only; do not change the
         | 
| 72 | 
            +
              #         model while it is being solved
         | 
| 77 73 | 
             
              #
         | 
| 78 74 | 
             
              attr_reader :model
         | 
| 79 75 |  | 
| 80 | 
            -
              # | 
| 76 | 
            +
              #
         | 
| 77 | 
            +
              # @return [Float] discount factor, in (0, 1]
         | 
| 78 | 
            +
              #
         | 
| 79 | 
            +
              attr_reader :discount
         | 
| 80 | 
            +
             | 
| 81 | 
            +
              #
         | 
| 81 82 | 
             
              # Current value estimate for each state.
         | 
| 82 83 | 
             
              #
         | 
| 83 84 | 
             
              # The result is converted from the solver's internal representation, so you
         | 
| 84 | 
            -
              # cannot affect the solver by changing the result. | 
| 85 | 
            +
              # cannot affect the solver by changing the result.
         | 
| 85 86 | 
             
              #
         | 
| 86 87 | 
             
              # @return [Hash<state, Float>] from states to values; read only; any changes
         | 
| 87 88 | 
             
              # made to the returned object will not affect the solver
         | 
| @@ -98,14 +99,12 @@ class FiniteMDP::Solver | |
| 98 99 | 
             
              #
         | 
| 99 100 | 
             
              def state_action_value
         | 
| 100 101 | 
             
                q = {}
         | 
| 101 | 
            -
                states  | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
                     | 
| 107 | 
            -
                      pr * (r + @discount * @array_value[next_state_n])}.inject(:+)
         | 
| 108 | 
            -
                    q[[state, state_actions[action_n]]] = q_sa
         | 
| 102 | 
            +
                model.states.each_with_index do |state, state_n|
         | 
| 103 | 
            +
                  model.actions(state).each_with_index do |action, action_n|
         | 
| 104 | 
            +
                    q_sa = model.array[state_n][action_n].map do |next_state_n, pr, r|
         | 
| 105 | 
            +
                      pr * (r + @discount * @array_value[next_state_n])
         | 
| 106 | 
            +
                    end.inject(:+)
         | 
| 107 | 
            +
                    q[[state, action]] = q_sa
         | 
| 109 108 | 
             
                  end
         | 
| 110 109 | 
             
                end
         | 
| 111 110 | 
             
                q
         | 
| @@ -118,8 +117,9 @@ class FiniteMDP::Solver | |
| 118 117 | 
             
              # made to the returned object will not affect the solver
         | 
| 119 118 | 
             
              #
         | 
| 120 119 | 
             
              def policy
         | 
| 121 | 
            -
                Hash[model.states.zip(@array_policy).map | 
| 122 | 
            -
                  [state, model.actions(state)[action_n]] | 
| 120 | 
            +
                Hash[model.states.zip(@array_policy).map do |state, action_n|
         | 
| 121 | 
            +
                  [state, model.actions(state)[action_n]]
         | 
| 122 | 
            +
                end]
         | 
| 123 123 | 
             
              end
         | 
| 124 124 |  | 
| 125 125 | 
             
              #
         | 
| @@ -135,7 +135,7 @@ class FiniteMDP::Solver | |
| 135 135 | 
             
              #
         | 
| 136 136 | 
             
              def evaluate_policy
         | 
| 137 137 | 
             
                delta = 0.0
         | 
| 138 | 
            -
                 | 
| 138 | 
            +
                model.array.each_with_index do |actions, state_n|
         | 
| 139 139 | 
             
                  next_state_ns = actions[@array_policy[state_n]]
         | 
| 140 140 | 
             
                  new_value = backup(next_state_ns)
         | 
| 141 141 | 
             
                  delta = [delta, (@array_value[state_n] - new_value).abs].max
         | 
| @@ -162,16 +162,16 @@ class FiniteMDP::Solver | |
| 162 162 | 
             
              def evaluate_policy_exact
         | 
| 163 163 | 
             
                if @policy_A
         | 
| 164 164 | 
             
                  # update only those rows for which the policy has changed
         | 
| 165 | 
            -
                  @policy_A_action.zip(@array_policy) | 
| 166 | 
            -
                    each_with_index do |(old_action_n, new_action_n), state_n|
         | 
| 165 | 
            +
                  @policy_A_action.zip(@array_policy)
         | 
| 166 | 
            +
                    .each_with_index do |(old_action_n, new_action_n), state_n|
         | 
| 167 167 | 
             
                    next if old_action_n == new_action_n
         | 
| 168 168 | 
             
                    update_policy_Ab state_n, new_action_n
         | 
| 169 169 | 
             
                  end
         | 
| 170 170 | 
             
                else
         | 
| 171 171 | 
             
                  # initialise the A and the b for Ax = b
         | 
| 172 | 
            -
                  num_states =  | 
| 172 | 
            +
                  num_states = model.num_states
         | 
| 173 173 | 
             
                  @policy_A = NMatrix.float(num_states, num_states)
         | 
| 174 | 
            -
                  @policy_A_action = [-1]*num_states
         | 
| 174 | 
            +
                  @policy_A_action = [-1] * num_states
         | 
| 175 175 | 
             
                  @policy_b = NVector.float(num_states)
         | 
| 176 176 |  | 
| 177 177 | 
             
                  @array_policy.each_with_index do |action_n, state_n|
         | 
| @@ -189,26 +189,30 @@ class FiniteMDP::Solver | |
| 189 189 | 
             
              #
         | 
| 190 190 | 
             
              # This is the 'policy improvement' step in Figure 4.3 of Sutton and Barto
         | 
| 191 191 | 
             
              # (1998).
         | 
| 192 | 
            -
              # 
         | 
| 193 | 
            -
              # @return [Boolean] false iff the policy changed for any state
         | 
| 194 192 | 
             
              #
         | 
| 195 | 
            -
               | 
| 196 | 
            -
             | 
| 197 | 
            -
             | 
| 193 | 
            +
              # @param [Float] tolerance non-negative tolerance; for the policy to change,
         | 
| 194 | 
            +
              #        the action must be at least this much better than the current
         | 
| 195 | 
            +
              #        action
         | 
| 196 | 
            +
              #
         | 
| 197 | 
            +
              # @return [Integer] number of states that changed
         | 
| 198 | 
            +
              #
         | 
| 199 | 
            +
              def improve_policy(tolerance: Float::EPSILON)
         | 
| 200 | 
            +
                changed = 0
         | 
| 201 | 
            +
                model.array.each_with_index do |actions, state_n|
         | 
| 198 202 | 
             
                  a_max = nil
         | 
| 199 203 | 
             
                  v_max = -Float::MAX
         | 
| 200 204 | 
             
                  actions.each_with_index do |next_state_ns, action_n|
         | 
| 201 205 | 
             
                    v = backup(next_state_ns)
         | 
| 202 | 
            -
                    if v > v_max
         | 
| 206 | 
            +
                    if v > v_max + tolerance
         | 
| 203 207 | 
             
                      a_max = action_n
         | 
| 204 208 | 
             
                      v_max = v
         | 
| 205 209 | 
             
                    end
         | 
| 206 210 | 
             
                  end
         | 
| 207 211 | 
             
                  raise "no feasible actions in state #{state_n}" unless a_max
         | 
| 208 | 
            -
                   | 
| 212 | 
            +
                  changed += 1 if @array_policy[state_n] != a_max
         | 
| 209 213 | 
             
                  @array_policy[state_n] = a_max
         | 
| 210 214 | 
             
                end
         | 
| 211 | 
            -
                 | 
| 215 | 
            +
                changed
         | 
| 212 216 | 
             
              end
         | 
| 213 217 |  | 
| 214 218 | 
             
              #
         | 
| @@ -223,7 +227,7 @@ class FiniteMDP::Solver | |
| 223 227 | 
             
              #
         | 
| 224 228 | 
             
              def value_iteration_single
         | 
| 225 229 | 
             
                delta = 0.0
         | 
| 226 | 
            -
                 | 
| 230 | 
            +
                model.array.each_with_index do |actions, state_n|
         | 
| 227 231 | 
             
                  a_max = nil
         | 
| 228 232 | 
             
                  v_max = -Float::MAX
         | 
| 229 233 | 
             
                  actions.each_with_index do |next_state_ns, action_n|
         | 
| @@ -247,7 +251,7 @@ class FiniteMDP::Solver | |
| 247 251 | 
             
              #
         | 
| 248 252 | 
             
              # @param [Float] tolerance small positive number
         | 
| 249 253 | 
             
              #
         | 
| 250 | 
            -
              # @param [Integer, nil] max_iters terminate after this many iterations, even | 
| 254 | 
            +
              # @param [Integer, nil] max_iters terminate after this many iterations, even
         | 
| 251 255 | 
             
              #        if the value function has not converged; nil means that there is
         | 
| 252 256 | 
             
              #        no limit on the number of iterations
         | 
| 253 257 | 
             
              #
         | 
| @@ -260,7 +264,7 @@ class FiniteMDP::Solver | |
| 260 264 | 
             
              # @yieldparam [Float] delta largest change in the value function in the last
         | 
| 261 265 | 
             
              #             iteration
         | 
| 262 266 | 
             
              #
         | 
| 263 | 
            -
              def value_iteration | 
| 267 | 
            +
              def value_iteration(tolerance:, max_iters: nil)
         | 
| 264 268 | 
             
                delta = Float::MAX
         | 
| 265 269 | 
             
                num_iters = 0
         | 
| 266 270 | 
             
                loop do
         | 
| @@ -281,6 +285,11 @@ class FiniteMDP::Solver | |
| 281 285 | 
             
              #        phase ends if the largest change in the value function
         | 
| 282 286 | 
             
              #        (<tt>delta</tt>) is below this tolerance
         | 
| 283 287 | 
             
              #
         | 
| 288 | 
            +
              # @param [Float] policy_tolerance small positive number; when comparing
         | 
| 289 | 
            +
              #        actions during policy improvement, ignore value function differences
         | 
| 290 | 
            +
              #        smaller than this tolerance; this helps with convergence when there
         | 
| 291 | 
            +
              #        are several equivalent or extremely similar actions
         | 
| 292 | 
            +
              #
         | 
| 284 293 | 
             
              # @param [Integer, nil] max_value_iters terminate the policy evaluation
         | 
| 285 294 | 
             
              #        phase after this many iterations, even if the value function has not
         | 
| 286 295 | 
             
              #        converged; nil means that there is no limit on the number of
         | 
| @@ -298,16 +307,20 @@ class FiniteMDP::Solver | |
| 298 307 | 
             
              # @yieldparam [Integer] num_policy_iters policy improvement iterations done so
         | 
| 299 308 | 
             
              #             far
         | 
| 300 309 | 
             
              #
         | 
| 310 | 
            +
              # @yieldparam [Integer?] actions_changed number of actions that changed in
         | 
| 311 | 
            +
              #             the policy improvement phase, if any
         | 
| 312 | 
            +
              #
         | 
| 301 313 | 
             
              # @yieldparam [Integer] num_value_iters policy evaluation iterations done so
         | 
| 302 314 | 
             
              #             far for the current policy improvement iteration
         | 
| 303 315 | 
             
              #
         | 
| 304 316 | 
             
              # @yieldparam [Float] delta largest change in the value function in the last
         | 
| 305 317 | 
             
              #             policy evaluation iteration
         | 
| 306 318 | 
             
              #
         | 
| 307 | 
            -
              def policy_iteration | 
| 308 | 
            -
                 | 
| 319 | 
            +
              def policy_iteration(value_tolerance:,
         | 
| 320 | 
            +
                policy_tolerance: value_tolerance / 2.0, max_value_iters: nil,
         | 
| 321 | 
            +
                max_policy_iters: nil)
         | 
| 309 322 |  | 
| 310 | 
            -
                 | 
| 323 | 
            +
                num_actions_changed = nil
         | 
| 311 324 | 
             
                num_policy_iters = 0
         | 
| 312 325 | 
             
                loop do
         | 
| 313 326 | 
             
                  # policy evaluation
         | 
| @@ -315,19 +328,19 @@ class FiniteMDP::Solver | |
| 315 328 | 
             
                  loop do
         | 
| 316 329 | 
             
                    value_delta = evaluate_policy
         | 
| 317 330 | 
             
                    num_value_iters += 1
         | 
| 331 | 
            +
                    yield(num_policy_iters, num_actions_changed, num_value_iters,
         | 
| 332 | 
            +
                      value_delta) if block_given?
         | 
| 318 333 |  | 
| 319 334 | 
             
                    break if value_delta < value_tolerance
         | 
| 320 335 | 
             
                    break if max_value_iters && num_value_iters >= max_value_iters
         | 
| 321 | 
            -
                    yield num_policy_iters, num_value_iters, value_delta if block_given?
         | 
| 322 336 | 
             
                  end
         | 
| 323 337 |  | 
| 324 338 | 
             
                  # policy improvement
         | 
| 325 | 
            -
                   | 
| 339 | 
            +
                  num_actions_changed = improve_policy(tolerance: policy_tolerance)
         | 
| 326 340 | 
             
                  num_policy_iters += 1
         | 
| 327 | 
            -
                   | 
| 328 | 
            -
                   | 
| 341 | 
            +
                  return true if num_actions_changed == 0
         | 
| 342 | 
            +
                  return false if max_policy_iters && num_policy_iters >= max_policy_iters
         | 
| 329 343 | 
             
                end
         | 
| 330 | 
            -
                stable
         | 
| 331 344 | 
             
              end
         | 
| 332 345 |  | 
| 333 346 | 
             
              #
         | 
| @@ -343,18 +356,19 @@ class FiniteMDP::Solver | |
| 343 356 | 
             
              #
         | 
| 344 357 | 
             
              # @yieldparam [Integer] num_iters policy improvement iterations done so far
         | 
| 345 358 | 
             
              #
         | 
| 346 | 
            -
               | 
| 347 | 
            -
             | 
| 359 | 
            +
              # @yieldparam [Integer] num_actions_changed number of actions that changed in
         | 
| 360 | 
            +
              #             the last policy improvement phase
         | 
| 361 | 
            +
              #
         | 
| 362 | 
            +
              def policy_iteration_exact(max_iters: nil)
         | 
| 348 363 | 
             
                num_iters = 0
         | 
| 349 364 | 
             
                loop do
         | 
| 350 365 | 
             
                  evaluate_policy_exact
         | 
| 351 | 
            -
                   | 
| 366 | 
            +
                  num_actions_changed = improve_policy
         | 
| 352 367 | 
             
                  num_iters += 1
         | 
| 353 | 
            -
                   | 
| 354 | 
            -
                   | 
| 355 | 
            -
                   | 
| 368 | 
            +
                  yield num_iters, num_actions_changed if block_given?
         | 
| 369 | 
            +
                  return true if num_actions_changed == 0
         | 
| 370 | 
            +
                  return false if max_iters && num_iters >= max_iters
         | 
| 356 371 | 
             
                end
         | 
| 357 | 
            -
                stable
         | 
| 358 372 | 
             
              end
         | 
| 359 373 |  | 
| 360 374 | 
             
              private
         | 
| @@ -362,30 +376,29 @@ class FiniteMDP::Solver | |
| 362 376 | 
             
              #
         | 
| 363 377 | 
             
              # Updated value estimate for a state with the given successor states.
         | 
| 364 378 | 
             
              #
         | 
| 365 | 
            -
              def backup | 
| 366 | 
            -
                next_state_ns.map  | 
| 367 | 
            -
                  probability*(reward + @discount | 
| 368 | 
            -
                 | 
| 379 | 
            +
              def backup(next_state_ns)
         | 
| 380 | 
            +
                next_state_ns.map do |next_state_n, probability, reward|
         | 
| 381 | 
            +
                  probability * (reward + @discount * @array_value[next_state_n])
         | 
| 382 | 
            +
                end.inject(:+)
         | 
| 369 383 | 
             
              end
         | 
| 370 384 |  | 
| 371 385 | 
             
              #
         | 
| 372 386 | 
             
              # Update the row in A the entry in b (in Ax=b) for the given state; see
         | 
| 373 387 | 
             
              # {#evaluate_policy_exact}.
         | 
| 374 388 | 
             
              #
         | 
| 375 | 
            -
              def update_policy_Ab | 
| 389 | 
            +
              def update_policy_Ab(state_n, action_n)
         | 
| 376 390 | 
             
                # clear out the old values for state_n's row
         | 
| 377 391 | 
             
                @policy_A[true, state_n] = 0.0
         | 
| 378 392 |  | 
| 379 393 | 
             
                # set new values according to state_n's successors under the current policy
         | 
| 380 394 | 
             
                b_n = 0
         | 
| 381 | 
            -
                next_state_ns =  | 
| 395 | 
            +
                next_state_ns = model.array[state_n][action_n]
         | 
| 382 396 | 
             
                next_state_ns.each do |next_state_n, probability, reward|
         | 
| 383 | 
            -
                  @policy_A[next_state_n, state_n] = -@discount*probability
         | 
| 384 | 
            -
                  b_n += probability*reward
         | 
| 397 | 
            +
                  @policy_A[next_state_n, state_n] = -@discount * probability
         | 
| 398 | 
            +
                  b_n += probability * reward
         | 
| 385 399 | 
             
                end
         | 
| 386 400 | 
             
                @policy_A[state_n, state_n] += 1
         | 
| 387 401 | 
             
                @policy_A_action[state_n] = action_n
         | 
| 388 402 | 
             
                @policy_b[state_n] = b_n
         | 
| 389 403 | 
             
              end
         | 
| 390 404 | 
             
            end
         | 
| 391 | 
            -
             | 
| @@ -1,3 +1,4 @@ | |
| 1 | 
            +
            # frozen_string_literal: true
         | 
| 1 2 | 
             
            #
         | 
| 2 3 | 
             
            # A finite markov decision process model for which the states, actions,
         | 
| 3 4 | 
             
            # transition probabilities and rewards are specified as a table. This is a
         | 
| @@ -12,7 +13,7 @@ class FiniteMDP::TableModel | |
| 12 13 | 
             
              # @param [Array<[state, action, state, Float, Float]>] rows each row is
         | 
| 13 14 | 
             
              #  [state, action, next state, probability, reward]
         | 
| 14 15 | 
             
              #
         | 
| 15 | 
            -
              def initialize | 
| 16 | 
            +
              def initialize(rows)
         | 
| 16 17 | 
             
                @rows = rows
         | 
| 17 18 | 
             
              end
         | 
| 18 19 |  | 
| @@ -28,7 +29,7 @@ class FiniteMDP::TableModel | |
| 28 29 | 
             
              # @return [Array<state>] not empty; no duplicate states
         | 
| 29 30 | 
             
              #
         | 
| 30 31 | 
             
              def states
         | 
| 31 | 
            -
                @rows.map{|row| row[0]}.uniq
         | 
| 32 | 
            +
                @rows.map { |row| row[0] }.uniq
         | 
| 32 33 | 
             
              end
         | 
| 33 34 |  | 
| 34 35 | 
             
              #
         | 
| @@ -38,23 +39,23 @@ class FiniteMDP::TableModel | |
| 38 39 | 
             
              #
         | 
| 39 40 | 
             
              # @return [Array<action>] not empty; no duplicate actions
         | 
| 40 41 | 
             
              #
         | 
| 41 | 
            -
              def actions | 
| 42 | 
            -
                @rows.map{|row| row[1] if row[0] == state}.compact.uniq
         | 
| 42 | 
            +
              def actions(state)
         | 
| 43 | 
            +
                @rows.map { |row| row[1] if row[0] == state }.compact.uniq
         | 
| 43 44 | 
             
              end
         | 
| 44 45 |  | 
| 45 46 | 
             
              #
         | 
| 46 47 | 
             
              # Possible successor states after taking the given action in the given state;
         | 
| 47 48 | 
             
              # see {Model#next_states}.
         | 
| 48 | 
            -
              # | 
| 49 | 
            +
              #
         | 
| 49 50 | 
             
              # @param [state] state
         | 
| 50 51 | 
             
              #
         | 
| 51 52 | 
             
              # @param [action] action
         | 
| 52 53 | 
             
              #
         | 
| 53 54 | 
             
              # @return [Array<state>] not empty; no duplicate states
         | 
| 54 55 | 
             
              #
         | 
| 55 | 
            -
              def next_states | 
| 56 | 
            -
                @rows.map{|row| row[2] if row[0] == state && row[1] == action}.compact
         | 
| 57 | 
            -
              end | 
| 56 | 
            +
              def next_states(state, action)
         | 
| 57 | 
            +
                @rows.map { |row| row[2] if row[0] == state && row[1] == action }.compact
         | 
| 58 | 
            +
              end
         | 
| 58 59 |  | 
| 59 60 | 
             
              #
         | 
| 60 61 | 
             
              # Probability of the given transition; see {Model#transition_probability}.
         | 
| @@ -66,10 +67,10 @@ class FiniteMDP::TableModel | |
| 66 67 | 
             
              # @param [state] next_state
         | 
| 67 68 | 
             
              #
         | 
| 68 69 | 
             
              # @return [Float] in [0, 1]; zero if the transition is not in the table
         | 
| 69 | 
            -
              # | 
| 70 | 
            -
              def transition_probability | 
| 71 | 
            -
                 | 
| 72 | 
            -
             | 
| 70 | 
            +
              #
         | 
| 71 | 
            +
              def transition_probability(state, action, next_state)
         | 
| 72 | 
            +
                row = find_row(state, action, next_state)
         | 
| 73 | 
            +
                row ? row[3] : 0
         | 
| 73 74 | 
             
              end
         | 
| 74 75 |  | 
| 75 76 | 
             
              #
         | 
| @@ -83,9 +84,9 @@ class FiniteMDP::TableModel | |
| 83 84 | 
             
              #
         | 
| 84 85 | 
             
              # @return [Float, nil] nil if the transition is not in the table
         | 
| 85 86 | 
             
              #
         | 
| 86 | 
            -
              def reward | 
| 87 | 
            -
                 | 
| 88 | 
            -
             | 
| 87 | 
            +
              def reward(state, action, next_state)
         | 
| 88 | 
            +
                row = find_row(state, action, next_state)
         | 
| 89 | 
            +
                row[4] if row
         | 
| 89 90 | 
             
              end
         | 
| 90 91 |  | 
| 91 92 | 
             
              #
         | 
| @@ -105,18 +106,26 @@ class FiniteMDP::TableModel | |
| 105 106 | 
             
              #
         | 
| 106 107 | 
             
              # @return [TableModel]
         | 
| 107 108 | 
             
              #
         | 
| 108 | 
            -
              def self.from_model | 
| 109 | 
            +
              def self.from_model(model, sparse = true)
         | 
| 109 110 | 
             
                rows = []
         | 
| 110 111 | 
             
                model.states.each do |state|
         | 
| 111 112 | 
             
                  model.actions(state).each do |action|
         | 
| 112 113 | 
             
                    model.next_states(state, action).each do |next_state|
         | 
| 113 114 | 
             
                      pr = model.transition_probability(state, action, next_state)
         | 
| 114 | 
            -
                       | 
| 115 | 
            -
             | 
| 115 | 
            +
                      next unless pr > 0 || !sparse
         | 
| 116 | 
            +
                      reward = model.reward(state, action, next_state)
         | 
| 117 | 
            +
                      rows << [state, action, next_state, pr, reward]
         | 
| 116 118 | 
             
                    end
         | 
| 117 119 | 
             
                  end
         | 
| 118 120 | 
             
                end
         | 
| 119 121 | 
             
                FiniteMDP::TableModel.new(rows)
         | 
| 120 122 | 
             
              end
         | 
| 121 | 
            -
            end
         | 
| 122 123 |  | 
| 124 | 
            +
              private
         | 
| 125 | 
            +
             | 
| 126 | 
            +
              def find_row(state, action, next_state)
         | 
| 127 | 
            +
                @rows.find do |row|
         | 
| 128 | 
            +
                  row[0] == state && row[1] == action && row[2] == next_state
         | 
| 129 | 
            +
                end
         | 
| 130 | 
            +
              end
         | 
| 131 | 
            +
            end
         | 
| @@ -1,3 +1,4 @@ | |
| 1 | 
            +
            # frozen_string_literal: true
         | 
| 1 2 | 
             
            #
         | 
| 2 3 | 
             
            # Define an object's hash code and equality (in the sense of <tt>eql?</tt>)
         | 
| 3 4 | 
             
            # according to its array representation (<tt>to_a</tt>). See notes for {Model}
         | 
| @@ -7,7 +8,7 @@ | |
| 7 8 | 
             
            #
         | 
| 8 9 | 
             
            # @example
         | 
| 9 10 | 
             
            #
         | 
| 10 | 
            -
            #   class MyPoint | 
| 11 | 
            +
            #   class MyPoint
         | 
| 11 12 | 
             
            #     include FiniteMDP::VectorValued
         | 
| 12 13 | 
             
            #
         | 
| 13 14 | 
             
            #     def initialize x, y
         | 
| @@ -31,7 +32,7 @@ module FiniteMDP::VectorValued | |
| 31 32 | 
             
              # @return [Integer]
         | 
| 32 33 | 
             
              #
         | 
| 33 34 | 
             
              def hash
         | 
| 34 | 
            -
                 | 
| 35 | 
            +
                to_a.hash
         | 
| 35 36 | 
             
              end
         | 
| 36 37 |  | 
| 37 38 | 
             
              #
         | 
| @@ -39,8 +40,7 @@ module FiniteMDP::VectorValued | |
| 39 40 | 
             
              #
         | 
| 40 41 | 
             
              # @return [Boolean]
         | 
| 41 42 | 
             
              #
         | 
| 42 | 
            -
              def eql? | 
| 43 | 
            -
                 | 
| 43 | 
            +
              def eql?(other)
         | 
| 44 | 
            +
                to_a.eql? other.to_a
         | 
| 44 45 | 
             
              end
         | 
| 45 46 | 
             
            end
         | 
| 46 | 
            -
             | 
    
        data/lib/finite_mdp/version.rb
    CHANGED
    
    
    
        data/lib/finite_mdp.rb
    CHANGED
    
    | @@ -1,14 +1,15 @@ | |
| 1 | 
            +
            # frozen_string_literal: true
         | 
| 1 2 | 
             
            require 'enumerator'
         | 
| 2 3 |  | 
| 3 4 | 
             
            require 'finite_mdp/version'
         | 
| 4 5 | 
             
            require 'finite_mdp/vector_valued'
         | 
| 5 6 | 
             
            require 'finite_mdp/model'
         | 
| 7 | 
            +
            require 'finite_mdp/array_model'
         | 
| 6 8 | 
             
            require 'finite_mdp/hash_model'
         | 
| 7 9 | 
             
            require 'finite_mdp/table_model'
         | 
| 8 10 | 
             
            require 'finite_mdp/solver'
         | 
| 9 11 |  | 
| 10 | 
            -
            # TODO maybe for efficiency it would be worth including a special case for
         | 
| 12 | 
            +
            # TODO: maybe for efficiency it would be worth including a special case for
         | 
| 11 13 | 
             
            # models in which rewards depend only on the state -- a few minor
         | 
| 12 14 | 
             
            # simplifications are possible in the solver, but it won't make a huge
         | 
| 13 15 | 
             
            # difference.
         | 
| 14 | 
            -
             |