pyrlutils 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyrlutils/__init__.py +0 -0
- pyrlutils/action.py +27 -0
- pyrlutils/bandit/__init__.py +0 -0
- pyrlutils/bandit/algo.py +128 -0
- pyrlutils/bandit/reward.py +12 -0
- pyrlutils/dp/__init__.py +0 -0
- pyrlutils/dp/valuefcns.py +149 -0
- pyrlutils/helpers/__init__.py +0 -0
- pyrlutils/helpers/exceptions.py +5 -0
- pyrlutils/openai/__init__.py +0 -0
- pyrlutils/openai/utils.py +31 -0
- pyrlutils/policy.py +151 -0
- pyrlutils/reward.py +37 -0
- pyrlutils/state.py +320 -0
- pyrlutils/td/__init__.py +0 -0
- pyrlutils/td/doubleqlearn.py +110 -0
- pyrlutils/td/qlearn.py +86 -0
- pyrlutils/td/sarsa.py +86 -0
- pyrlutils/td/state_td.py +111 -0
- pyrlutils/td/utils.py +258 -0
- pyrlutils/transition.py +155 -0
- pyrlutils-0.1.2.dist-info/METADATA +43 -0
- pyrlutils-0.1.2.dist-info/RECORD +26 -0
- pyrlutils-0.1.2.dist-info/WHEEL +5 -0
- pyrlutils-0.1.2.dist-info/licenses/LICENSE +19 -0
- pyrlutils-0.1.2.dist-info/top_level.txt +1 -0
pyrlutils/td/state_td.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Annotated
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from npdict import NumpyNDArrayWrappedDict
|
|
6
|
+
|
|
7
|
+
from .utils import decay_schedule, TimeDifferencePathElements, AbstractStateValueFunctionTemporalDifferenceLearner
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SingleStepTemporalDifferenceLearner(AbstractStateValueFunctionTemporalDifferenceLearner):
|
|
11
|
+
def learn(
|
|
12
|
+
self,
|
|
13
|
+
episodes: int
|
|
14
|
+
) -> tuple[Annotated[NumpyNDArrayWrappedDict, "1D Array"], Annotated[NumpyNDArrayWrappedDict, "2D Array"]]:
|
|
15
|
+
V = NumpyNDArrayWrappedDict(
|
|
16
|
+
[self._state.get_all_possible_state_values()],
|
|
17
|
+
default_initial_value=0.0
|
|
18
|
+
)
|
|
19
|
+
V_track = NumpyNDArrayWrappedDict(
|
|
20
|
+
[list(range(episodes)), self._state.get_all_possible_state_values()],
|
|
21
|
+
default_initial_value=0.0
|
|
22
|
+
)
|
|
23
|
+
V_array, V_track_array = V.to_numpy(), V_track.to_numpy()
|
|
24
|
+
alphas = decay_schedule(
|
|
25
|
+
self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
for i in range(episodes):
|
|
29
|
+
self._state.state_index = self.initial_state_index
|
|
30
|
+
done = False
|
|
31
|
+
while not done:
|
|
32
|
+
old_state_value = self._state.state_value
|
|
33
|
+
action_value = self._policy.get_action_value(self._state.state_value)
|
|
34
|
+
action_func = self._actions_dict[action_value]
|
|
35
|
+
self._state = action_func(self._state)
|
|
36
|
+
new_state_value = self._state.state_value
|
|
37
|
+
reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
|
|
38
|
+
done = self._state.is_terminal
|
|
39
|
+
|
|
40
|
+
td_target = reward + self.gamma * V[new_state_value] * (not done)
|
|
41
|
+
td_error = td_target - V[old_state_value]
|
|
42
|
+
V[old_state_value] = V[old_state_value] + alphas[i] * td_error
|
|
43
|
+
|
|
44
|
+
V_track_array[i, :] = V_array
|
|
45
|
+
|
|
46
|
+
return V, V_track
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class MultipleStepTemporalDifferenceLearner(AbstractStateValueFunctionTemporalDifferenceLearner):
|
|
50
|
+
def learn(
|
|
51
|
+
self,
|
|
52
|
+
episodes: int,
|
|
53
|
+
n_steps: int=3
|
|
54
|
+
) -> tuple[Annotated[NumpyNDArrayWrappedDict, "1D Array"], Annotated[NumpyNDArrayWrappedDict, "2D Array"]]:
|
|
55
|
+
V = NumpyNDArrayWrappedDict(
|
|
56
|
+
[self._state.get_all_possible_state_values()],
|
|
57
|
+
default_initial_value=0.0
|
|
58
|
+
)
|
|
59
|
+
V_track = NumpyNDArrayWrappedDict(
|
|
60
|
+
[list(range(episodes)), self._state.get_all_possible_state_values()],
|
|
61
|
+
default_initial_value=0.0
|
|
62
|
+
)
|
|
63
|
+
V_array, V_track_array = V.to_numpy(), V_track.to_numpy()
|
|
64
|
+
alphas = decay_schedule(
|
|
65
|
+
self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
|
|
66
|
+
)
|
|
67
|
+
discounts = np.logspace(0, n_steps-1, num=n_steps+1, base=self.gamma, endpoint=False)
|
|
68
|
+
|
|
69
|
+
for i in range(episodes):
|
|
70
|
+
self._state.state_index = self.initial_state_index
|
|
71
|
+
done = False
|
|
72
|
+
path = []
|
|
73
|
+
|
|
74
|
+
while not done or path is not None:
|
|
75
|
+
path = path[1:] # worth revisiting this line
|
|
76
|
+
|
|
77
|
+
new_state_value = self._state._get_state_value_from_index(self._state.nb_state_values-1)
|
|
78
|
+
while not done and len(path) < n_steps:
|
|
79
|
+
old_state_value = self._state.state_value
|
|
80
|
+
action_value = self._policy.get_action_value(self._state.state_value)
|
|
81
|
+
action_func = self._actions_dict[action_value]
|
|
82
|
+
self._state = action_func(self._state)
|
|
83
|
+
new_state_value = self._state.state_value
|
|
84
|
+
reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
|
|
85
|
+
done = self._state.is_terminal
|
|
86
|
+
|
|
87
|
+
path.append(
|
|
88
|
+
TimeDifferencePathElements(
|
|
89
|
+
this_state_value=old_state_value,
|
|
90
|
+
reward=reward,
|
|
91
|
+
next_state_value=new_state_value,
|
|
92
|
+
done=done
|
|
93
|
+
)
|
|
94
|
+
)
|
|
95
|
+
if done:
|
|
96
|
+
break
|
|
97
|
+
|
|
98
|
+
n = len(path)
|
|
99
|
+
estimated_state_value = path[0].this_state_value
|
|
100
|
+
rewards = np.array([this_moment.reward for this_moment in path])
|
|
101
|
+
partial_return = discounts[n:] * rewards
|
|
102
|
+
bs_val = discounts[-1] * V[new_state_value] * (not done)
|
|
103
|
+
ntd_target = np.sum(np.append(partial_return, bs_val))
|
|
104
|
+
ntd_error = ntd_target - V[estimated_state_value]
|
|
105
|
+
V[(estimated_state_value,)] = V[estimated_state_value] + alphas[i] * ntd_error
|
|
106
|
+
if len(path) == 1 and path[0].done:
|
|
107
|
+
path = None
|
|
108
|
+
|
|
109
|
+
V_track_array[i, :] = V_array
|
|
110
|
+
|
|
111
|
+
return V, V_track
|
pyrlutils/td/utils.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Annotated, Union, Optional
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
from npdict import NumpyNDArrayWrappedDict
|
|
9
|
+
|
|
10
|
+
from ..state import DiscreteStateValueType
|
|
11
|
+
from ..action import DiscreteActionValueType
|
|
12
|
+
from ..policy import DiscretePolicy
|
|
13
|
+
from ..transition import TransitionProbabilityFactory
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def decay_schedule(
|
|
17
|
+
init_value: float,
|
|
18
|
+
min_value: float,
|
|
19
|
+
decay_ratio: float,
|
|
20
|
+
max_steps: int,
|
|
21
|
+
log_start: int=-2,
|
|
22
|
+
log_base: int=10
|
|
23
|
+
) -> Annotated[NDArray[np.float64], "1D Array"]:
|
|
24
|
+
decay_steps = int(max_steps*decay_ratio)
|
|
25
|
+
rem_steps = max_steps - decay_steps
|
|
26
|
+
|
|
27
|
+
values = np.logspace(log_start, 0, decay_steps, base=log_base, endpoint=True)[::-1]
|
|
28
|
+
values = (values - values.min()) / (values.max() - values.min())
|
|
29
|
+
values = (init_value - min_value) * values + min_value
|
|
30
|
+
values = np.pad(values, (0, rem_steps), 'edge')
|
|
31
|
+
return values
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def select_action(
|
|
35
|
+
state_value: DiscreteStateValueType,
|
|
36
|
+
Q: Union[Annotated[NDArray[np.float64], "2D Array"], NumpyNDArrayWrappedDict],
|
|
37
|
+
epsilon: float,
|
|
38
|
+
) -> Union[DiscreteActionValueType, int]:
|
|
39
|
+
if np.random.random() <= epsilon:
|
|
40
|
+
if isinstance(Q, NumpyNDArrayWrappedDict):
|
|
41
|
+
return np.random.choice(Q._lists_keystrings[1])
|
|
42
|
+
else:
|
|
43
|
+
return np.random.choice(np.arange(Q.shape[1]))
|
|
44
|
+
|
|
45
|
+
q_matrix = Q.to_numpy() if isinstance(Q, NumpyNDArrayWrappedDict) else Q
|
|
46
|
+
state_index = Q.get_key_index(0, state_value) if isinstance(Q, NumpyNDArrayWrappedDict) else state_value
|
|
47
|
+
max_index = np.argmax(q_matrix[state_index, :])
|
|
48
|
+
|
|
49
|
+
if isinstance(Q, NumpyNDArrayWrappedDict):
|
|
50
|
+
return Q._lists_keystrings[1][max_index]
|
|
51
|
+
else:
|
|
52
|
+
return max_index
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass
|
|
56
|
+
class TimeDifferencePathElements:
|
|
57
|
+
this_state_value: DiscreteStateValueType
|
|
58
|
+
reward: float
|
|
59
|
+
next_state_value: DiscreteStateValueType
|
|
60
|
+
done: bool
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class AbstractStateValueFunctionTemporalDifferenceLearner(ABC):
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
transprobfac: TransitionProbabilityFactory,
|
|
67
|
+
gamma: float=1.0,
|
|
68
|
+
init_alpha: float=0.5,
|
|
69
|
+
min_alpha: float=0.01,
|
|
70
|
+
alpha_decay_ratio: float=0.3,
|
|
71
|
+
policy: Optional[DiscretePolicy]=None,
|
|
72
|
+
initial_state_index: int=0
|
|
73
|
+
):
|
|
74
|
+
self._gamma = gamma
|
|
75
|
+
self._init_alpha = init_alpha
|
|
76
|
+
self._min_alpha = min_alpha
|
|
77
|
+
try:
|
|
78
|
+
assert 0.0 <= alpha_decay_ratio <= 1.0
|
|
79
|
+
except AssertionError:
|
|
80
|
+
raise ValueError("alpha_decay_ratio must be between 0 and 1!")
|
|
81
|
+
self._alpha_decay_ratio = alpha_decay_ratio
|
|
82
|
+
self._transprobfac = transprobfac
|
|
83
|
+
self._state, self._actions_dict, self._indrewardfcn = self._transprobfac.generate_mdp_objects()
|
|
84
|
+
self._action_names = list(self._actions_dict.keys())
|
|
85
|
+
self._actions_to_indices = {action_value: idx for idx, action_value in enumerate(self._action_names)}
|
|
86
|
+
self._policy = policy
|
|
87
|
+
try:
|
|
88
|
+
assert 0 <= initial_state_index < self._state.nb_state_values
|
|
89
|
+
except AssertionError:
|
|
90
|
+
raise ValueError(f"Initial state index must be between 0 and {self._state.nb_state_values}")
|
|
91
|
+
self._init_state_index = initial_state_index
|
|
92
|
+
|
|
93
|
+
@abstractmethod
|
|
94
|
+
def learn(self, *args, **kwargs) -> tuple[Annotated[NDArray[np.float64], "1D Array"], Annotated[NDArray[np.float64], "2D Array"]]:
|
|
95
|
+
raise NotImplementedError()
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def nb_states(self) -> int:
|
|
99
|
+
return self._state.nb_state_values
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def policy(self) -> DiscretePolicy:
|
|
103
|
+
return self._policy
|
|
104
|
+
|
|
105
|
+
@policy.setter
|
|
106
|
+
def policy(self, val: DiscretePolicy):
|
|
107
|
+
self._policy = val
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def gamma(self) -> float:
|
|
111
|
+
return self._gamma
|
|
112
|
+
|
|
113
|
+
@gamma.setter
|
|
114
|
+
def gamma(self, val: float):
|
|
115
|
+
self._gamma = val
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def init_alpha(self) -> float:
|
|
119
|
+
return self._init_alpha
|
|
120
|
+
|
|
121
|
+
@init_alpha.setter
|
|
122
|
+
def init_alpha(self, val: float):
|
|
123
|
+
self._init_alpha = val
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def min_alpha(self) -> float:
|
|
127
|
+
return self._min_alpha
|
|
128
|
+
|
|
129
|
+
@min_alpha.setter
|
|
130
|
+
def min_alpha(self, val: float):
|
|
131
|
+
self._min_alpha = val
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def alpha_decay_ratio(self) -> float:
|
|
135
|
+
return self._alpha_decay_ratio
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def initial_state_index(self) -> int:
|
|
139
|
+
return self._init_state_index
|
|
140
|
+
|
|
141
|
+
@initial_state_index.setter
|
|
142
|
+
def initial_state_index(self, val: int):
|
|
143
|
+
self._init_state_index = val
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class AbstractStateActionValueFunctionTemporalDifferenceLearner(ABC):
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
transprobfac: TransitionProbabilityFactory,
|
|
151
|
+
gamma: float=1.0,
|
|
152
|
+
init_alpha: float=0.5,
|
|
153
|
+
min_alpha: float=0.01,
|
|
154
|
+
alpha_decay_ratio: float=0.3,
|
|
155
|
+
init_epsilon: float=1.0,
|
|
156
|
+
min_epsilon: float=0.1,
|
|
157
|
+
epsilon_decay_ratio: float=0.9,
|
|
158
|
+
policy: Optional[DiscretePolicy]=None,
|
|
159
|
+
initial_state_index: int=0
|
|
160
|
+
):
|
|
161
|
+
self._gamma = gamma
|
|
162
|
+
self._init_alpha = init_alpha
|
|
163
|
+
self._min_alpha = min_alpha
|
|
164
|
+
try:
|
|
165
|
+
assert 0.0 <= alpha_decay_ratio <= 1.0
|
|
166
|
+
except AssertionError:
|
|
167
|
+
raise ValueError("alpha_decay_ratio must be between 0 and 1!")
|
|
168
|
+
self._alpha_decay_ratio = alpha_decay_ratio
|
|
169
|
+
self._init_epsilon = init_epsilon
|
|
170
|
+
self._min_epsilon = min_epsilon
|
|
171
|
+
self._epsilon_decay_ratio = epsilon_decay_ratio
|
|
172
|
+
|
|
173
|
+
self._transprobfac = transprobfac
|
|
174
|
+
self._state, self._actions_dict, self._indrewardfcn = self._transprobfac.generate_mdp_objects()
|
|
175
|
+
self._action_names = list(self._actions_dict.keys())
|
|
176
|
+
self._actions_to_indices = {action_value: idx for idx, action_value in enumerate(self._action_names)}
|
|
177
|
+
self._policy = policy
|
|
178
|
+
try:
|
|
179
|
+
assert 0 <= initial_state_index < self._state.nb_state_values
|
|
180
|
+
except AssertionError:
|
|
181
|
+
raise ValueError(f"Initial state index must be between 0 and {self._state.nb_state_values}")
|
|
182
|
+
self._init_state_index = initial_state_index
|
|
183
|
+
|
|
184
|
+
@abstractmethod
|
|
185
|
+
def learn(self, *args, **kwargs) -> tuple[Annotated[NDArray[np.float64], "1D Array"], Annotated[NDArray[np.float64], "2D Array"]]:
|
|
186
|
+
raise NotImplementedError()
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def nb_states(self) -> int:
|
|
190
|
+
return self._state.nb_state_values
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def policy(self) -> DiscretePolicy:
|
|
194
|
+
return self._policy
|
|
195
|
+
|
|
196
|
+
@policy.setter
|
|
197
|
+
def policy(self, val: DiscretePolicy):
|
|
198
|
+
self._policy = val
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def gamma(self) -> float:
|
|
202
|
+
return self._gamma
|
|
203
|
+
|
|
204
|
+
@gamma.setter
|
|
205
|
+
def gamma(self, val: float):
|
|
206
|
+
self._gamma = val
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def init_alpha(self) -> float:
|
|
210
|
+
return self._init_alpha
|
|
211
|
+
|
|
212
|
+
@init_alpha.setter
|
|
213
|
+
def init_alpha(self, val: float):
|
|
214
|
+
self._init_alpha = val
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def min_alpha(self) -> float:
|
|
218
|
+
return self._min_alpha
|
|
219
|
+
|
|
220
|
+
@min_alpha.setter
|
|
221
|
+
def min_alpha(self, val: float):
|
|
222
|
+
self._min_alpha = val
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def alpha_decay_ratio(self) -> float:
|
|
226
|
+
return self._alpha_decay_ratio
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def init_epsilon(self) -> float:
|
|
230
|
+
return self._init_epsilon
|
|
231
|
+
|
|
232
|
+
@init_epsilon.setter
|
|
233
|
+
def init_epsilon(self, val: float):
|
|
234
|
+
self._init_epsilon = val
|
|
235
|
+
|
|
236
|
+
@property
|
|
237
|
+
def min_epsilon(self) -> float:
|
|
238
|
+
return self._min_epsilon
|
|
239
|
+
|
|
240
|
+
@min_epsilon.setter
|
|
241
|
+
def min_epsilon(self, val: float):
|
|
242
|
+
self._min_epsilon = val
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def epsilon_decay_ratio(self) -> float:
|
|
246
|
+
return self._epsilon_decay_ratio
|
|
247
|
+
|
|
248
|
+
@epsilon_decay_ratio.setter
|
|
249
|
+
def epsilon_decay_ratio(self, val: float):
|
|
250
|
+
self._epsilon_decay_ratio = val
|
|
251
|
+
|
|
252
|
+
@property
|
|
253
|
+
def initial_state_index(self) -> int:
|
|
254
|
+
return self._init_state_index
|
|
255
|
+
|
|
256
|
+
@initial_state_index.setter
|
|
257
|
+
def initial_state_index(self, val: int):
|
|
258
|
+
self._init_state_index = val
|
pyrlutils/transition.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
|
|
2
|
+
from types import LambdaType, FunctionType
|
|
3
|
+
from typing import Union
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from .state import DiscreteState, DiscreteStateValueType
|
|
9
|
+
from .reward import IndividualRewardFunction
|
|
10
|
+
from .action import Action, DiscreteActionValueType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class NextStateTuple:
|
|
15
|
+
next_state_value: DiscreteStateValueType
|
|
16
|
+
probability: float
|
|
17
|
+
reward: float
|
|
18
|
+
terminal: bool
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TransitionProbabilityFactory:
|
|
22
|
+
def __init__(self):
|
|
23
|
+
self._transprobs = {}
|
|
24
|
+
self._all_state_values = []
|
|
25
|
+
self._all_action_values = []
|
|
26
|
+
self._objects_generated = False
|
|
27
|
+
|
|
28
|
+
def add_state_transitions(
|
|
29
|
+
self,
|
|
30
|
+
state_value: DiscreteStateValueType,
|
|
31
|
+
action_values_to_next_state: dict[DiscreteActionValueType, Union[list[NextStateTuple], dict]]
|
|
32
|
+
):
|
|
33
|
+
if state_value not in self._all_state_values:
|
|
34
|
+
self._all_state_values.append(state_value)
|
|
35
|
+
|
|
36
|
+
this_state_transition_dict = {}
|
|
37
|
+
|
|
38
|
+
for action_value, next_state_tuples in action_values_to_next_state.items():
|
|
39
|
+
this_state_transition_dict[action_value] = []
|
|
40
|
+
for next_state_tuple in next_state_tuples:
|
|
41
|
+
if action_value not in self._all_action_values:
|
|
42
|
+
self._all_action_values.append(action_value)
|
|
43
|
+
if not isinstance(next_state_tuple, NextStateTuple):
|
|
44
|
+
if isinstance(next_state_tuple, dict):
|
|
45
|
+
next_state_tuple = NextStateTuple(
|
|
46
|
+
next_state_tuple['next_state_value'],
|
|
47
|
+
next_state_tuple['probability'],
|
|
48
|
+
next_state_tuple['reward'],
|
|
49
|
+
next_state_tuple['terminal']
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
raise TypeError('"action_values_to_next_state" has to be a dictionary or NextStateTuple instance.')
|
|
53
|
+
|
|
54
|
+
if next_state_tuple.next_state_value not in self._all_state_values:
|
|
55
|
+
self._all_state_values.append(next_state_tuple.next_state_value)
|
|
56
|
+
|
|
57
|
+
this_state_transition_dict[action_value].append(next_state_tuple)
|
|
58
|
+
|
|
59
|
+
self._transprobs[state_value] = this_state_transition_dict
|
|
60
|
+
|
|
61
|
+
def _get_probs_for_eachstate(
|
|
62
|
+
self,
|
|
63
|
+
action_value: DiscreteActionValueType
|
|
64
|
+
) -> dict[DiscreteStateValueType, list[NextStateTuple]]:
|
|
65
|
+
state_nexttuples = {}
|
|
66
|
+
for state_value, action_nexttuples_pair in self._transprobs.items():
|
|
67
|
+
for this_action_value, nexttuples in action_nexttuples_pair.items():
|
|
68
|
+
if this_action_value == action_value:
|
|
69
|
+
state_nexttuples[state_value] = nexttuples
|
|
70
|
+
return state_nexttuples
|
|
71
|
+
|
|
72
|
+
def _generate_action_function(
|
|
73
|
+
self,
|
|
74
|
+
state_nexttuples: dict[DiscreteStateValueType, list[NextStateTuple]]
|
|
75
|
+
) -> Union[FunctionType, LambdaType]:
|
|
76
|
+
|
|
77
|
+
def _action_function(state: DiscreteState) -> DiscreteState:
|
|
78
|
+
nexttuples = state_nexttuples[state.state_value]
|
|
79
|
+
nextstates = [nexttuple.next_state_value for nexttuple in nexttuples]
|
|
80
|
+
probs = [nexttuple.probability for nexttuple in nexttuples]
|
|
81
|
+
next_state_value = np.random.choice(nextstates, p=probs)
|
|
82
|
+
state.set_state_value(next_state_value)
|
|
83
|
+
return state
|
|
84
|
+
|
|
85
|
+
return _action_function
|
|
86
|
+
|
|
87
|
+
def _generate_individual_reward_function(self) -> IndividualRewardFunction:
|
|
88
|
+
|
|
89
|
+
def _individual_reward_function(
|
|
90
|
+
state_value: DiscreteStateValueType,
|
|
91
|
+
action_value: DiscreteActionValueType,
|
|
92
|
+
next_state_value: DiscreteStateValueType
|
|
93
|
+
) -> float:
|
|
94
|
+
if state_value not in self._transprobs.keys():
|
|
95
|
+
return 0.
|
|
96
|
+
|
|
97
|
+
if action_value not in self._transprobs[state_value].keys():
|
|
98
|
+
return 0.
|
|
99
|
+
|
|
100
|
+
reward = 0.
|
|
101
|
+
for next_tuple in self._transprobs[state_value][action_value]:
|
|
102
|
+
if next_tuple.next_state_value == next_state_value:
|
|
103
|
+
reward += next_tuple.reward
|
|
104
|
+
return reward
|
|
105
|
+
|
|
106
|
+
class ThisIndividualRewardFunction(IndividualRewardFunction):
|
|
107
|
+
def reward(
|
|
108
|
+
self,
|
|
109
|
+
state_value: DiscreteStateValueType,
|
|
110
|
+
action_value: DiscreteActionValueType,
|
|
111
|
+
next_state_value: DiscreteStateValueType
|
|
112
|
+
) -> float:
|
|
113
|
+
return _individual_reward_function(state_value, action_value, next_state_value)
|
|
114
|
+
|
|
115
|
+
return ThisIndividualRewardFunction()
|
|
116
|
+
|
|
117
|
+
def get_probability(
|
|
118
|
+
self,
|
|
119
|
+
state_value: DiscreteStateValueType,
|
|
120
|
+
action_value: DiscreteActionValueType,
|
|
121
|
+
new_state_value: DiscreteStateValueType
|
|
122
|
+
) -> float:
|
|
123
|
+
if state_value not in self._transprobs.keys():
|
|
124
|
+
return 0.
|
|
125
|
+
|
|
126
|
+
if action_value not in self._transprobs[state_value]:
|
|
127
|
+
return 0.
|
|
128
|
+
|
|
129
|
+
probs = 0.
|
|
130
|
+
for next_state_tuple in self._transprobs[state_value][action_value]:
|
|
131
|
+
if next_state_tuple.next_state_value == new_state_value:
|
|
132
|
+
probs += next_state_tuple.probability
|
|
133
|
+
return probs
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def transition_probabilities(self) -> dict[DiscreteStateValueType, dict[DiscreteActionValueType, list[NextStateTuple]]]:
|
|
137
|
+
return self._transprobs
|
|
138
|
+
|
|
139
|
+
def generate_mdp_objects(self) -> tuple[DiscreteState, dict[DiscreteActionValueType, Action], IndividualRewardFunction]:
|
|
140
|
+
state = DiscreteState(self._all_state_values)
|
|
141
|
+
actions_dict = {}
|
|
142
|
+
for action_value in self._all_action_values:
|
|
143
|
+
state_nexttuple = self._get_probs_for_eachstate(action_value)
|
|
144
|
+
actions_dict[action_value] = Action(self._generate_action_function(state_nexttuple))
|
|
145
|
+
for next_tuples in state_nexttuple.values():
|
|
146
|
+
for next_tuple in next_tuples:
|
|
147
|
+
state._terminal_dict[next_tuple.next_state_value] = next_tuple.terminal
|
|
148
|
+
|
|
149
|
+
individual_reward_fcn = self._generate_individual_reward_function()
|
|
150
|
+
self._objects_generated = True
|
|
151
|
+
return state, actions_dict, individual_reward_fcn
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def objects_generated(self) -> bool:
|
|
155
|
+
return self._objects_generated
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pyrlutils
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Utility and Helpers for Reinformcement Learning
|
|
5
|
+
Author-email: Kwan Yuet Stephen Ho <stephenhky@yahoo.com.hk>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Repository, https://github.com/stephenhky/PyRLUtils
|
|
8
|
+
Project-URL: Issues, https://github.com/stephenhky/PyRLUtils/issues
|
|
9
|
+
Keywords: machine learning,reinforcement leaning,artificial intelligence
|
|
10
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
13
|
+
Classifier: Topic :: Software Development :: Version Control :: Git
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
18
|
+
Classifier: Intended Audience :: Science/Research
|
|
19
|
+
Classifier: Intended Audience :: Developers
|
|
20
|
+
Requires-Python: >=3.10
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
License-File: LICENSE
|
|
23
|
+
Requires-Dist: numpy
|
|
24
|
+
Requires-Dist: npdict>=0.0.7
|
|
25
|
+
Requires-Dist: typing-extensions
|
|
26
|
+
Provides-Extra: openaigym
|
|
27
|
+
Requires-Dist: gymnasium; extra == "openaigym"
|
|
28
|
+
Provides-Extra: test
|
|
29
|
+
Requires-Dist: unittest; extra == "test"
|
|
30
|
+
Dynamic: license-file
|
|
31
|
+
|
|
32
|
+
# PyRLUtils
|
|
33
|
+
|
|
34
|
+
[](https://circleci.com/gh/stephenhky/PyRLUtils.svg)
|
|
35
|
+
[](https://github.com/stephenhky/pyqentangle/PyRLUtils)
|
|
36
|
+
[](https://pypi.org/project/pyqentangle/)
|
|
37
|
+
[](https://pypi.org/project/PyRLUtils/)
|
|
38
|
+
[](https://pyup.io/repos/github/stephenhky/PyRLUtils/)
|
|
39
|
+
[](https://pyup.io/repos/github/stephenhky/PyRLUtils/)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
This is a Python package with utility classes and helper functions for
|
|
43
|
+
that facilitates the development of any reinformecement learning projects.
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
pyrlutils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
pyrlutils/action.py,sha256=QoBdtcGtK_EkYAjb50bruhoB_XIz0agLpQjdGFnGbRQ,732
|
|
3
|
+
pyrlutils/policy.py,sha256=A9bj2eVd6XjNNkClSYVJDoxoGuGkyoYVr1DpVdI0wzs,5120
|
|
4
|
+
pyrlutils/reward.py,sha256=are0swsobMqI1IbrBVBaPMYXWpJnp6lZwAyfgBEm2zg,1211
|
|
5
|
+
pyrlutils/state.py,sha256=h-OGrezt0fWfVdM9-BTfqdhx1Ert_utG0ORpIwHwXCw,11902
|
|
6
|
+
pyrlutils/transition.py,sha256=_32jxeYbsiKyaHR9Y2XceUQYbb1jslLCQO2AWL61_EU,6260
|
|
7
|
+
pyrlutils/bandit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
+
pyrlutils/bandit/algo.py,sha256=X2Pn4DOi-RXWz5CNg1h0RJCoV3VlAwEGHRMjkfbckfw,3969
|
|
9
|
+
pyrlutils/bandit/reward.py,sha256=l2H_gZk2qqDxZioHe1M28pD8N47fgSR-K0Q6muchVd0,282
|
|
10
|
+
pyrlutils/dp/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
pyrlutils/dp/valuefcns.py,sha256=0T7vzdKRIKhLMsaq7JgPqONMmq4lWRGPj7xPtuxVtbE,6546
|
|
12
|
+
pyrlutils/helpers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
+
pyrlutils/helpers/exceptions.py,sha256=4fPGW839BChfap-Gd7b-75Dz-Ed3foqbJQ1lg15TZ-4,192
|
|
14
|
+
pyrlutils/openai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
+
pyrlutils/openai/utils.py,sha256=PJc9WHZM8aM4Z9MlACUxUC8TO7VARp8taatba_ikhew,1056
|
|
16
|
+
pyrlutils/td/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
|
+
pyrlutils/td/doubleqlearn.py,sha256=BbNOBaTUSiKD9z_IPFjl2R40bEf_xzZESv1fNEn7-Jg,4375
|
|
18
|
+
pyrlutils/td/qlearn.py,sha256=ZibW_fuB89ZAST5snNYLe5H_zUIMZ93vuJXguXpccyo,3374
|
|
19
|
+
pyrlutils/td/sarsa.py,sha256=jtnfMdPHld9C8yzDMQd3xyLZ3BwGL6ShvDq5WpHfZEo,3281
|
|
20
|
+
pyrlutils/td/state_td.py,sha256=gMX-RuSZQ-UIoTWnsmR7xLZvL2jndRknXTExWnhixpM,4778
|
|
21
|
+
pyrlutils/td/utils.py,sha256=VM5MAfWLQIk6a_qENU-iWHglp1azlaP68qkHvl4jXro,8022
|
|
22
|
+
pyrlutils-0.1.2.dist-info/licenses/LICENSE,sha256=bnQPjIcaeBdr2ZofX-_j-nELs8pAx5fQ4Cdfgeaspew,1063
|
|
23
|
+
pyrlutils-0.1.2.dist-info/METADATA,sha256=wjoPuC3G6WLmZdBUd8eqjgjI32FYuBejiUkT-CAs7JY,2214
|
|
24
|
+
pyrlutils-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
25
|
+
pyrlutils-0.1.2.dist-info/top_level.txt,sha256=gOBuxugE2MA4WDXlLhzkQh_rUonZU6nvJnMuomeHMCU,10
|
|
26
|
+
pyrlutils-0.1.2.dist-info/RECORD,,
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Copyright (c) 2023 Kwan Yuet Stephen Ho
|
|
2
|
+
|
|
3
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
4
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
5
|
+
in the Software without restriction, including without limitation the rights
|
|
6
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
7
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
8
|
+
furnished to do so, subject to the following conditions:
|
|
9
|
+
|
|
10
|
+
The above copyright notice and this permission notice shall be included in all
|
|
11
|
+
copies or substantial portions of the Software.
|
|
12
|
+
|
|
13
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
14
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
15
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
16
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
17
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
18
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
19
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
pyrlutils
|