pyrlutils 0.0.1__tar.gz → 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyrlutils-0.0.1/pyrlutils.egg-info → pyrlutils-0.0.2}/PKG-INFO +2 -2
- pyrlutils-0.0.2/pyrlutils/policy.py +84 -0
- pyrlutils-0.0.1/pyrlutils/values.py → pyrlutils-0.0.2/pyrlutils/reward.py +2 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils/state.py +28 -2
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils/transition.py +66 -26
- pyrlutils-0.0.2/pyrlutils/valuefcns.py +144 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2/pyrlutils.egg-info}/PKG-INFO +2 -2
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils.egg-info/SOURCES.txt +5 -1
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils.egg-info/requires.txt +1 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/setup.py +2 -2
- pyrlutils-0.0.2/test/test_2ddiscrete.py +20 -0
- pyrlutils-0.0.2/test/test_2dmaze.py +341 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/test/test_action.py +0 -2
- pyrlutils-0.0.2/test/test_frozenlake.py +29 -0
- pyrlutils-0.0.1/pyrlutils/policy.py +0 -34
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/LICENSE +0 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/MANIFEST.in +0 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/README.md +0 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils/__init__.py +0 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils/action.py +0 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils.egg-info/dependency_links.txt +0 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils.egg-info/not-zip-safe +0 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils.egg-info/top_level.txt +0 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/setup.cfg +0 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/test/test_continous_state_actions.py +0 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/test/test_state.py +0 -0
- {pyrlutils-0.0.1 → pyrlutils-0.0.2}/test/test_transprobs.py +0 -0
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: pyrlutils
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.2
|
|
4
4
|
Summary: Utility and Helpers for Reinformcement Learning
|
|
5
5
|
Home-page: https://github.com/stephenhky/PyRLUtils
|
|
6
6
|
Author: Kwan-Yuet Ho
|
|
7
7
|
Author-email: stephenhky@yahoo.com.hk
|
|
8
|
-
License:
|
|
8
|
+
License: MIT
|
|
9
9
|
Keywords: machine learning,reinforcement leaning,artifiial intelligence
|
|
10
10
|
Platform: UNKNOWN
|
|
11
11
|
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Union, Dict
|
|
4
|
+
from warnings import warn
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from .state import State, DiscreteState, DiscreteStateValueType
|
|
9
|
+
from .action import Action, DiscreteActionValueType
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Policy(ABC):
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def get_action(self, state: State) -> Action:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
def __call__(self, state: State) -> Action:
|
|
18
|
+
return self.get_action(state)
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def is_stochastic(self) -> bool:
|
|
22
|
+
raise NotImplemented()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DeterministicPolicy(Policy):
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def add_deterministic_rule(self, *args, **kwargs):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def is_stochastic(self) -> bool:
|
|
32
|
+
return False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class DiscreteDeterminsticPolicy(DeterministicPolicy):
|
|
36
|
+
def __init__(self, actions_dict: Dict[DiscreteActionValueType, Action]):
|
|
37
|
+
self._state_to_action = {}
|
|
38
|
+
self._actions_dict = actions_dict
|
|
39
|
+
|
|
40
|
+
def add_deterministic_rule(self, state_value: DiscreteStateValueType, action_value: DiscreteActionValueType):
|
|
41
|
+
if state_value in self._state_to_action:
|
|
42
|
+
warn('State value {} exists in rule; it will be replaced.'.format(state_value))
|
|
43
|
+
self._state_to_action[state_value] = action_value
|
|
44
|
+
|
|
45
|
+
def get_action_value(self, state_value: DiscreteStateValueType) -> DiscreteActionValueType:
|
|
46
|
+
return self._state_to_action.get(state_value)
|
|
47
|
+
|
|
48
|
+
def get_action(self, state: DiscreteState) -> Action:
|
|
49
|
+
return self._actions_dict[self.get_action_value(state.state_value)]
|
|
50
|
+
|
|
51
|
+
def __eq__(self, other) -> bool:
|
|
52
|
+
if len(self._state_to_action) != len(set(self._state_to_action.keys()).union(other._state_to_action.keys())):
|
|
53
|
+
return False
|
|
54
|
+
if len(self._actions_dict) != len(set(self._actions_dict.keys()).union(other._actions_dict.keys())):
|
|
55
|
+
return False
|
|
56
|
+
for action in self._actions_dict.keys():
|
|
57
|
+
if self._actions_dict[action] != other._actions_dict[action]:
|
|
58
|
+
return False
|
|
59
|
+
for state in self._state_to_action.keys():
|
|
60
|
+
if self._state_to_action[state] != other._state_to_action[state]:
|
|
61
|
+
return False
|
|
62
|
+
return True
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class StochasticPolicy(Policy):
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def get_probability(self, *args, **kwargs) -> float:
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def is_stochastic(self) -> bool:
|
|
72
|
+
return True
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class DiscreteStochasticPolicy(StochasticPolicy):
|
|
76
|
+
@abstractmethod
|
|
77
|
+
def get_probability(self, state_value: DiscreteStateValueType, action_value: DiscreteActionValueType) -> float:
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class ContinuousStochasticPolicy(StochasticPolicy):
|
|
82
|
+
@abstractmethod
|
|
83
|
+
def get_probability(self, state_value: Union[float, np.ndarray], action_value: DiscreteActionValueType, value: Union[float, np.ndarray]) -> float:
|
|
84
|
+
pass
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
-
from typing import List, Optional, Union
|
|
3
|
+
from typing import Tuple, List, Optional, Union
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
@@ -23,7 +23,7 @@ class State(ABC):
|
|
|
23
23
|
self.set_state_value(new_state_value)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
DiscreteStateValueType = Union[float, str]
|
|
26
|
+
DiscreteStateValueType = Union[float, str, Tuple[int]]
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class DiscreteState(State):
|
|
@@ -52,6 +52,10 @@ class DiscreteState(State):
|
|
|
52
52
|
def state_value(self, new_state_value: DiscreteStateValueType):
|
|
53
53
|
self.set_state_value(new_state_value)
|
|
54
54
|
|
|
55
|
+
@property
|
|
56
|
+
def state_space_size(self):
|
|
57
|
+
return len(self._all_state_values)
|
|
58
|
+
|
|
55
59
|
|
|
56
60
|
class InvalidRangeError(Exception):
|
|
57
61
|
def __init__(self, message=None):
|
|
@@ -168,3 +172,25 @@ class ContinuousState(State):
|
|
|
168
172
|
return self._nbdims
|
|
169
173
|
|
|
170
174
|
|
|
175
|
+
class Discrete2DCartesianState(DiscreteState):
|
|
176
|
+
def __init__(self, x_lowlim: int, x_hilim: int, y_lowlim: int, y_hilim: int, initial_coordinate: List[int]=None):
|
|
177
|
+
self._x_lowlim = x_lowlim
|
|
178
|
+
self._x_hilim = x_hilim
|
|
179
|
+
self._y_lowlim = y_lowlim
|
|
180
|
+
self._y_hilim = y_hilim
|
|
181
|
+
self._countx = self._x_hilim - self._x_lowlim + 1
|
|
182
|
+
self._county = self._y_hilim - self._y_lowlim + 1
|
|
183
|
+
if initial_coordinate is None:
|
|
184
|
+
initial_coordinate = [self._x_lowlim, self._y_lowlim]
|
|
185
|
+
initial_value = (initial_coordinate[1] - self._y_lowlim) * self._countx + (initial_coordinate[0] - self._x_lowlim)
|
|
186
|
+
super().__init__(list(range(self._countx*self._county)), initial_values=initial_value)
|
|
187
|
+
|
|
188
|
+
def _encode_coordinates(self, x, y) -> int:
|
|
189
|
+
return (y - self._y_lowlim) * self._countx + (x - self._x_lowlim)
|
|
190
|
+
|
|
191
|
+
def encode_coordinates(self, coordinates: List[int]) -> int:
|
|
192
|
+
assert len(coordinates) == 2
|
|
193
|
+
return self._encode_coordinates(coordinates[0], coordinates[1])
|
|
194
|
+
|
|
195
|
+
def decode_coordinates(self, hashcode) -> List[int]:
|
|
196
|
+
return [hashcode % self._countx, hashcode // self._countx]
|
|
@@ -3,9 +3,10 @@ from types import LambdaType
|
|
|
3
3
|
from typing import Tuple, Dict
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
|
+
import gym
|
|
6
7
|
|
|
7
8
|
from .state import DiscreteState, DiscreteStateValueType
|
|
8
|
-
from .
|
|
9
|
+
from .reward import IndividualRewardFunction
|
|
9
10
|
from .action import Action, DiscreteActionValueType
|
|
10
11
|
|
|
11
12
|
|
|
@@ -35,22 +36,22 @@ class NextStateTuple:
|
|
|
35
36
|
|
|
36
37
|
class TransitionProbabilityFactory:
|
|
37
38
|
def __init__(self):
|
|
38
|
-
self.
|
|
39
|
-
self.
|
|
40
|
-
self.
|
|
41
|
-
self.
|
|
39
|
+
self._transprobs = {}
|
|
40
|
+
self._all_state_values = []
|
|
41
|
+
self._all_action_values = []
|
|
42
|
+
self._objects_generated = False
|
|
42
43
|
|
|
43
44
|
def add_state_transitions(self, state_value: DiscreteStateValueType, action_values_to_next_state: dict):
|
|
44
|
-
if state_value not in self.
|
|
45
|
-
self.
|
|
45
|
+
if state_value not in self._all_state_values:
|
|
46
|
+
self._all_state_values.append(state_value)
|
|
46
47
|
|
|
47
48
|
this_state_transition_dict = {}
|
|
48
49
|
|
|
49
50
|
for action_value, next_state_tuples in action_values_to_next_state.items():
|
|
50
51
|
this_state_transition_dict[action_value] = []
|
|
51
52
|
for next_state_tuple in next_state_tuples:
|
|
52
|
-
if action_value not in self.
|
|
53
|
-
self.
|
|
53
|
+
if action_value not in self._all_action_values:
|
|
54
|
+
self._all_action_values.append(action_value)
|
|
54
55
|
if not isinstance(next_state_tuple, NextStateTuple):
|
|
55
56
|
if isinstance(next_state_tuple, dict):
|
|
56
57
|
next_state_tuple = NextStateTuple(
|
|
@@ -62,16 +63,16 @@ class TransitionProbabilityFactory:
|
|
|
62
63
|
else:
|
|
63
64
|
raise TypeError('"action_values_to_next_state" has to be a dictionary or NextStateTuple instance.')
|
|
64
65
|
|
|
65
|
-
if next_state_tuple.next_state_value not in self.
|
|
66
|
-
self.
|
|
66
|
+
if next_state_tuple.next_state_value not in self._all_state_values:
|
|
67
|
+
self._all_state_values.append(next_state_tuple.next_state_value)
|
|
67
68
|
|
|
68
69
|
this_state_transition_dict[action_value].append(next_state_tuple)
|
|
69
70
|
|
|
70
|
-
self.
|
|
71
|
+
self._transprobs[state_value] = this_state_transition_dict
|
|
71
72
|
|
|
72
73
|
def _get_probs_for_eachstate(self, action_value: DiscreteActionValueType) -> Dict[DiscreteStateValueType, NextStateTuple]:
|
|
73
74
|
state_nexttuples = {}
|
|
74
|
-
for state_value, action_nexttuples_pair in self.
|
|
75
|
+
for state_value, action_nexttuples_pair in self._transprobs.items():
|
|
75
76
|
for this_action_value, nexttuples in action_nexttuples_pair.items():
|
|
76
77
|
if this_action_value == action_value:
|
|
77
78
|
state_nexttuples[state_value] = nexttuples
|
|
@@ -92,16 +93,17 @@ class TransitionProbabilityFactory:
|
|
|
92
93
|
def _generate_individual_reward_function(self) -> IndividualRewardFunction:
|
|
93
94
|
|
|
94
95
|
def _individual_reward_function(state_value, action_value, next_state_value) -> float:
|
|
95
|
-
if state_value in self.
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
96
|
+
if state_value not in self._transprobs.keys():
|
|
97
|
+
return 0.
|
|
98
|
+
|
|
99
|
+
if action_value not in self._transprobs[state_value].keys():
|
|
100
|
+
return 0.
|
|
101
|
+
|
|
102
|
+
reward = 0.
|
|
103
|
+
for next_tuple in self._transprobs[state_value][action_value]:
|
|
104
|
+
if next_tuple.next_state_value == next_state_value:
|
|
105
|
+
reward += next_tuple.reward
|
|
106
|
+
return reward
|
|
105
107
|
|
|
106
108
|
class ThisIndividualRewardFunction(IndividualRewardFunction):
|
|
107
109
|
def __init__(self):
|
|
@@ -112,10 +114,27 @@ class TransitionProbabilityFactory:
|
|
|
112
114
|
|
|
113
115
|
return ThisIndividualRewardFunction()
|
|
114
116
|
|
|
115
|
-
def
|
|
116
|
-
|
|
117
|
+
def get_probability(self, state_value, action_value, new_state_value) -> float:
|
|
118
|
+
if state_value not in self._transprobs.keys():
|
|
119
|
+
return 0.
|
|
120
|
+
|
|
121
|
+
if action_value not in self._transprobs[state_value]:
|
|
122
|
+
return 0.
|
|
123
|
+
|
|
124
|
+
probs = 0.
|
|
125
|
+
for next_state_tuple in self._transprobs[state_value][action_value]:
|
|
126
|
+
if next_state_tuple.next_state_value == new_state_value:
|
|
127
|
+
probs += next_state_tuple.probability
|
|
128
|
+
return probs
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def transition_probabilities(self) -> dict:
|
|
132
|
+
return self._transprobs
|
|
133
|
+
|
|
134
|
+
def generate_mdp_objects(self) -> Tuple[DiscreteState, Dict[DiscreteActionValueType, Action], IndividualRewardFunction]:
|
|
135
|
+
state = DiscreteState(self._all_state_values)
|
|
117
136
|
actions_dict = {}
|
|
118
|
-
for action_value in self.
|
|
137
|
+
for action_value in self._all_action_values:
|
|
119
138
|
state_nexttuple = self._get_probs_for_eachstate(action_value)
|
|
120
139
|
actions_dict[action_value] = Action(self._generate_action_function(state_nexttuple))
|
|
121
140
|
|
|
@@ -123,3 +142,24 @@ class TransitionProbabilityFactory:
|
|
|
123
142
|
|
|
124
143
|
return state, actions_dict, individual_reward_fcn
|
|
125
144
|
|
|
145
|
+
@property
|
|
146
|
+
def objects_generated(self) -> bool:
|
|
147
|
+
return self._objects_generated
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory(TransitionProbabilityFactory):
|
|
151
|
+
def __init__(self, envname):
|
|
152
|
+
super().__init__()
|
|
153
|
+
self.gymenv = gym.make(envname)
|
|
154
|
+
self._convert_openai_gymenv_to_transprob()
|
|
155
|
+
|
|
156
|
+
def _convert_openai_gymenv_to_transprob(self):
|
|
157
|
+
P = self.gymenv.env.P
|
|
158
|
+
for state_value, trans_dict in P.items():
|
|
159
|
+
new_trans_dict = {}
|
|
160
|
+
for action_value, next_state_list in trans_dict.items():
|
|
161
|
+
new_trans_dict[action_value] = [
|
|
162
|
+
NextStateTuple(next_state[1], next_state[0], next_state[2], next_state[3])
|
|
163
|
+
for next_state in next_state_list
|
|
164
|
+
]
|
|
165
|
+
self.add_state_transitions(state_value, new_trans_dict)
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
|
|
2
|
+
import random
|
|
3
|
+
from copy import copy
|
|
4
|
+
from typing import Tuple, Dict
|
|
5
|
+
from itertools import product
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from .state import DiscreteStateValueType
|
|
10
|
+
from .transition import TransitionProbabilityFactory
|
|
11
|
+
from .policy import DiscreteDeterminsticPolicy
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OptimalPolicyOnValueFunctions:
|
|
15
|
+
def __init__(self, discount_factor: float, transprobfac: TransitionProbabilityFactory):
|
|
16
|
+
try:
|
|
17
|
+
assert discount_factor >= 0. and discount_factor <= 1.
|
|
18
|
+
except AssertionError:
|
|
19
|
+
raise ValueError('Discount factor must be between 0 and 1.')
|
|
20
|
+
self._gamma = discount_factor
|
|
21
|
+
self._transprobfac = transprobfac
|
|
22
|
+
self._states, self._actions_dict, self._indrewardfcn = self._transprobfac.generate_mdp_objects()
|
|
23
|
+
self._state_names = self._states.get_all_possible_state_values()
|
|
24
|
+
self._states_to_indices = {state: idx for idx, state in enumerate(self._state_names)}
|
|
25
|
+
self._action_names = list(self._actions_dict.keys())
|
|
26
|
+
self._actions_to_indices = {action_value: idx for idx, action_value in enumerate(self._action_names)}
|
|
27
|
+
|
|
28
|
+
self._evaluated = False
|
|
29
|
+
self._improved = False
|
|
30
|
+
|
|
31
|
+
self._theta = 1e-10
|
|
32
|
+
self._policy_evaluation_maxiter = 10000
|
|
33
|
+
|
|
34
|
+
def _policy_evaluation(self, policy: DiscreteDeterminsticPolicy) -> np.ndarray:
|
|
35
|
+
prev_V = np.zeros(len(self._states_to_indices))
|
|
36
|
+
|
|
37
|
+
for _ in range(self._policy_evaluation_maxiter):
|
|
38
|
+
V = np.zeros(len(self._states_to_indices))
|
|
39
|
+
for state_value in self._state_names:
|
|
40
|
+
state_index = self._states_to_indices[state_value]
|
|
41
|
+
action_value = policy.get_action_value(state_value)
|
|
42
|
+
for next_state_tuple in self._transprobfac.transition_probabilities[state_value][action_value]:
|
|
43
|
+
prob = next_state_tuple.probability
|
|
44
|
+
reward = next_state_tuple.reward
|
|
45
|
+
next_state_value = next_state_tuple.next_state_value
|
|
46
|
+
next_state_index = self._states_to_indices[next_state_value]
|
|
47
|
+
terminal = next_state_tuple.terminal
|
|
48
|
+
|
|
49
|
+
V[state_index] += prob * (reward + (self._gamma*prev_V[next_state_index] if not terminal else 0.))
|
|
50
|
+
|
|
51
|
+
if np.max(np.abs(V-prev_V)) < self._theta:
|
|
52
|
+
break
|
|
53
|
+
|
|
54
|
+
prev_V = V.copy()
|
|
55
|
+
|
|
56
|
+
return V
|
|
57
|
+
|
|
58
|
+
def _policy_improvement(self, V: np.ndarray) -> DiscreteDeterminsticPolicy:
|
|
59
|
+
Q = np.zeros((len(self._states_to_indices), len(self._actions_to_indices)))
|
|
60
|
+
|
|
61
|
+
for state_value in self._state_names:
|
|
62
|
+
state_index = self._states_to_indices[state_value]
|
|
63
|
+
for action_value in self._action_names:
|
|
64
|
+
action_index = self._actions_to_indices[action_value]
|
|
65
|
+
for next_state_tuple in self._transprobfac.transition_probabilities[state_value][action_value]:
|
|
66
|
+
prob = next_state_tuple.probability
|
|
67
|
+
reward = next_state_tuple.reward
|
|
68
|
+
next_state_value = next_state_tuple.next_state_value
|
|
69
|
+
next_state_index = self._states_to_indices[next_state_value]
|
|
70
|
+
terminal = next_state_tuple.terminal
|
|
71
|
+
|
|
72
|
+
Q[state_index, action_index] += prob * (reward + (self._gamma*V[next_state_index] if not terminal else 0.))
|
|
73
|
+
|
|
74
|
+
optimal_policy = DiscreteDeterminsticPolicy(self._actions_dict)
|
|
75
|
+
optimal_action_indices = np.argmax(Q, axis=1)
|
|
76
|
+
for state_value, action_index in zip(self._state_names, optimal_action_indices):
|
|
77
|
+
action_value = self._action_names[action_index]
|
|
78
|
+
optimal_policy.add_deterministic_rule(state_value, action_value)
|
|
79
|
+
return optimal_policy
|
|
80
|
+
|
|
81
|
+
def _policy_iteration(self) -> Tuple[np.ndarray, DiscreteDeterminsticPolicy]:
|
|
82
|
+
policy = DiscreteDeterminsticPolicy(self._actions_dict)
|
|
83
|
+
for state_value in self._state_names:
|
|
84
|
+
policy.add_deterministic_rule(state_value, random.choice(self._action_names))
|
|
85
|
+
V = None
|
|
86
|
+
|
|
87
|
+
done = False
|
|
88
|
+
while not done:
|
|
89
|
+
old_policy = copy(policy)
|
|
90
|
+
|
|
91
|
+
V = self._policy_evaluation(policy)
|
|
92
|
+
policy = self._policy_improvement(V)
|
|
93
|
+
|
|
94
|
+
if policy == old_policy:
|
|
95
|
+
done = True
|
|
96
|
+
|
|
97
|
+
return V, policy
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _value_iteration(self) -> Tuple[np.ndarray, DiscreteDeterminsticPolicy]:
|
|
101
|
+
V = np.zeros(len(self._state_names))
|
|
102
|
+
|
|
103
|
+
for _ in range(self._policy_evaluation_maxiter):
|
|
104
|
+
Q = np.zeros((len(self._state_names), len(self._action_names)))
|
|
105
|
+
for state_value, action_value in product(self._state_names, self._action_names):
|
|
106
|
+
state_index = self._states_to_indices[state_value]
|
|
107
|
+
action_index = self._actions_to_indices[action_value]
|
|
108
|
+
for next_state_tuple in self._transprobfac.transition_probabilities[state_value][action_value]:
|
|
109
|
+
prob = next_state_tuple.probability
|
|
110
|
+
reward = next_state_tuple.reward
|
|
111
|
+
next_state_value = next_state_tuple.next_state_value
|
|
112
|
+
next_state_index = self._states_to_indices[next_state_value]
|
|
113
|
+
terminal = next_state_tuple.terminal
|
|
114
|
+
|
|
115
|
+
Q[state_index, action_index] += prob * (reward + (self._gamma * V[next_state_index] if not terminal else 0.))
|
|
116
|
+
|
|
117
|
+
if np.max(np.abs(V-np.max(Q, axis=1))) < self._theta:
|
|
118
|
+
break
|
|
119
|
+
|
|
120
|
+
V = np.max(Q, axis=1)
|
|
121
|
+
|
|
122
|
+
Qmaxj = np.argmax(Q, axis=1)
|
|
123
|
+
|
|
124
|
+
policy = DiscreteDeterminsticPolicy(self._actions_dict)
|
|
125
|
+
for state_value, action_index in zip(self._state_names, Qmaxj):
|
|
126
|
+
policy.add_deterministic_rule(state_value, self._action_names[action_index])
|
|
127
|
+
|
|
128
|
+
return V, policy
|
|
129
|
+
|
|
130
|
+
def policy_iteration(self) -> Tuple[Dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
|
|
131
|
+
V, policy = self._policy_iteration()
|
|
132
|
+
state_values_dict = {
|
|
133
|
+
self._state_names[i]: V[i]
|
|
134
|
+
for i in range(V.shape[0])
|
|
135
|
+
}
|
|
136
|
+
return state_values_dict, policy
|
|
137
|
+
|
|
138
|
+
def value_iteration(self) -> Tuple[Dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
|
|
139
|
+
V, policy = self._value_iteration()
|
|
140
|
+
state_values_dict = {
|
|
141
|
+
self._state_names[i]: V[i]
|
|
142
|
+
for i in range(V.shape[0])
|
|
143
|
+
}
|
|
144
|
+
return state_values_dict, policy
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: pyrlutils
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.2
|
|
4
4
|
Summary: Utility and Helpers for Reinformcement Learning
|
|
5
5
|
Home-page: https://github.com/stephenhky/PyRLUtils
|
|
6
6
|
Author: Kwan-Yuet Ho
|
|
7
7
|
Author-email: stephenhky@yahoo.com.hk
|
|
8
|
-
License:
|
|
8
|
+
License: MIT
|
|
9
9
|
Keywords: machine learning,reinforcement leaning,artifiial intelligence
|
|
10
10
|
Platform: UNKNOWN
|
|
11
11
|
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
@@ -5,16 +5,20 @@ setup.py
|
|
|
5
5
|
pyrlutils/__init__.py
|
|
6
6
|
pyrlutils/action.py
|
|
7
7
|
pyrlutils/policy.py
|
|
8
|
+
pyrlutils/reward.py
|
|
8
9
|
pyrlutils/state.py
|
|
9
10
|
pyrlutils/transition.py
|
|
10
|
-
pyrlutils/
|
|
11
|
+
pyrlutils/valuefcns.py
|
|
11
12
|
pyrlutils.egg-info/PKG-INFO
|
|
12
13
|
pyrlutils.egg-info/SOURCES.txt
|
|
13
14
|
pyrlutils.egg-info/dependency_links.txt
|
|
14
15
|
pyrlutils.egg-info/not-zip-safe
|
|
15
16
|
pyrlutils.egg-info/requires.txt
|
|
16
17
|
pyrlutils.egg-info/top_level.txt
|
|
18
|
+
test/test_2ddiscrete.py
|
|
19
|
+
test/test_2dmaze.py
|
|
17
20
|
test/test_action.py
|
|
18
21
|
test/test_continous_state_actions.py
|
|
22
|
+
test/test_frozenlake.py
|
|
19
23
|
test/test_state.py
|
|
20
24
|
test/test_transprobs.py
|
|
@@ -18,7 +18,7 @@ def package_description():
|
|
|
18
18
|
|
|
19
19
|
setup(
|
|
20
20
|
name='pyrlutils',
|
|
21
|
-
version="0.0.
|
|
21
|
+
version="0.0.2",
|
|
22
22
|
description="Utility and Helpers for Reinformcement Learning",
|
|
23
23
|
long_description=package_description(),
|
|
24
24
|
long_description_content_type='text/markdown',
|
|
@@ -38,7 +38,7 @@ setup(
|
|
|
38
38
|
url="https://github.com/stephenhky/PyRLUtils",
|
|
39
39
|
author="Kwan-Yuet Ho",
|
|
40
40
|
author_email="stephenhky@yahoo.com.hk",
|
|
41
|
-
license='
|
|
41
|
+
license='MIT',
|
|
42
42
|
packages=[
|
|
43
43
|
'pyrlutils'
|
|
44
44
|
],
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
|
|
2
|
+
import unittest
|
|
3
|
+
|
|
4
|
+
from pyrlutils.state import Discrete2DCartesianState
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Test2DDiscreteState(unittest.TestCase):
|
|
8
|
+
def test_twobythree(self):
|
|
9
|
+
state = Discrete2DCartesianState(0, 1, 0, 2)
|
|
10
|
+
|
|
11
|
+
assert state.state_space_size == 6
|
|
12
|
+
assert state.state_value == 0
|
|
13
|
+
|
|
14
|
+
state.set_state_value(5)
|
|
15
|
+
assert state.decode_coordinates(state.state_value) == [1, 2]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
if __name__ == '__main__':
|
|
20
|
+
unittest.main()
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
|
|
2
|
+
import unittest
|
|
3
|
+
|
|
4
|
+
from pyrlutils.transition import TransitionProbabilityFactory, NextStateTuple
|
|
5
|
+
from pyrlutils.valuefcns import OptimalPolicyOnValueFunctions
|
|
6
|
+
from pyrlutils.state import Discrete2DCartesianState
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Test2DMaze(unittest.TestCase):
|
|
10
|
+
def setUp(self):
|
|
11
|
+
maze_state = Discrete2DCartesianState(0, 5, 0, 4, initial_coordinate=[0, 0])
|
|
12
|
+
|
|
13
|
+
transprobfactory = TransitionProbabilityFactory()
|
|
14
|
+
transprobfactory.add_state_transitions(
|
|
15
|
+
maze_state.encode_coordinates([0, 0]),
|
|
16
|
+
{
|
|
17
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([0, 0]), 1., 0., False)],
|
|
18
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([0, 0]), 1., 0., False)],
|
|
19
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([0, 0]), 1., 0., False)],
|
|
20
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([0, 1]), 1., 0., False)]
|
|
21
|
+
}
|
|
22
|
+
)
|
|
23
|
+
transprobfactory.add_state_transitions(
|
|
24
|
+
maze_state.encode_coordinates([0, 1]),
|
|
25
|
+
{
|
|
26
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([0, 1]), 1., 0., False)],
|
|
27
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([1, 1]), 1., 0., False)],
|
|
28
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([0, 0]), 1., 0., False)],
|
|
29
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([0, 2]), 1., 0., False)]
|
|
30
|
+
}
|
|
31
|
+
)
|
|
32
|
+
transprobfactory.add_state_transitions(
|
|
33
|
+
maze_state.encode_coordinates([0, 2]),
|
|
34
|
+
{
|
|
35
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([0, 2]), 1., 0., False)],
|
|
36
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([0, 2]), 1., 0., False)],
|
|
37
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([0, 1]), 1., 0., False)],
|
|
38
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([0, 3]), 1., 0., False)]
|
|
39
|
+
}
|
|
40
|
+
)
|
|
41
|
+
transprobfactory.add_state_transitions(
|
|
42
|
+
maze_state.encode_coordinates([0, 3]),
|
|
43
|
+
{
|
|
44
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([0, 3]), 1., 0., False)],
|
|
45
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([1, 3]), 1., 0., False)],
|
|
46
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([0, 2]), 1., 0., False)],
|
|
47
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([0, 4]), 1., 0., False)]
|
|
48
|
+
}
|
|
49
|
+
)
|
|
50
|
+
transprobfactory.add_state_transitions(
|
|
51
|
+
maze_state.encode_coordinates([0, 4]),
|
|
52
|
+
{
|
|
53
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([0, 4]), 1., 0., False)],
|
|
54
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([1, 4]), 1., 0., False)],
|
|
55
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([0, 3]), 1., 0., False)],
|
|
56
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([0, 4]), 1., 0., False)]
|
|
57
|
+
}
|
|
58
|
+
)
|
|
59
|
+
transprobfactory.add_state_transitions(
|
|
60
|
+
maze_state.encode_coordinates([1, 0]),
|
|
61
|
+
{
|
|
62
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([1, 0]), 1., 0., False)],
|
|
63
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([2, 0]), 1., 0., False)],
|
|
64
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([1, 0]), 1., 0., False)],
|
|
65
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([1, 1]), 1., 0., False)]
|
|
66
|
+
}
|
|
67
|
+
)
|
|
68
|
+
transprobfactory.add_state_transitions(
|
|
69
|
+
maze_state.encode_coordinates([1, 1]),
|
|
70
|
+
{
|
|
71
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([0, 1]), 1., 0., False)],
|
|
72
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([1, 1]), 1., 0., False)],
|
|
73
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([1, 0]), 1., 0., False)],
|
|
74
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([1, 1]), 1., 0., False)]
|
|
75
|
+
}
|
|
76
|
+
)
|
|
77
|
+
transprobfactory.add_state_transitions(
|
|
78
|
+
maze_state.encode_coordinates([1, 2]),
|
|
79
|
+
{
|
|
80
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([1, 2]), 1., 0., False)],
|
|
81
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([2, 2]), 1., 0., False)],
|
|
82
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([1, 2]), 1., 0., False)],
|
|
83
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([1, 3]), 1., 0., False)]
|
|
84
|
+
}
|
|
85
|
+
)
|
|
86
|
+
transprobfactory.add_state_transitions(
|
|
87
|
+
maze_state.encode_coordinates([1, 3]),
|
|
88
|
+
{
|
|
89
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([0, 3]), 1., 0., False)],
|
|
90
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([1, 3]), 1., 0., False)],
|
|
91
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([1, 2]), 1., 0., False)],
|
|
92
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([1, 4]), 1., 0., False)]
|
|
93
|
+
}
|
|
94
|
+
)
|
|
95
|
+
transprobfactory.add_state_transitions(
|
|
96
|
+
maze_state.encode_coordinates([1, 4]),
|
|
97
|
+
{
|
|
98
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([0, 4]), 1., 0., False)],
|
|
99
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([1, 4]), 1., 0., False)],
|
|
100
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([1, 4]), 1., 0., False)],
|
|
101
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([1, 4]), 1., 0., False)]
|
|
102
|
+
}
|
|
103
|
+
)
|
|
104
|
+
transprobfactory.add_state_transitions(
|
|
105
|
+
maze_state.encode_coordinates([2, 0]),
|
|
106
|
+
{
|
|
107
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([1, 0]), 1., 0., False)],
|
|
108
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([3, 0]), 1., 0., False)],
|
|
109
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([2, 0]), 1., 0., False)],
|
|
110
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([2, 1]), 1., 0., False)]
|
|
111
|
+
}
|
|
112
|
+
)
|
|
113
|
+
transprobfactory.add_state_transitions(
|
|
114
|
+
maze_state.encode_coordinates([2, 1]),
|
|
115
|
+
{
|
|
116
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([2, 1]), 1., 0., False)],
|
|
117
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([2, 1]), 1., 0., False)],
|
|
118
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([2, 0]), 1., 0., False)],
|
|
119
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([2, 2]), 1., 0., False)]
|
|
120
|
+
}
|
|
121
|
+
)
|
|
122
|
+
transprobfactory.add_state_transitions(
|
|
123
|
+
maze_state.encode_coordinates([2, 2]),
|
|
124
|
+
{
|
|
125
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([1, 2]), 1., 0., False)],
|
|
126
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([2, 2]), 1., 0., False)],
|
|
127
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([2, 1]), 1., 0., False)],
|
|
128
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([2, 3]), 1., 0., False)]
|
|
129
|
+
}
|
|
130
|
+
)
|
|
131
|
+
transprobfactory.add_state_transitions(
|
|
132
|
+
maze_state.encode_coordinates([2, 3]),
|
|
133
|
+
{
|
|
134
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([2, 3]), 1., 0., False)],
|
|
135
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([3, 3]), 1., 0., False)],
|
|
136
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([2, 2]), 1., 0., False)],
|
|
137
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([2, 3]), 1., 0., False)]
|
|
138
|
+
}
|
|
139
|
+
)
|
|
140
|
+
transprobfactory.add_state_transitions(
|
|
141
|
+
maze_state.encode_coordinates([2, 4]),
|
|
142
|
+
{
|
|
143
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([2, 4]), 1., 0., False)],
|
|
144
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([3, 4]), 1., 0., False)],
|
|
145
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([2, 4]), 1., 0., False)],
|
|
146
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([2, 4]), 1., 0., False)]
|
|
147
|
+
}
|
|
148
|
+
)
|
|
149
|
+
transprobfactory.add_state_transitions(
|
|
150
|
+
maze_state.encode_coordinates([3, 0]),
|
|
151
|
+
{
|
|
152
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([2, 0]), 1., 0., False)],
|
|
153
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([4, 0]), 1., 0., False)],
|
|
154
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([3, 0]), 1., 0., False)],
|
|
155
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([3, 1]), 1., 0., False)]
|
|
156
|
+
}
|
|
157
|
+
)
|
|
158
|
+
transprobfactory.add_state_transitions(
|
|
159
|
+
maze_state.encode_coordinates([3, 1]),
|
|
160
|
+
{
|
|
161
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([3, 1]), 1., 0., False)],
|
|
162
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([3, 1]), 1., 0., False)],
|
|
163
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([3, 0]), 1., 0., False)],
|
|
164
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([3, 2]), 1., 0., False)]
|
|
165
|
+
}
|
|
166
|
+
)
|
|
167
|
+
transprobfactory.add_state_transitions(
|
|
168
|
+
maze_state.encode_coordinates([3, 2]),
|
|
169
|
+
{
|
|
170
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([3, 2]), 1., 0., False)],
|
|
171
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([4, 2]), 1., 0., False)],
|
|
172
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([3, 1]), 1., 0., False)],
|
|
173
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([3, 3]), 1., 0., False)]
|
|
174
|
+
}
|
|
175
|
+
)
|
|
176
|
+
transprobfactory.add_state_transitions(
|
|
177
|
+
maze_state.encode_coordinates([3, 3]),
|
|
178
|
+
{
|
|
179
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([2, 3]), 1., 0., False)],
|
|
180
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([4, 3]), 1., 0., False)],
|
|
181
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([3, 3]), 1., 0., False)],
|
|
182
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([3, 4]), 1., 0., False)]
|
|
183
|
+
}
|
|
184
|
+
)
|
|
185
|
+
transprobfactory.add_state_transitions(
|
|
186
|
+
maze_state.encode_coordinates([3, 4]),
|
|
187
|
+
{
|
|
188
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([2, 4]), 1., 0., False)],
|
|
189
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([3, 4]), 1., 0., False)],
|
|
190
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([3, 3]), 1., 0., False)],
|
|
191
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([3, 4]), 1., 0., False)]
|
|
192
|
+
}
|
|
193
|
+
)
|
|
194
|
+
transprobfactory.add_state_transitions(
|
|
195
|
+
maze_state.encode_coordinates([4, 0]),
|
|
196
|
+
{
|
|
197
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([3, 0]), 1., 0., False)],
|
|
198
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([5, 0]), 1., 0., False)],
|
|
199
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([4, 0]), 1., 0., False)],
|
|
200
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([4, 0]), 1., 0., False)]
|
|
201
|
+
}
|
|
202
|
+
)
|
|
203
|
+
transprobfactory.add_state_transitions(
|
|
204
|
+
maze_state.encode_coordinates([4, 1]),
|
|
205
|
+
{
|
|
206
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([4, 1]), 1., 0., False)],
|
|
207
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([5, 1]), 1., 0., False)],
|
|
208
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([4, 1]), 1., 0., False)],
|
|
209
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([4, 1]), 1., 0., False)]
|
|
210
|
+
}
|
|
211
|
+
)
|
|
212
|
+
transprobfactory.add_state_transitions(
|
|
213
|
+
maze_state.encode_coordinates([4, 2]),
|
|
214
|
+
{
|
|
215
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([3, 2]), 1., 0., False)],
|
|
216
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([4, 2]), 1., 0., False)],
|
|
217
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([4, 2]), 1., 0., False)],
|
|
218
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([4, 3]), 1., 0., False)]
|
|
219
|
+
}
|
|
220
|
+
)
|
|
221
|
+
transprobfactory.add_state_transitions(
|
|
222
|
+
maze_state.encode_coordinates([4, 3]),
|
|
223
|
+
{
|
|
224
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([3, 3]), 1., 0., False)],
|
|
225
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([4, 3]), 1., 0., False)],
|
|
226
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([4, 2]), 1., 0., False)],
|
|
227
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([4, 4]), 1., 0., False)]
|
|
228
|
+
}
|
|
229
|
+
)
|
|
230
|
+
transprobfactory.add_state_transitions(
|
|
231
|
+
maze_state.encode_coordinates([4, 4]),
|
|
232
|
+
{
|
|
233
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([4, 4]), 1., 0., False)],
|
|
234
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([5, 4]), 1., 1., True)],
|
|
235
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([4, 3]), 1., 0., False)],
|
|
236
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([4, 4]), 1., 0., False)]
|
|
237
|
+
}
|
|
238
|
+
)
|
|
239
|
+
transprobfactory.add_state_transitions(
|
|
240
|
+
maze_state.encode_coordinates([5, 0]),
|
|
241
|
+
{
|
|
242
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([4, 0]), 1., 0., False)],
|
|
243
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([5, 0]), 1., 0., False)],
|
|
244
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([5, 0]), 1., 0., False)],
|
|
245
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([5, 1]), 1., 0., False)]
|
|
246
|
+
}
|
|
247
|
+
)
|
|
248
|
+
transprobfactory.add_state_transitions(
|
|
249
|
+
maze_state.encode_coordinates([5, 1]),
|
|
250
|
+
{
|
|
251
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([4, 1]), 1., 0., False)],
|
|
252
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([5, 1]), 1., 0., False)],
|
|
253
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([5, 0]), 1., 0., False)],
|
|
254
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([5, 1]), 1., 0., False)]
|
|
255
|
+
}
|
|
256
|
+
)
|
|
257
|
+
transprobfactory.add_state_transitions(
|
|
258
|
+
maze_state.encode_coordinates([5, 2]),
|
|
259
|
+
{
|
|
260
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([5, 2]), 1., 0., False)],
|
|
261
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([5, 2]), 1., 0., False)],
|
|
262
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([5, 2]), 1., 0., False)],
|
|
263
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([5, 3]), 1., 0., False)]
|
|
264
|
+
}
|
|
265
|
+
)
|
|
266
|
+
transprobfactory.add_state_transitions(
|
|
267
|
+
maze_state.encode_coordinates([5, 3]),
|
|
268
|
+
{
|
|
269
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([5, 3]), 1., 0., False)],
|
|
270
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([5, 3]), 1., 0., False)],
|
|
271
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([5, 2]), 1., 0., False)],
|
|
272
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([5, 4]), 1., 1., True)]
|
|
273
|
+
}
|
|
274
|
+
)
|
|
275
|
+
transprobfactory.add_state_transitions(
|
|
276
|
+
maze_state.encode_coordinates([5, 4]),
|
|
277
|
+
{
|
|
278
|
+
'up': [NextStateTuple(maze_state.encode_coordinates([4, 4]), 1., 0., False)],
|
|
279
|
+
'down': [NextStateTuple(maze_state.encode_coordinates([5, 4]), 1., 1., True)],
|
|
280
|
+
'left': [NextStateTuple(maze_state.encode_coordinates([5, 3]), 1., 0., False)],
|
|
281
|
+
'right': [NextStateTuple(maze_state.encode_coordinates([5, 4]), 1., 1., True)]
|
|
282
|
+
}
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
self.transprobfactory = transprobfactory
|
|
286
|
+
self.maze_state = maze_state
|
|
287
|
+
|
|
288
|
+
def test_policy_iteration(self):
|
|
289
|
+
policy_finder = OptimalPolicyOnValueFunctions(0.85, self.transprobfactory)
|
|
290
|
+
values_dict, policy = policy_finder.policy_iteration()
|
|
291
|
+
|
|
292
|
+
for state_value, value in values_dict.items():
|
|
293
|
+
[x, y] = self.maze_state.decode_coordinates(state_value)
|
|
294
|
+
print('({}, {}): {}'.format(x, y, value))
|
|
295
|
+
|
|
296
|
+
state, actions_dict, _ = self.transprobfactory.generate_mdp_objects()
|
|
297
|
+
|
|
298
|
+
arrived_destination = False
|
|
299
|
+
for _ in range(state.state_space_size*2):
|
|
300
|
+
action_value = policy.get_action_value(state)
|
|
301
|
+
print('Action value: {}'.format(action_value))
|
|
302
|
+
action = policy.get_action(state)
|
|
303
|
+
state = action(state)
|
|
304
|
+
|
|
305
|
+
coordinates = self.maze_state.decode_coordinates(state.state_value)
|
|
306
|
+
print('at: {}, {}'.format(coordinates[0], coordinates[1]))
|
|
307
|
+
if coordinates[0] == 5 and coordinates[1] == 4:
|
|
308
|
+
arrived_destination = True
|
|
309
|
+
break
|
|
310
|
+
|
|
311
|
+
assert arrived_destination
|
|
312
|
+
|
|
313
|
+
def test_value_iteration(self):
|
|
314
|
+
policy_finder = OptimalPolicyOnValueFunctions(0.85, self.transprobfactory)
|
|
315
|
+
values_dict, policy = policy_finder.value_iteration()
|
|
316
|
+
|
|
317
|
+
for state_value, value in values_dict.items():
|
|
318
|
+
[x, y] = self.maze_state.decode_coordinates(state_value)
|
|
319
|
+
print('({}, {}): {}'.format(x, y, value))
|
|
320
|
+
|
|
321
|
+
state, actions_dict, _ = self.transprobfactory.generate_mdp_objects()
|
|
322
|
+
|
|
323
|
+
arrived_destination = False
|
|
324
|
+
for _ in range(state.state_space_size*2):
|
|
325
|
+
action_value = policy.get_action_value(state)
|
|
326
|
+
print('Action value: {}'.format(action_value))
|
|
327
|
+
action = policy.get_action(state)
|
|
328
|
+
state = action(state)
|
|
329
|
+
|
|
330
|
+
coordinates = self.maze_state.decode_coordinates(state.state_value)
|
|
331
|
+
print('at: {}, {}'.format(coordinates[0], coordinates[1]))
|
|
332
|
+
if coordinates[0] == 5 and coordinates[1] == 4:
|
|
333
|
+
arrived_destination = True
|
|
334
|
+
break
|
|
335
|
+
|
|
336
|
+
assert arrived_destination
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
if __name__ == '__main__':
|
|
341
|
+
unittest.main()
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
|
|
2
|
+
import unittest
|
|
3
|
+
|
|
4
|
+
from pyrlutils.transition import OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory
|
|
5
|
+
|
|
6
|
+
class TestFrozenLake(unittest.TestCase):
|
|
7
|
+
def test_factory(self):
|
|
8
|
+
tranprobfactory = OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory('FrozenLake-v1')
|
|
9
|
+
state, actions_dict, ind_reward_fcn = tranprobfactory.generate_mdp_objects()
|
|
10
|
+
|
|
11
|
+
assert len(state.get_all_possible_state_values()) == 16
|
|
12
|
+
assert state.state_value == 0
|
|
13
|
+
|
|
14
|
+
actions_dict[0](state)
|
|
15
|
+
assert state.state_value in {0, 4}
|
|
16
|
+
|
|
17
|
+
state.state_value = 15
|
|
18
|
+
actions_dict[2](state)
|
|
19
|
+
assert state.state_value == 15
|
|
20
|
+
|
|
21
|
+
assert ind_reward_fcn(0, 0, 0) == 0.0
|
|
22
|
+
assert ind_reward_fcn(14, 3, 15) == 1.0
|
|
23
|
+
|
|
24
|
+
assert abs(tranprobfactory.get_probability(0, 0, 0) - 0.66667) < 1e-4
|
|
25
|
+
assert abs(tranprobfactory.get_probability(14, 3, 15) - 0.33333) < 1e-4
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
if __name__ == '__main__':
|
|
29
|
+
unittest.main()
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
from abc import ABC, abstractmethod
|
|
3
|
-
|
|
4
|
-
from .state import State
|
|
5
|
-
from .action import Action
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class Policy(ABC):
|
|
9
|
-
@abstractmethod
|
|
10
|
-
def get_action(self, state: State) -> Action:
|
|
11
|
-
pass
|
|
12
|
-
|
|
13
|
-
def __call__(self, state: State) -> Action:
|
|
14
|
-
return self.get_action(state)
|
|
15
|
-
|
|
16
|
-
@property
|
|
17
|
-
def is_stochastic(self) -> bool:
|
|
18
|
-
pass
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class DeterministicPolicy(Policy):
|
|
22
|
-
@property
|
|
23
|
-
def is_stochastic(self) -> bool:
|
|
24
|
-
return False
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class StochasticPolicy(Policy):
|
|
28
|
-
@abstractmethod
|
|
29
|
-
def get_probability(self, state: State, action: Action) -> float:
|
|
30
|
-
pass
|
|
31
|
-
|
|
32
|
-
@property
|
|
33
|
-
def is_stochastic(self) -> bool:
|
|
34
|
-
return True
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|