pyrlutils 0.0.1__tar.gz → 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {pyrlutils-0.0.1/pyrlutils.egg-info → pyrlutils-0.0.2}/PKG-INFO +2 -2
  2. pyrlutils-0.0.2/pyrlutils/policy.py +84 -0
  3. pyrlutils-0.0.1/pyrlutils/values.py → pyrlutils-0.0.2/pyrlutils/reward.py +2 -0
  4. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils/state.py +28 -2
  5. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils/transition.py +66 -26
  6. pyrlutils-0.0.2/pyrlutils/valuefcns.py +144 -0
  7. {pyrlutils-0.0.1 → pyrlutils-0.0.2/pyrlutils.egg-info}/PKG-INFO +2 -2
  8. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils.egg-info/SOURCES.txt +5 -1
  9. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils.egg-info/requires.txt +1 -0
  10. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/setup.py +2 -2
  11. pyrlutils-0.0.2/test/test_2ddiscrete.py +20 -0
  12. pyrlutils-0.0.2/test/test_2dmaze.py +341 -0
  13. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/test/test_action.py +0 -2
  14. pyrlutils-0.0.2/test/test_frozenlake.py +29 -0
  15. pyrlutils-0.0.1/pyrlutils/policy.py +0 -34
  16. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/LICENSE +0 -0
  17. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/MANIFEST.in +0 -0
  18. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/README.md +0 -0
  19. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils/__init__.py +0 -0
  20. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils/action.py +0 -0
  21. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils.egg-info/dependency_links.txt +0 -0
  22. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils.egg-info/not-zip-safe +0 -0
  23. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/pyrlutils.egg-info/top_level.txt +0 -0
  24. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/setup.cfg +0 -0
  25. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/test/test_continous_state_actions.py +0 -0
  26. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/test/test_state.py +0 -0
  27. {pyrlutils-0.0.1 → pyrlutils-0.0.2}/test/test_transprobs.py +0 -0
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pyrlutils
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: Utility and Helpers for Reinformcement Learning
5
5
  Home-page: https://github.com/stephenhky/PyRLUtils
6
6
  Author: Kwan-Yuet Ho
7
7
  Author-email: stephenhky@yahoo.com.hk
8
- License: LGPL
8
+ License: MIT
9
9
  Keywords: machine learning,reinforcement leaning,artifiial intelligence
10
10
  Platform: UNKNOWN
11
11
  Classifier: Topic :: Scientific/Engineering :: Mathematics
@@ -0,0 +1,84 @@
1
+
2
+ from abc import ABC, abstractmethod
3
+ from typing import Union, Dict
4
+ from warnings import warn
5
+
6
+ import numpy as np
7
+
8
+ from .state import State, DiscreteState, DiscreteStateValueType
9
+ from .action import Action, DiscreteActionValueType
10
+
11
+
12
+ class Policy(ABC):
13
+ @abstractmethod
14
+ def get_action(self, state: State) -> Action:
15
+ pass
16
+
17
+ def __call__(self, state: State) -> Action:
18
+ return self.get_action(state)
19
+
20
+ @property
21
+ def is_stochastic(self) -> bool:
22
+ raise NotImplemented()
23
+
24
+
25
+ class DeterministicPolicy(Policy):
26
+ @abstractmethod
27
+ def add_deterministic_rule(self, *args, **kwargs):
28
+ pass
29
+
30
+ @property
31
+ def is_stochastic(self) -> bool:
32
+ return False
33
+
34
+
35
+ class DiscreteDeterminsticPolicy(DeterministicPolicy):
36
+ def __init__(self, actions_dict: Dict[DiscreteActionValueType, Action]):
37
+ self._state_to_action = {}
38
+ self._actions_dict = actions_dict
39
+
40
+ def add_deterministic_rule(self, state_value: DiscreteStateValueType, action_value: DiscreteActionValueType):
41
+ if state_value in self._state_to_action:
42
+ warn('State value {} exists in rule; it will be replaced.'.format(state_value))
43
+ self._state_to_action[state_value] = action_value
44
+
45
+ def get_action_value(self, state_value: DiscreteStateValueType) -> DiscreteActionValueType:
46
+ return self._state_to_action.get(state_value)
47
+
48
+ def get_action(self, state: DiscreteState) -> Action:
49
+ return self._actions_dict[self.get_action_value(state.state_value)]
50
+
51
+ def __eq__(self, other) -> bool:
52
+ if len(self._state_to_action) != len(set(self._state_to_action.keys()).union(other._state_to_action.keys())):
53
+ return False
54
+ if len(self._actions_dict) != len(set(self._actions_dict.keys()).union(other._actions_dict.keys())):
55
+ return False
56
+ for action in self._actions_dict.keys():
57
+ if self._actions_dict[action] != other._actions_dict[action]:
58
+ return False
59
+ for state in self._state_to_action.keys():
60
+ if self._state_to_action[state] != other._state_to_action[state]:
61
+ return False
62
+ return True
63
+
64
+
65
+ class StochasticPolicy(Policy):
66
+ @abstractmethod
67
+ def get_probability(self, *args, **kwargs) -> float:
68
+ pass
69
+
70
+ @property
71
+ def is_stochastic(self) -> bool:
72
+ return True
73
+
74
+
75
+ class DiscreteStochasticPolicy(StochasticPolicy):
76
+ @abstractmethod
77
+ def get_probability(self, state_value: DiscreteStateValueType, action_value: DiscreteActionValueType) -> float:
78
+ pass
79
+
80
+
81
+ class ContinuousStochasticPolicy(StochasticPolicy):
82
+ @abstractmethod
83
+ def get_probability(self, state_value: Union[float, np.ndarray], action_value: DiscreteActionValueType, value: Union[float, np.ndarray]) -> float:
84
+ pass
@@ -33,3 +33,5 @@ class RewardFunction(ABC):
33
33
 
34
34
  def __call__(self, state_value, action_value) -> float:
35
35
  return self.total_reward(state_value, action_value)
36
+
37
+
@@ -1,6 +1,6 @@
1
1
 
2
2
  from abc import ABC, abstractmethod
3
- from typing import List, Optional, Union
3
+ from typing import Tuple, List, Optional, Union
4
4
 
5
5
  import numpy as np
6
6
 
@@ -23,7 +23,7 @@ class State(ABC):
23
23
  self.set_state_value(new_state_value)
24
24
 
25
25
 
26
- DiscreteStateValueType = Union[float, str]
26
+ DiscreteStateValueType = Union[float, str, Tuple[int]]
27
27
 
28
28
 
29
29
  class DiscreteState(State):
@@ -52,6 +52,10 @@ class DiscreteState(State):
52
52
  def state_value(self, new_state_value: DiscreteStateValueType):
53
53
  self.set_state_value(new_state_value)
54
54
 
55
+ @property
56
+ def state_space_size(self):
57
+ return len(self._all_state_values)
58
+
55
59
 
56
60
  class InvalidRangeError(Exception):
57
61
  def __init__(self, message=None):
@@ -168,3 +172,25 @@ class ContinuousState(State):
168
172
  return self._nbdims
169
173
 
170
174
 
175
+ class Discrete2DCartesianState(DiscreteState):
176
+ def __init__(self, x_lowlim: int, x_hilim: int, y_lowlim: int, y_hilim: int, initial_coordinate: List[int]=None):
177
+ self._x_lowlim = x_lowlim
178
+ self._x_hilim = x_hilim
179
+ self._y_lowlim = y_lowlim
180
+ self._y_hilim = y_hilim
181
+ self._countx = self._x_hilim - self._x_lowlim + 1
182
+ self._county = self._y_hilim - self._y_lowlim + 1
183
+ if initial_coordinate is None:
184
+ initial_coordinate = [self._x_lowlim, self._y_lowlim]
185
+ initial_value = (initial_coordinate[1] - self._y_lowlim) * self._countx + (initial_coordinate[0] - self._x_lowlim)
186
+ super().__init__(list(range(self._countx*self._county)), initial_values=initial_value)
187
+
188
+ def _encode_coordinates(self, x, y) -> int:
189
+ return (y - self._y_lowlim) * self._countx + (x - self._x_lowlim)
190
+
191
+ def encode_coordinates(self, coordinates: List[int]) -> int:
192
+ assert len(coordinates) == 2
193
+ return self._encode_coordinates(coordinates[0], coordinates[1])
194
+
195
+ def decode_coordinates(self, hashcode) -> List[int]:
196
+ return [hashcode % self._countx, hashcode // self._countx]
@@ -3,9 +3,10 @@ from types import LambdaType
3
3
  from typing import Tuple, Dict
4
4
 
5
5
  import numpy as np
6
+ import gym
6
7
 
7
8
  from .state import DiscreteState, DiscreteStateValueType
8
- from .values import IndividualRewardFunction
9
+ from .reward import IndividualRewardFunction
9
10
  from .action import Action, DiscreteActionValueType
10
11
 
11
12
 
@@ -35,22 +36,22 @@ class NextStateTuple:
35
36
 
36
37
  class TransitionProbabilityFactory:
37
38
  def __init__(self):
38
- self.transprobs = {}
39
- self.all_state_values = []
40
- self.all_action_values = []
41
- self.objects_generated = False
39
+ self._transprobs = {}
40
+ self._all_state_values = []
41
+ self._all_action_values = []
42
+ self._objects_generated = False
42
43
 
43
44
  def add_state_transitions(self, state_value: DiscreteStateValueType, action_values_to_next_state: dict):
44
- if state_value not in self.all_state_values:
45
- self.all_state_values.append(state_value)
45
+ if state_value not in self._all_state_values:
46
+ self._all_state_values.append(state_value)
46
47
 
47
48
  this_state_transition_dict = {}
48
49
 
49
50
  for action_value, next_state_tuples in action_values_to_next_state.items():
50
51
  this_state_transition_dict[action_value] = []
51
52
  for next_state_tuple in next_state_tuples:
52
- if action_value not in self.all_action_values:
53
- self.all_action_values.append(action_value)
53
+ if action_value not in self._all_action_values:
54
+ self._all_action_values.append(action_value)
54
55
  if not isinstance(next_state_tuple, NextStateTuple):
55
56
  if isinstance(next_state_tuple, dict):
56
57
  next_state_tuple = NextStateTuple(
@@ -62,16 +63,16 @@ class TransitionProbabilityFactory:
62
63
  else:
63
64
  raise TypeError('"action_values_to_next_state" has to be a dictionary or NextStateTuple instance.')
64
65
 
65
- if next_state_tuple.next_state_value not in self.all_state_values:
66
- self.all_state_values.append(next_state_tuple.next_state_value)
66
+ if next_state_tuple.next_state_value not in self._all_state_values:
67
+ self._all_state_values.append(next_state_tuple.next_state_value)
67
68
 
68
69
  this_state_transition_dict[action_value].append(next_state_tuple)
69
70
 
70
- self.transprobs[state_value] = this_state_transition_dict
71
+ self._transprobs[state_value] = this_state_transition_dict
71
72
 
72
73
  def _get_probs_for_eachstate(self, action_value: DiscreteActionValueType) -> Dict[DiscreteStateValueType, NextStateTuple]:
73
74
  state_nexttuples = {}
74
- for state_value, action_nexttuples_pair in self.transprobs.items():
75
+ for state_value, action_nexttuples_pair in self._transprobs.items():
75
76
  for this_action_value, nexttuples in action_nexttuples_pair.items():
76
77
  if this_action_value == action_value:
77
78
  state_nexttuples[state_value] = nexttuples
@@ -92,16 +93,17 @@ class TransitionProbabilityFactory:
92
93
  def _generate_individual_reward_function(self) -> IndividualRewardFunction:
93
94
 
94
95
  def _individual_reward_function(state_value, action_value, next_state_value) -> float:
95
- if state_value in self.transprobs.keys():
96
- if action_value in self.transprobs[state_value].keys():
97
- for next_tuple in self.transprobs[state_value][action_value]:
98
- if next_tuple.next_state_value == next_state_value:
99
- return next_tuple.reward
100
- return 0.0
101
- else:
102
- return 0.0
103
- else:
104
- return 0.0
96
+ if state_value not in self._transprobs.keys():
97
+ return 0.
98
+
99
+ if action_value not in self._transprobs[state_value].keys():
100
+ return 0.
101
+
102
+ reward = 0.
103
+ for next_tuple in self._transprobs[state_value][action_value]:
104
+ if next_tuple.next_state_value == next_state_value:
105
+ reward += next_tuple.reward
106
+ return reward
105
107
 
106
108
  class ThisIndividualRewardFunction(IndividualRewardFunction):
107
109
  def __init__(self):
@@ -112,10 +114,27 @@ class TransitionProbabilityFactory:
112
114
 
113
115
  return ThisIndividualRewardFunction()
114
116
 
115
- def generate_mdp_objects(self) -> Tuple[DiscreteState, dict, IndividualRewardFunction]:
116
- state = DiscreteState(self.all_state_values)
117
+ def get_probability(self, state_value, action_value, new_state_value) -> float:
118
+ if state_value not in self._transprobs.keys():
119
+ return 0.
120
+
121
+ if action_value not in self._transprobs[state_value]:
122
+ return 0.
123
+
124
+ probs = 0.
125
+ for next_state_tuple in self._transprobs[state_value][action_value]:
126
+ if next_state_tuple.next_state_value == new_state_value:
127
+ probs += next_state_tuple.probability
128
+ return probs
129
+
130
+ @property
131
+ def transition_probabilities(self) -> dict:
132
+ return self._transprobs
133
+
134
+ def generate_mdp_objects(self) -> Tuple[DiscreteState, Dict[DiscreteActionValueType, Action], IndividualRewardFunction]:
135
+ state = DiscreteState(self._all_state_values)
117
136
  actions_dict = {}
118
- for action_value in self.all_action_values:
137
+ for action_value in self._all_action_values:
119
138
  state_nexttuple = self._get_probs_for_eachstate(action_value)
120
139
  actions_dict[action_value] = Action(self._generate_action_function(state_nexttuple))
121
140
 
@@ -123,3 +142,24 @@ class TransitionProbabilityFactory:
123
142
 
124
143
  return state, actions_dict, individual_reward_fcn
125
144
 
145
+ @property
146
+ def objects_generated(self) -> bool:
147
+ return self._objects_generated
148
+
149
+
150
+ class OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory(TransitionProbabilityFactory):
151
+ def __init__(self, envname):
152
+ super().__init__()
153
+ self.gymenv = gym.make(envname)
154
+ self._convert_openai_gymenv_to_transprob()
155
+
156
+ def _convert_openai_gymenv_to_transprob(self):
157
+ P = self.gymenv.env.P
158
+ for state_value, trans_dict in P.items():
159
+ new_trans_dict = {}
160
+ for action_value, next_state_list in trans_dict.items():
161
+ new_trans_dict[action_value] = [
162
+ NextStateTuple(next_state[1], next_state[0], next_state[2], next_state[3])
163
+ for next_state in next_state_list
164
+ ]
165
+ self.add_state_transitions(state_value, new_trans_dict)
@@ -0,0 +1,144 @@
1
+
2
+ import random
3
+ from copy import copy
4
+ from typing import Tuple, Dict
5
+ from itertools import product
6
+
7
+ import numpy as np
8
+
9
+ from .state import DiscreteStateValueType
10
+ from .transition import TransitionProbabilityFactory
11
+ from .policy import DiscreteDeterminsticPolicy
12
+
13
+
14
+ class OptimalPolicyOnValueFunctions:
15
+ def __init__(self, discount_factor: float, transprobfac: TransitionProbabilityFactory):
16
+ try:
17
+ assert discount_factor >= 0. and discount_factor <= 1.
18
+ except AssertionError:
19
+ raise ValueError('Discount factor must be between 0 and 1.')
20
+ self._gamma = discount_factor
21
+ self._transprobfac = transprobfac
22
+ self._states, self._actions_dict, self._indrewardfcn = self._transprobfac.generate_mdp_objects()
23
+ self._state_names = self._states.get_all_possible_state_values()
24
+ self._states_to_indices = {state: idx for idx, state in enumerate(self._state_names)}
25
+ self._action_names = list(self._actions_dict.keys())
26
+ self._actions_to_indices = {action_value: idx for idx, action_value in enumerate(self._action_names)}
27
+
28
+ self._evaluated = False
29
+ self._improved = False
30
+
31
+ self._theta = 1e-10
32
+ self._policy_evaluation_maxiter = 10000
33
+
34
+ def _policy_evaluation(self, policy: DiscreteDeterminsticPolicy) -> np.ndarray:
35
+ prev_V = np.zeros(len(self._states_to_indices))
36
+
37
+ for _ in range(self._policy_evaluation_maxiter):
38
+ V = np.zeros(len(self._states_to_indices))
39
+ for state_value in self._state_names:
40
+ state_index = self._states_to_indices[state_value]
41
+ action_value = policy.get_action_value(state_value)
42
+ for next_state_tuple in self._transprobfac.transition_probabilities[state_value][action_value]:
43
+ prob = next_state_tuple.probability
44
+ reward = next_state_tuple.reward
45
+ next_state_value = next_state_tuple.next_state_value
46
+ next_state_index = self._states_to_indices[next_state_value]
47
+ terminal = next_state_tuple.terminal
48
+
49
+ V[state_index] += prob * (reward + (self._gamma*prev_V[next_state_index] if not terminal else 0.))
50
+
51
+ if np.max(np.abs(V-prev_V)) < self._theta:
52
+ break
53
+
54
+ prev_V = V.copy()
55
+
56
+ return V
57
+
58
+ def _policy_improvement(self, V: np.ndarray) -> DiscreteDeterminsticPolicy:
59
+ Q = np.zeros((len(self._states_to_indices), len(self._actions_to_indices)))
60
+
61
+ for state_value in self._state_names:
62
+ state_index = self._states_to_indices[state_value]
63
+ for action_value in self._action_names:
64
+ action_index = self._actions_to_indices[action_value]
65
+ for next_state_tuple in self._transprobfac.transition_probabilities[state_value][action_value]:
66
+ prob = next_state_tuple.probability
67
+ reward = next_state_tuple.reward
68
+ next_state_value = next_state_tuple.next_state_value
69
+ next_state_index = self._states_to_indices[next_state_value]
70
+ terminal = next_state_tuple.terminal
71
+
72
+ Q[state_index, action_index] += prob * (reward + (self._gamma*V[next_state_index] if not terminal else 0.))
73
+
74
+ optimal_policy = DiscreteDeterminsticPolicy(self._actions_dict)
75
+ optimal_action_indices = np.argmax(Q, axis=1)
76
+ for state_value, action_index in zip(self._state_names, optimal_action_indices):
77
+ action_value = self._action_names[action_index]
78
+ optimal_policy.add_deterministic_rule(state_value, action_value)
79
+ return optimal_policy
80
+
81
+ def _policy_iteration(self) -> Tuple[np.ndarray, DiscreteDeterminsticPolicy]:
82
+ policy = DiscreteDeterminsticPolicy(self._actions_dict)
83
+ for state_value in self._state_names:
84
+ policy.add_deterministic_rule(state_value, random.choice(self._action_names))
85
+ V = None
86
+
87
+ done = False
88
+ while not done:
89
+ old_policy = copy(policy)
90
+
91
+ V = self._policy_evaluation(policy)
92
+ policy = self._policy_improvement(V)
93
+
94
+ if policy == old_policy:
95
+ done = True
96
+
97
+ return V, policy
98
+
99
+
100
+ def _value_iteration(self) -> Tuple[np.ndarray, DiscreteDeterminsticPolicy]:
101
+ V = np.zeros(len(self._state_names))
102
+
103
+ for _ in range(self._policy_evaluation_maxiter):
104
+ Q = np.zeros((len(self._state_names), len(self._action_names)))
105
+ for state_value, action_value in product(self._state_names, self._action_names):
106
+ state_index = self._states_to_indices[state_value]
107
+ action_index = self._actions_to_indices[action_value]
108
+ for next_state_tuple in self._transprobfac.transition_probabilities[state_value][action_value]:
109
+ prob = next_state_tuple.probability
110
+ reward = next_state_tuple.reward
111
+ next_state_value = next_state_tuple.next_state_value
112
+ next_state_index = self._states_to_indices[next_state_value]
113
+ terminal = next_state_tuple.terminal
114
+
115
+ Q[state_index, action_index] += prob * (reward + (self._gamma * V[next_state_index] if not terminal else 0.))
116
+
117
+ if np.max(np.abs(V-np.max(Q, axis=1))) < self._theta:
118
+ break
119
+
120
+ V = np.max(Q, axis=1)
121
+
122
+ Qmaxj = np.argmax(Q, axis=1)
123
+
124
+ policy = DiscreteDeterminsticPolicy(self._actions_dict)
125
+ for state_value, action_index in zip(self._state_names, Qmaxj):
126
+ policy.add_deterministic_rule(state_value, self._action_names[action_index])
127
+
128
+ return V, policy
129
+
130
+ def policy_iteration(self) -> Tuple[Dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
131
+ V, policy = self._policy_iteration()
132
+ state_values_dict = {
133
+ self._state_names[i]: V[i]
134
+ for i in range(V.shape[0])
135
+ }
136
+ return state_values_dict, policy
137
+
138
+ def value_iteration(self) -> Tuple[Dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
139
+ V, policy = self._value_iteration()
140
+ state_values_dict = {
141
+ self._state_names[i]: V[i]
142
+ for i in range(V.shape[0])
143
+ }
144
+ return state_values_dict, policy
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pyrlutils
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: Utility and Helpers for Reinformcement Learning
5
5
  Home-page: https://github.com/stephenhky/PyRLUtils
6
6
  Author: Kwan-Yuet Ho
7
7
  Author-email: stephenhky@yahoo.com.hk
8
- License: LGPL
8
+ License: MIT
9
9
  Keywords: machine learning,reinforcement leaning,artifiial intelligence
10
10
  Platform: UNKNOWN
11
11
  Classifier: Topic :: Scientific/Engineering :: Mathematics
@@ -5,16 +5,20 @@ setup.py
5
5
  pyrlutils/__init__.py
6
6
  pyrlutils/action.py
7
7
  pyrlutils/policy.py
8
+ pyrlutils/reward.py
8
9
  pyrlutils/state.py
9
10
  pyrlutils/transition.py
10
- pyrlutils/values.py
11
+ pyrlutils/valuefcns.py
11
12
  pyrlutils.egg-info/PKG-INFO
12
13
  pyrlutils.egg-info/SOURCES.txt
13
14
  pyrlutils.egg-info/dependency_links.txt
14
15
  pyrlutils.egg-info/not-zip-safe
15
16
  pyrlutils.egg-info/requires.txt
16
17
  pyrlutils.egg-info/top_level.txt
18
+ test/test_2ddiscrete.py
19
+ test/test_2dmaze.py
17
20
  test/test_action.py
18
21
  test/test_continous_state_actions.py
22
+ test/test_frozenlake.py
19
23
  test/test_state.py
20
24
  test/test_transprobs.py
@@ -18,7 +18,7 @@ def package_description():
18
18
 
19
19
  setup(
20
20
  name='pyrlutils',
21
- version="0.0.1",
21
+ version="0.0.2",
22
22
  description="Utility and Helpers for Reinformcement Learning",
23
23
  long_description=package_description(),
24
24
  long_description_content_type='text/markdown',
@@ -38,7 +38,7 @@ setup(
38
38
  url="https://github.com/stephenhky/PyRLUtils",
39
39
  author="Kwan-Yuet Ho",
40
40
  author_email="stephenhky@yahoo.com.hk",
41
- license='LGPL',
41
+ license='MIT',
42
42
  packages=[
43
43
  'pyrlutils'
44
44
  ],
@@ -0,0 +1,20 @@
1
+
2
+ import unittest
3
+
4
+ from pyrlutils.state import Discrete2DCartesianState
5
+
6
+
7
+ class Test2DDiscreteState(unittest.TestCase):
8
+ def test_twobythree(self):
9
+ state = Discrete2DCartesianState(0, 1, 0, 2)
10
+
11
+ assert state.state_space_size == 6
12
+ assert state.state_value == 0
13
+
14
+ state.set_state_value(5)
15
+ assert state.decode_coordinates(state.state_value) == [1, 2]
16
+
17
+
18
+
19
+ if __name__ == '__main__':
20
+ unittest.main()
@@ -0,0 +1,341 @@
1
+
2
+ import unittest
3
+
4
+ from pyrlutils.transition import TransitionProbabilityFactory, NextStateTuple
5
+ from pyrlutils.valuefcns import OptimalPolicyOnValueFunctions
6
+ from pyrlutils.state import Discrete2DCartesianState
7
+
8
+
9
+ class Test2DMaze(unittest.TestCase):
10
+ def setUp(self):
11
+ maze_state = Discrete2DCartesianState(0, 5, 0, 4, initial_coordinate=[0, 0])
12
+
13
+ transprobfactory = TransitionProbabilityFactory()
14
+ transprobfactory.add_state_transitions(
15
+ maze_state.encode_coordinates([0, 0]),
16
+ {
17
+ 'up': [NextStateTuple(maze_state.encode_coordinates([0, 0]), 1., 0., False)],
18
+ 'down': [NextStateTuple(maze_state.encode_coordinates([0, 0]), 1., 0., False)],
19
+ 'left': [NextStateTuple(maze_state.encode_coordinates([0, 0]), 1., 0., False)],
20
+ 'right': [NextStateTuple(maze_state.encode_coordinates([0, 1]), 1., 0., False)]
21
+ }
22
+ )
23
+ transprobfactory.add_state_transitions(
24
+ maze_state.encode_coordinates([0, 1]),
25
+ {
26
+ 'up': [NextStateTuple(maze_state.encode_coordinates([0, 1]), 1., 0., False)],
27
+ 'down': [NextStateTuple(maze_state.encode_coordinates([1, 1]), 1., 0., False)],
28
+ 'left': [NextStateTuple(maze_state.encode_coordinates([0, 0]), 1., 0., False)],
29
+ 'right': [NextStateTuple(maze_state.encode_coordinates([0, 2]), 1., 0., False)]
30
+ }
31
+ )
32
+ transprobfactory.add_state_transitions(
33
+ maze_state.encode_coordinates([0, 2]),
34
+ {
35
+ 'up': [NextStateTuple(maze_state.encode_coordinates([0, 2]), 1., 0., False)],
36
+ 'down': [NextStateTuple(maze_state.encode_coordinates([0, 2]), 1., 0., False)],
37
+ 'left': [NextStateTuple(maze_state.encode_coordinates([0, 1]), 1., 0., False)],
38
+ 'right': [NextStateTuple(maze_state.encode_coordinates([0, 3]), 1., 0., False)]
39
+ }
40
+ )
41
+ transprobfactory.add_state_transitions(
42
+ maze_state.encode_coordinates([0, 3]),
43
+ {
44
+ 'up': [NextStateTuple(maze_state.encode_coordinates([0, 3]), 1., 0., False)],
45
+ 'down': [NextStateTuple(maze_state.encode_coordinates([1, 3]), 1., 0., False)],
46
+ 'left': [NextStateTuple(maze_state.encode_coordinates([0, 2]), 1., 0., False)],
47
+ 'right': [NextStateTuple(maze_state.encode_coordinates([0, 4]), 1., 0., False)]
48
+ }
49
+ )
50
+ transprobfactory.add_state_transitions(
51
+ maze_state.encode_coordinates([0, 4]),
52
+ {
53
+ 'up': [NextStateTuple(maze_state.encode_coordinates([0, 4]), 1., 0., False)],
54
+ 'down': [NextStateTuple(maze_state.encode_coordinates([1, 4]), 1., 0., False)],
55
+ 'left': [NextStateTuple(maze_state.encode_coordinates([0, 3]), 1., 0., False)],
56
+ 'right': [NextStateTuple(maze_state.encode_coordinates([0, 4]), 1., 0., False)]
57
+ }
58
+ )
59
+ transprobfactory.add_state_transitions(
60
+ maze_state.encode_coordinates([1, 0]),
61
+ {
62
+ 'up': [NextStateTuple(maze_state.encode_coordinates([1, 0]), 1., 0., False)],
63
+ 'down': [NextStateTuple(maze_state.encode_coordinates([2, 0]), 1., 0., False)],
64
+ 'left': [NextStateTuple(maze_state.encode_coordinates([1, 0]), 1., 0., False)],
65
+ 'right': [NextStateTuple(maze_state.encode_coordinates([1, 1]), 1., 0., False)]
66
+ }
67
+ )
68
+ transprobfactory.add_state_transitions(
69
+ maze_state.encode_coordinates([1, 1]),
70
+ {
71
+ 'up': [NextStateTuple(maze_state.encode_coordinates([0, 1]), 1., 0., False)],
72
+ 'down': [NextStateTuple(maze_state.encode_coordinates([1, 1]), 1., 0., False)],
73
+ 'left': [NextStateTuple(maze_state.encode_coordinates([1, 0]), 1., 0., False)],
74
+ 'right': [NextStateTuple(maze_state.encode_coordinates([1, 1]), 1., 0., False)]
75
+ }
76
+ )
77
+ transprobfactory.add_state_transitions(
78
+ maze_state.encode_coordinates([1, 2]),
79
+ {
80
+ 'up': [NextStateTuple(maze_state.encode_coordinates([1, 2]), 1., 0., False)],
81
+ 'down': [NextStateTuple(maze_state.encode_coordinates([2, 2]), 1., 0., False)],
82
+ 'left': [NextStateTuple(maze_state.encode_coordinates([1, 2]), 1., 0., False)],
83
+ 'right': [NextStateTuple(maze_state.encode_coordinates([1, 3]), 1., 0., False)]
84
+ }
85
+ )
86
+ transprobfactory.add_state_transitions(
87
+ maze_state.encode_coordinates([1, 3]),
88
+ {
89
+ 'up': [NextStateTuple(maze_state.encode_coordinates([0, 3]), 1., 0., False)],
90
+ 'down': [NextStateTuple(maze_state.encode_coordinates([1, 3]), 1., 0., False)],
91
+ 'left': [NextStateTuple(maze_state.encode_coordinates([1, 2]), 1., 0., False)],
92
+ 'right': [NextStateTuple(maze_state.encode_coordinates([1, 4]), 1., 0., False)]
93
+ }
94
+ )
95
+ transprobfactory.add_state_transitions(
96
+ maze_state.encode_coordinates([1, 4]),
97
+ {
98
+ 'up': [NextStateTuple(maze_state.encode_coordinates([0, 4]), 1., 0., False)],
99
+ 'down': [NextStateTuple(maze_state.encode_coordinates([1, 4]), 1., 0., False)],
100
+ 'left': [NextStateTuple(maze_state.encode_coordinates([1, 4]), 1., 0., False)],
101
+ 'right': [NextStateTuple(maze_state.encode_coordinates([1, 4]), 1., 0., False)]
102
+ }
103
+ )
104
+ transprobfactory.add_state_transitions(
105
+ maze_state.encode_coordinates([2, 0]),
106
+ {
107
+ 'up': [NextStateTuple(maze_state.encode_coordinates([1, 0]), 1., 0., False)],
108
+ 'down': [NextStateTuple(maze_state.encode_coordinates([3, 0]), 1., 0., False)],
109
+ 'left': [NextStateTuple(maze_state.encode_coordinates([2, 0]), 1., 0., False)],
110
+ 'right': [NextStateTuple(maze_state.encode_coordinates([2, 1]), 1., 0., False)]
111
+ }
112
+ )
113
+ transprobfactory.add_state_transitions(
114
+ maze_state.encode_coordinates([2, 1]),
115
+ {
116
+ 'up': [NextStateTuple(maze_state.encode_coordinates([2, 1]), 1., 0., False)],
117
+ 'down': [NextStateTuple(maze_state.encode_coordinates([2, 1]), 1., 0., False)],
118
+ 'left': [NextStateTuple(maze_state.encode_coordinates([2, 0]), 1., 0., False)],
119
+ 'right': [NextStateTuple(maze_state.encode_coordinates([2, 2]), 1., 0., False)]
120
+ }
121
+ )
122
+ transprobfactory.add_state_transitions(
123
+ maze_state.encode_coordinates([2, 2]),
124
+ {
125
+ 'up': [NextStateTuple(maze_state.encode_coordinates([1, 2]), 1., 0., False)],
126
+ 'down': [NextStateTuple(maze_state.encode_coordinates([2, 2]), 1., 0., False)],
127
+ 'left': [NextStateTuple(maze_state.encode_coordinates([2, 1]), 1., 0., False)],
128
+ 'right': [NextStateTuple(maze_state.encode_coordinates([2, 3]), 1., 0., False)]
129
+ }
130
+ )
131
+ transprobfactory.add_state_transitions(
132
+ maze_state.encode_coordinates([2, 3]),
133
+ {
134
+ 'up': [NextStateTuple(maze_state.encode_coordinates([2, 3]), 1., 0., False)],
135
+ 'down': [NextStateTuple(maze_state.encode_coordinates([3, 3]), 1., 0., False)],
136
+ 'left': [NextStateTuple(maze_state.encode_coordinates([2, 2]), 1., 0., False)],
137
+ 'right': [NextStateTuple(maze_state.encode_coordinates([2, 3]), 1., 0., False)]
138
+ }
139
+ )
140
+ transprobfactory.add_state_transitions(
141
+ maze_state.encode_coordinates([2, 4]),
142
+ {
143
+ 'up': [NextStateTuple(maze_state.encode_coordinates([2, 4]), 1., 0., False)],
144
+ 'down': [NextStateTuple(maze_state.encode_coordinates([3, 4]), 1., 0., False)],
145
+ 'left': [NextStateTuple(maze_state.encode_coordinates([2, 4]), 1., 0., False)],
146
+ 'right': [NextStateTuple(maze_state.encode_coordinates([2, 4]), 1., 0., False)]
147
+ }
148
+ )
149
+ transprobfactory.add_state_transitions(
150
+ maze_state.encode_coordinates([3, 0]),
151
+ {
152
+ 'up': [NextStateTuple(maze_state.encode_coordinates([2, 0]), 1., 0., False)],
153
+ 'down': [NextStateTuple(maze_state.encode_coordinates([4, 0]), 1., 0., False)],
154
+ 'left': [NextStateTuple(maze_state.encode_coordinates([3, 0]), 1., 0., False)],
155
+ 'right': [NextStateTuple(maze_state.encode_coordinates([3, 1]), 1., 0., False)]
156
+ }
157
+ )
158
+ transprobfactory.add_state_transitions(
159
+ maze_state.encode_coordinates([3, 1]),
160
+ {
161
+ 'up': [NextStateTuple(maze_state.encode_coordinates([3, 1]), 1., 0., False)],
162
+ 'down': [NextStateTuple(maze_state.encode_coordinates([3, 1]), 1., 0., False)],
163
+ 'left': [NextStateTuple(maze_state.encode_coordinates([3, 0]), 1., 0., False)],
164
+ 'right': [NextStateTuple(maze_state.encode_coordinates([3, 2]), 1., 0., False)]
165
+ }
166
+ )
167
+ transprobfactory.add_state_transitions(
168
+ maze_state.encode_coordinates([3, 2]),
169
+ {
170
+ 'up': [NextStateTuple(maze_state.encode_coordinates([3, 2]), 1., 0., False)],
171
+ 'down': [NextStateTuple(maze_state.encode_coordinates([4, 2]), 1., 0., False)],
172
+ 'left': [NextStateTuple(maze_state.encode_coordinates([3, 1]), 1., 0., False)],
173
+ 'right': [NextStateTuple(maze_state.encode_coordinates([3, 3]), 1., 0., False)]
174
+ }
175
+ )
176
+ transprobfactory.add_state_transitions(
177
+ maze_state.encode_coordinates([3, 3]),
178
+ {
179
+ 'up': [NextStateTuple(maze_state.encode_coordinates([2, 3]), 1., 0., False)],
180
+ 'down': [NextStateTuple(maze_state.encode_coordinates([4, 3]), 1., 0., False)],
181
+ 'left': [NextStateTuple(maze_state.encode_coordinates([3, 3]), 1., 0., False)],
182
+ 'right': [NextStateTuple(maze_state.encode_coordinates([3, 4]), 1., 0., False)]
183
+ }
184
+ )
185
+ transprobfactory.add_state_transitions(
186
+ maze_state.encode_coordinates([3, 4]),
187
+ {
188
+ 'up': [NextStateTuple(maze_state.encode_coordinates([2, 4]), 1., 0., False)],
189
+ 'down': [NextStateTuple(maze_state.encode_coordinates([3, 4]), 1., 0., False)],
190
+ 'left': [NextStateTuple(maze_state.encode_coordinates([3, 3]), 1., 0., False)],
191
+ 'right': [NextStateTuple(maze_state.encode_coordinates([3, 4]), 1., 0., False)]
192
+ }
193
+ )
194
+ transprobfactory.add_state_transitions(
195
+ maze_state.encode_coordinates([4, 0]),
196
+ {
197
+ 'up': [NextStateTuple(maze_state.encode_coordinates([3, 0]), 1., 0., False)],
198
+ 'down': [NextStateTuple(maze_state.encode_coordinates([5, 0]), 1., 0., False)],
199
+ 'left': [NextStateTuple(maze_state.encode_coordinates([4, 0]), 1., 0., False)],
200
+ 'right': [NextStateTuple(maze_state.encode_coordinates([4, 0]), 1., 0., False)]
201
+ }
202
+ )
203
+ transprobfactory.add_state_transitions(
204
+ maze_state.encode_coordinates([4, 1]),
205
+ {
206
+ 'up': [NextStateTuple(maze_state.encode_coordinates([4, 1]), 1., 0., False)],
207
+ 'down': [NextStateTuple(maze_state.encode_coordinates([5, 1]), 1., 0., False)],
208
+ 'left': [NextStateTuple(maze_state.encode_coordinates([4, 1]), 1., 0., False)],
209
+ 'right': [NextStateTuple(maze_state.encode_coordinates([4, 1]), 1., 0., False)]
210
+ }
211
+ )
212
+ transprobfactory.add_state_transitions(
213
+ maze_state.encode_coordinates([4, 2]),
214
+ {
215
+ 'up': [NextStateTuple(maze_state.encode_coordinates([3, 2]), 1., 0., False)],
216
+ 'down': [NextStateTuple(maze_state.encode_coordinates([4, 2]), 1., 0., False)],
217
+ 'left': [NextStateTuple(maze_state.encode_coordinates([4, 2]), 1., 0., False)],
218
+ 'right': [NextStateTuple(maze_state.encode_coordinates([4, 3]), 1., 0., False)]
219
+ }
220
+ )
221
+ transprobfactory.add_state_transitions(
222
+ maze_state.encode_coordinates([4, 3]),
223
+ {
224
+ 'up': [NextStateTuple(maze_state.encode_coordinates([3, 3]), 1., 0., False)],
225
+ 'down': [NextStateTuple(maze_state.encode_coordinates([4, 3]), 1., 0., False)],
226
+ 'left': [NextStateTuple(maze_state.encode_coordinates([4, 2]), 1., 0., False)],
227
+ 'right': [NextStateTuple(maze_state.encode_coordinates([4, 4]), 1., 0., False)]
228
+ }
229
+ )
230
+ transprobfactory.add_state_transitions(
231
+ maze_state.encode_coordinates([4, 4]),
232
+ {
233
+ 'up': [NextStateTuple(maze_state.encode_coordinates([4, 4]), 1., 0., False)],
234
+ 'down': [NextStateTuple(maze_state.encode_coordinates([5, 4]), 1., 1., True)],
235
+ 'left': [NextStateTuple(maze_state.encode_coordinates([4, 3]), 1., 0., False)],
236
+ 'right': [NextStateTuple(maze_state.encode_coordinates([4, 4]), 1., 0., False)]
237
+ }
238
+ )
239
+ transprobfactory.add_state_transitions(
240
+ maze_state.encode_coordinates([5, 0]),
241
+ {
242
+ 'up': [NextStateTuple(maze_state.encode_coordinates([4, 0]), 1., 0., False)],
243
+ 'down': [NextStateTuple(maze_state.encode_coordinates([5, 0]), 1., 0., False)],
244
+ 'left': [NextStateTuple(maze_state.encode_coordinates([5, 0]), 1., 0., False)],
245
+ 'right': [NextStateTuple(maze_state.encode_coordinates([5, 1]), 1., 0., False)]
246
+ }
247
+ )
248
+ transprobfactory.add_state_transitions(
249
+ maze_state.encode_coordinates([5, 1]),
250
+ {
251
+ 'up': [NextStateTuple(maze_state.encode_coordinates([4, 1]), 1., 0., False)],
252
+ 'down': [NextStateTuple(maze_state.encode_coordinates([5, 1]), 1., 0., False)],
253
+ 'left': [NextStateTuple(maze_state.encode_coordinates([5, 0]), 1., 0., False)],
254
+ 'right': [NextStateTuple(maze_state.encode_coordinates([5, 1]), 1., 0., False)]
255
+ }
256
+ )
257
+ transprobfactory.add_state_transitions(
258
+ maze_state.encode_coordinates([5, 2]),
259
+ {
260
+ 'up': [NextStateTuple(maze_state.encode_coordinates([5, 2]), 1., 0., False)],
261
+ 'down': [NextStateTuple(maze_state.encode_coordinates([5, 2]), 1., 0., False)],
262
+ 'left': [NextStateTuple(maze_state.encode_coordinates([5, 2]), 1., 0., False)],
263
+ 'right': [NextStateTuple(maze_state.encode_coordinates([5, 3]), 1., 0., False)]
264
+ }
265
+ )
266
+ transprobfactory.add_state_transitions(
267
+ maze_state.encode_coordinates([5, 3]),
268
+ {
269
+ 'up': [NextStateTuple(maze_state.encode_coordinates([5, 3]), 1., 0., False)],
270
+ 'down': [NextStateTuple(maze_state.encode_coordinates([5, 3]), 1., 0., False)],
271
+ 'left': [NextStateTuple(maze_state.encode_coordinates([5, 2]), 1., 0., False)],
272
+ 'right': [NextStateTuple(maze_state.encode_coordinates([5, 4]), 1., 1., True)]
273
+ }
274
+ )
275
+ transprobfactory.add_state_transitions(
276
+ maze_state.encode_coordinates([5, 4]),
277
+ {
278
+ 'up': [NextStateTuple(maze_state.encode_coordinates([4, 4]), 1., 0., False)],
279
+ 'down': [NextStateTuple(maze_state.encode_coordinates([5, 4]), 1., 1., True)],
280
+ 'left': [NextStateTuple(maze_state.encode_coordinates([5, 3]), 1., 0., False)],
281
+ 'right': [NextStateTuple(maze_state.encode_coordinates([5, 4]), 1., 1., True)]
282
+ }
283
+ )
284
+
285
+ self.transprobfactory = transprobfactory
286
+ self.maze_state = maze_state
287
+
288
+ def test_policy_iteration(self):
289
+ policy_finder = OptimalPolicyOnValueFunctions(0.85, self.transprobfactory)
290
+ values_dict, policy = policy_finder.policy_iteration()
291
+
292
+ for state_value, value in values_dict.items():
293
+ [x, y] = self.maze_state.decode_coordinates(state_value)
294
+ print('({}, {}): {}'.format(x, y, value))
295
+
296
+ state, actions_dict, _ = self.transprobfactory.generate_mdp_objects()
297
+
298
+ arrived_destination = False
299
+ for _ in range(state.state_space_size*2):
300
+ action_value = policy.get_action_value(state)
301
+ print('Action value: {}'.format(action_value))
302
+ action = policy.get_action(state)
303
+ state = action(state)
304
+
305
+ coordinates = self.maze_state.decode_coordinates(state.state_value)
306
+ print('at: {}, {}'.format(coordinates[0], coordinates[1]))
307
+ if coordinates[0] == 5 and coordinates[1] == 4:
308
+ arrived_destination = True
309
+ break
310
+
311
+ assert arrived_destination
312
+
313
+ def test_value_iteration(self):
314
+ policy_finder = OptimalPolicyOnValueFunctions(0.85, self.transprobfactory)
315
+ values_dict, policy = policy_finder.value_iteration()
316
+
317
+ for state_value, value in values_dict.items():
318
+ [x, y] = self.maze_state.decode_coordinates(state_value)
319
+ print('({}, {}): {}'.format(x, y, value))
320
+
321
+ state, actions_dict, _ = self.transprobfactory.generate_mdp_objects()
322
+
323
+ arrived_destination = False
324
+ for _ in range(state.state_space_size*2):
325
+ action_value = policy.get_action_value(state)
326
+ print('Action value: {}'.format(action_value))
327
+ action = policy.get_action(state)
328
+ state = action(state)
329
+
330
+ coordinates = self.maze_state.decode_coordinates(state.state_value)
331
+ print('at: {}, {}'.format(coordinates[0], coordinates[1]))
332
+ if coordinates[0] == 5 and coordinates[1] == 4:
333
+ arrived_destination = True
334
+ break
335
+
336
+ assert arrived_destination
337
+
338
+
339
+
340
+ if __name__ == '__main__':
341
+ unittest.main()
@@ -1,8 +1,6 @@
1
1
 
2
2
  import unittest
3
3
 
4
- import numpy as np
5
-
6
4
  from pyrlutils.state import DiscreteState
7
5
  from pyrlutils.action import Action
8
6
 
@@ -0,0 +1,29 @@
1
+
2
+ import unittest
3
+
4
+ from pyrlutils.transition import OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory
5
+
6
+ class TestFrozenLake(unittest.TestCase):
7
+ def test_factory(self):
8
+ tranprobfactory = OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory('FrozenLake-v1')
9
+ state, actions_dict, ind_reward_fcn = tranprobfactory.generate_mdp_objects()
10
+
11
+ assert len(state.get_all_possible_state_values()) == 16
12
+ assert state.state_value == 0
13
+
14
+ actions_dict[0](state)
15
+ assert state.state_value in {0, 4}
16
+
17
+ state.state_value = 15
18
+ actions_dict[2](state)
19
+ assert state.state_value == 15
20
+
21
+ assert ind_reward_fcn(0, 0, 0) == 0.0
22
+ assert ind_reward_fcn(14, 3, 15) == 1.0
23
+
24
+ assert abs(tranprobfactory.get_probability(0, 0, 0) - 0.66667) < 1e-4
25
+ assert abs(tranprobfactory.get_probability(14, 3, 15) - 0.33333) < 1e-4
26
+
27
+
28
+ if __name__ == '__main__':
29
+ unittest.main()
@@ -1,34 +0,0 @@
1
-
2
- from abc import ABC, abstractmethod
3
-
4
- from .state import State
5
- from .action import Action
6
-
7
-
8
- class Policy(ABC):
9
- @abstractmethod
10
- def get_action(self, state: State) -> Action:
11
- pass
12
-
13
- def __call__(self, state: State) -> Action:
14
- return self.get_action(state)
15
-
16
- @property
17
- def is_stochastic(self) -> bool:
18
- pass
19
-
20
-
21
- class DeterministicPolicy(Policy):
22
- @property
23
- def is_stochastic(self) -> bool:
24
- return False
25
-
26
-
27
- class StochasticPolicy(Policy):
28
- @abstractmethod
29
- def get_probability(self, state: State, action: Action) -> float:
30
- pass
31
-
32
- @property
33
- def is_stochastic(self) -> bool:
34
- return True
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes