pyrlutils 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyrlutils/__init__.py +0 -0
- pyrlutils/action.py +27 -0
- pyrlutils/bandit/__init__.py +0 -0
- pyrlutils/bandit/algo.py +128 -0
- pyrlutils/bandit/reward.py +12 -0
- pyrlutils/dp/__init__.py +0 -0
- pyrlutils/dp/valuefcns.py +149 -0
- pyrlutils/helpers/__init__.py +0 -0
- pyrlutils/helpers/exceptions.py +5 -0
- pyrlutils/openai/__init__.py +0 -0
- pyrlutils/openai/utils.py +31 -0
- pyrlutils/policy.py +151 -0
- pyrlutils/reward.py +37 -0
- pyrlutils/state.py +320 -0
- pyrlutils/td/__init__.py +0 -0
- pyrlutils/td/doubleqlearn.py +110 -0
- pyrlutils/td/qlearn.py +86 -0
- pyrlutils/td/sarsa.py +86 -0
- pyrlutils/td/state_td.py +111 -0
- pyrlutils/td/utils.py +258 -0
- pyrlutils/transition.py +155 -0
- pyrlutils-0.1.2.dist-info/METADATA +43 -0
- pyrlutils-0.1.2.dist-info/RECORD +26 -0
- pyrlutils-0.1.2.dist-info/WHEEL +5 -0
- pyrlutils-0.1.2.dist-info/licenses/LICENSE +19 -0
- pyrlutils-0.1.2.dist-info/top_level.txt +1 -0
pyrlutils/state.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
|
|
2
|
+
import sys
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Optional, Union, Annotated, Literal
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
if sys.version_info < (3, 11):
|
|
10
|
+
from typing_extensions import Self
|
|
11
|
+
else:
|
|
12
|
+
from typing import Self
|
|
13
|
+
|
|
14
|
+
from .helpers.exceptions import InvalidRangeError
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class State(ABC):
|
|
18
|
+
@property
|
|
19
|
+
def state_value(self):
|
|
20
|
+
raise NotImplemented()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
DiscreteStateValueType = Union[str, int, tuple[int], Enum]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DiscreteState(State):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
all_state_values: list[DiscreteStateValueType],
|
|
30
|
+
initial_value: Optional[DiscreteStateValueType] = None,
|
|
31
|
+
terminals: Optional[dict[DiscreteStateValueType, bool]]=None
|
|
32
|
+
):
|
|
33
|
+
super().__init__()
|
|
34
|
+
self._all_state_values = all_state_values
|
|
35
|
+
self._state_values_to_indices = {
|
|
36
|
+
state_value: idx
|
|
37
|
+
for idx, state_value in enumerate(self._all_state_values)
|
|
38
|
+
}
|
|
39
|
+
if initial_value is not None:
|
|
40
|
+
self._current_index = self._state_values_to_indices[initial_value]
|
|
41
|
+
else:
|
|
42
|
+
self._current_index = 0
|
|
43
|
+
if terminals is None:
|
|
44
|
+
self._terminal_dict = {
|
|
45
|
+
state_value: False
|
|
46
|
+
for state_value in self._all_state_values
|
|
47
|
+
}
|
|
48
|
+
else:
|
|
49
|
+
self._terminal_dict = terminals.copy()
|
|
50
|
+
for state_value in self._all_state_values:
|
|
51
|
+
if self._terminal_dict.get(state_value) is None:
|
|
52
|
+
self._terminal_dict[state_value] = False
|
|
53
|
+
|
|
54
|
+
def _get_state_value_from_index(self, index: int) -> DiscreteStateValueType:
|
|
55
|
+
return self._all_state_values[index]
|
|
56
|
+
|
|
57
|
+
def get_state_value(self) -> DiscreteStateValueType:
|
|
58
|
+
return self._get_state_value_from_index(self._current_index)
|
|
59
|
+
|
|
60
|
+
def set_state_value(self, state_value: DiscreteStateValueType) -> None:
|
|
61
|
+
if state_value in self._all_state_values:
|
|
62
|
+
self._current_index = self._state_values_to_indices[state_value]
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError('State value {} is invalid.'.format(state_value))
|
|
65
|
+
|
|
66
|
+
def get_all_possible_state_values(self) -> list[DiscreteStateValueType]:
|
|
67
|
+
return self._all_state_values
|
|
68
|
+
|
|
69
|
+
def query_state_index_from_value(self, value: DiscreteStateValueType) -> int:
|
|
70
|
+
return self._state_values_to_indices[value]
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def state_index(self) -> int:
|
|
74
|
+
return self._current_index
|
|
75
|
+
|
|
76
|
+
@state_index.setter
|
|
77
|
+
def state_index(self, new_index: int) -> None:
|
|
78
|
+
if new_index >= len(self._all_state_values):
|
|
79
|
+
raise ValueError(f"Invalid index {new_index}; it must be less than {self.nb_state_values}.")
|
|
80
|
+
self._current_index = new_index
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def state_value(self) -> DiscreteStateValueType:
|
|
84
|
+
return self._all_state_values[self._current_index]
|
|
85
|
+
|
|
86
|
+
@state_value.setter
|
|
87
|
+
def state_value(self, new_state_value: DiscreteStateValueType):
|
|
88
|
+
self.set_state_value(new_state_value)
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def state_space_size(self):
|
|
92
|
+
return len(self._all_state_values)
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def nb_state_values(self) -> int:
|
|
96
|
+
return len(self._all_state_values)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def is_terminal(self) -> bool:
|
|
100
|
+
return self._terminal_dict[self._all_state_values[self._current_index]]
|
|
101
|
+
|
|
102
|
+
def __hash__(self):
|
|
103
|
+
return self._current_index
|
|
104
|
+
|
|
105
|
+
def __eq__(self, other: Self) -> bool:
|
|
106
|
+
return self._current_index == other._current_index
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ContinuousState(State):
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
nbdims: int,
|
|
113
|
+
ranges: Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]],
|
|
114
|
+
init_value: Optional[Union[float, Annotated[NDArray[np.float64], "1D Array"]]] = None
|
|
115
|
+
):
|
|
116
|
+
super().__init__()
|
|
117
|
+
self._nbdims = nbdims
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
assert isinstance(ranges, np.ndarray)
|
|
121
|
+
except AssertionError:
|
|
122
|
+
raise TypeError('Range must be a numpy array.')
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
assert (ranges.dtype == np.float64) or (ranges.dtype == np.float32) or (ranges.dtype == np.float16)
|
|
126
|
+
except AssertionError:
|
|
127
|
+
raise TypeError('It has to be floating type numpy.ndarray.')
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
assert ranges.ndim == 1 or ranges.ndim == 2
|
|
131
|
+
match ranges.ndim:
|
|
132
|
+
case 1:
|
|
133
|
+
assert ranges.shape[0] == 2
|
|
134
|
+
case 2:
|
|
135
|
+
assert ranges.shape[1] == 2
|
|
136
|
+
case _:
|
|
137
|
+
raise ValueError("Ranges must be of shape (2, ) or (*, 2).")
|
|
138
|
+
except AssertionError:
|
|
139
|
+
raise ValueError("Ranges must be of shape (2, ) or (*, 2).")
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
assert self._nbdims > 0
|
|
143
|
+
except AssertionError:
|
|
144
|
+
raise ValueError('Number of dimensions must be positive.')
|
|
145
|
+
|
|
146
|
+
if self._nbdims > 1:
|
|
147
|
+
try:
|
|
148
|
+
assert self._nbdims == ranges.shape[0]
|
|
149
|
+
except AssertionError:
|
|
150
|
+
raise ValueError('Number of ranges does not meet the number of dimensions.')
|
|
151
|
+
try:
|
|
152
|
+
assert ranges.shape[1] == 2
|
|
153
|
+
except AssertionError:
|
|
154
|
+
raise ValueError("Only the smallest and largest values in `ranges'.")
|
|
155
|
+
else:
|
|
156
|
+
try:
|
|
157
|
+
assert ranges.shape[0] == 2
|
|
158
|
+
except AssertionError:
|
|
159
|
+
raise ValueError("Only the smallest and largest values in `ranges'.")
|
|
160
|
+
|
|
161
|
+
if self._nbdims > 1:
|
|
162
|
+
try:
|
|
163
|
+
for i in range(ranges.shape[0]):
|
|
164
|
+
assert ranges[i, 0] <= ranges[i, 1]
|
|
165
|
+
except AssertionError:
|
|
166
|
+
raise InvalidRangeError()
|
|
167
|
+
else:
|
|
168
|
+
try:
|
|
169
|
+
assert ranges[0] <= ranges[1]
|
|
170
|
+
except AssertionError:
|
|
171
|
+
raise InvalidRangeError()
|
|
172
|
+
|
|
173
|
+
self._ranges = ranges if self._nbdims > 1 else np.expand_dims(ranges, axis=0)
|
|
174
|
+
if init_value is None:
|
|
175
|
+
self._state_value = np.zeros(self._nbdims)
|
|
176
|
+
for i in range(self._nbdims):
|
|
177
|
+
self._state_value[i] = np.random.uniform(self._ranges[i, 0], self._ranges[i, 1])
|
|
178
|
+
else:
|
|
179
|
+
if self._nbdims > 1:
|
|
180
|
+
try:
|
|
181
|
+
assert init_value.shape[0] == self._nbdims
|
|
182
|
+
except AssertionError:
|
|
183
|
+
raise ValueError('Initialized value does not have the right dimension.')
|
|
184
|
+
for i in range(self._nbdims):
|
|
185
|
+
try:
|
|
186
|
+
assert self._ranges[i, 0] <= init_value[i] <= self.ranges[i, 1]
|
|
187
|
+
except AssertionError:
|
|
188
|
+
raise InvalidRangeError('Initialized value at dimension {} (value: {}) is not within the permitted range ({} -> {})!'.format(i, init_value[i], self._ranges[i, 0], self._ranges[i, 1]))
|
|
189
|
+
else:
|
|
190
|
+
try:
|
|
191
|
+
assert self._ranges[0, 0] <= init_value <= self.ranges[0, 1]
|
|
192
|
+
except AssertionError:
|
|
193
|
+
raise InvalidRangeError('Initialized value is out of range.')
|
|
194
|
+
self._state_value = init_value
|
|
195
|
+
|
|
196
|
+
def set_state_value(self, state_value: Union[float, Annotated[NDArray[np.float64], "1D Array"]]):
|
|
197
|
+
if self._nbdims > 1:
|
|
198
|
+
try:
|
|
199
|
+
assert state_value.shape[0] == self._nbdims
|
|
200
|
+
except AssertionError:
|
|
201
|
+
raise ValueError('Given value does not have the right dimension.')
|
|
202
|
+
for i in range(self._nbdims):
|
|
203
|
+
try:
|
|
204
|
+
assert self.ranges[i, 0] <= state_value[i] <= self.ranges[i, 1]
|
|
205
|
+
except AssertionError:
|
|
206
|
+
raise InvalidRangeError()
|
|
207
|
+
else:
|
|
208
|
+
try:
|
|
209
|
+
assert self.ranges[0, 0] <= state_value <= self.ranges[0, 1]
|
|
210
|
+
except AssertionError:
|
|
211
|
+
raise InvalidRangeError()
|
|
212
|
+
|
|
213
|
+
self._state_value = state_value
|
|
214
|
+
|
|
215
|
+
def get_state_value(self) -> Annotated[NDArray[np.float64], "1D Array"]:
|
|
216
|
+
return self._state_value
|
|
217
|
+
|
|
218
|
+
def get_state_value_ranges(self) -> Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]]:
|
|
219
|
+
return self._ranges
|
|
220
|
+
|
|
221
|
+
def get_state_value_range_at_dimension(self, dimension: int) -> Annotated[NDArray[np.float64], Literal["2"]]:
|
|
222
|
+
if dimension < self._nbdims:
|
|
223
|
+
return self._ranges[dimension]
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(f"There are only {self._nbdims} dimensions!")
|
|
226
|
+
|
|
227
|
+
@property
|
|
228
|
+
def ranges(self) -> Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]]:
|
|
229
|
+
return self.get_state_value_ranges()
|
|
230
|
+
|
|
231
|
+
@property
|
|
232
|
+
def state_value(self) -> Union[float, NDArray[np.float64]]:
|
|
233
|
+
return self.get_state_value()
|
|
234
|
+
|
|
235
|
+
@state_value.setter
|
|
236
|
+
def state_value(self, new_state_value):
|
|
237
|
+
self.set_state_value(new_state_value)
|
|
238
|
+
|
|
239
|
+
@property
|
|
240
|
+
def nbdims(self) -> int:
|
|
241
|
+
return self._nbdims
|
|
242
|
+
|
|
243
|
+
def __hash__(self):
|
|
244
|
+
return hash(tuple(self._state_value))
|
|
245
|
+
|
|
246
|
+
def __eq__(self, other: Self):
|
|
247
|
+
if self.nbdims != other.nbdims:
|
|
248
|
+
raise ValueError(f"The two states have two different dimensions. ({self.nbdims} vs. {other.nbdims})")
|
|
249
|
+
for i in range(self.nbdims):
|
|
250
|
+
if self.state_value[i] != other.state_value[i]:
|
|
251
|
+
return False
|
|
252
|
+
return True
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class Discrete2DCartesianState(DiscreteState):
|
|
256
|
+
def __init__(
|
|
257
|
+
self,
|
|
258
|
+
x_lowlim: int,
|
|
259
|
+
x_hilim: int,
|
|
260
|
+
y_lowlim: int,
|
|
261
|
+
y_hilim: int,
|
|
262
|
+
initial_coordinate: list[int]=None,
|
|
263
|
+
terminals: Optional[dict[DiscreteStateValueType, bool]] = None
|
|
264
|
+
):
|
|
265
|
+
self._x_lowlim = x_lowlim
|
|
266
|
+
self._x_hilim = x_hilim
|
|
267
|
+
self._y_lowlim = y_lowlim
|
|
268
|
+
self._y_hilim = y_hilim
|
|
269
|
+
self._countx = self._x_hilim - self._x_lowlim + 1
|
|
270
|
+
self._county = self._y_hilim - self._y_lowlim + 1
|
|
271
|
+
if initial_coordinate is None:
|
|
272
|
+
initial_coordinate = [self._x_lowlim, self._y_lowlim]
|
|
273
|
+
initial_value = (initial_coordinate[1] - self._y_lowlim) * self._countx + (initial_coordinate[0] - self._x_lowlim)
|
|
274
|
+
super().__init__(list(range(self._countx*self._county)), initial_value=initial_value, terminals=terminals)
|
|
275
|
+
|
|
276
|
+
def _encode_coordinates(self, x, y) -> int:
|
|
277
|
+
return (y - self._y_lowlim) * self._countx + (x - self._x_lowlim)
|
|
278
|
+
|
|
279
|
+
def encode_coordinates(self, coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]]) -> int:
|
|
280
|
+
if isinstance(coordinates, list):
|
|
281
|
+
assert len(coordinates) == 2
|
|
282
|
+
return self._encode_coordinates(coordinates[0], coordinates[1])
|
|
283
|
+
|
|
284
|
+
def decode_coordinates(self, hashcode) -> list[int]:
|
|
285
|
+
return [hashcode % self._countx + self._x_lowlim, hashcode // self._countx + self._y_lowlim]
|
|
286
|
+
|
|
287
|
+
def get_whether_terminal_given_coordinates(
|
|
288
|
+
self,
|
|
289
|
+
coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]]
|
|
290
|
+
) -> bool:
|
|
291
|
+
if isinstance(coordinates, list):
|
|
292
|
+
assert len(coordinates) == 2
|
|
293
|
+
hashcode = self._encode_coordinates(coordinates[0], coordinates[1])
|
|
294
|
+
return self._terminal_dict.get(hashcode, False)
|
|
295
|
+
|
|
296
|
+
def set_terminal_given_coordinates(
|
|
297
|
+
self,
|
|
298
|
+
coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]],
|
|
299
|
+
terminal_value: bool
|
|
300
|
+
) -> None:
|
|
301
|
+
if isinstance(coordinates, list):
|
|
302
|
+
assert len(coordinates) == 2
|
|
303
|
+
hashcode = self._encode_coordinates(coordinates[0], coordinates[1])
|
|
304
|
+
self._terminal_dict[hashcode] = terminal_value
|
|
305
|
+
|
|
306
|
+
@property
|
|
307
|
+
def x_lowlim(self) -> int:
|
|
308
|
+
return self._x_lowlim
|
|
309
|
+
|
|
310
|
+
@property
|
|
311
|
+
def x_hilim(self) -> int:
|
|
312
|
+
return self._x_hilim
|
|
313
|
+
|
|
314
|
+
@property
|
|
315
|
+
def y_lowlim(self) -> int:
|
|
316
|
+
return self._y_lowlim
|
|
317
|
+
|
|
318
|
+
@property
|
|
319
|
+
def y_hilim(self) -> int:
|
|
320
|
+
return self._y_hilim
|
pyrlutils/td/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Annotated
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from npdict import NumpyNDArrayWrappedDict
|
|
6
|
+
|
|
7
|
+
from .utils import AbstractStateActionValueFunctionTemporalDifferenceLearner, decay_schedule, select_action
|
|
8
|
+
from ..policy import DiscreteDeterminsticPolicy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DoubleQLearner(AbstractStateActionValueFunctionTemporalDifferenceLearner):
|
|
12
|
+
def learn(
|
|
13
|
+
self,
|
|
14
|
+
episodes: int
|
|
15
|
+
) -> tuple[
|
|
16
|
+
Annotated[NumpyNDArrayWrappedDict, "2D array"],
|
|
17
|
+
Annotated[NumpyNDArrayWrappedDict, "1D array"],
|
|
18
|
+
DiscreteDeterminsticPolicy,
|
|
19
|
+
Annotated[NumpyNDArrayWrappedDict, "3D array"],
|
|
20
|
+
list[DiscreteDeterminsticPolicy]
|
|
21
|
+
]:
|
|
22
|
+
Q1 = NumpyNDArrayWrappedDict(
|
|
23
|
+
[
|
|
24
|
+
self._state.get_all_possible_state_values(),
|
|
25
|
+
self._action_names
|
|
26
|
+
],
|
|
27
|
+
default_initial_value=0.0
|
|
28
|
+
)
|
|
29
|
+
Q1_track = NumpyNDArrayWrappedDict(
|
|
30
|
+
[
|
|
31
|
+
list(range(episodes)),
|
|
32
|
+
self._state.get_all_possible_state_values(),
|
|
33
|
+
self._action_names
|
|
34
|
+
],
|
|
35
|
+
default_initial_value=0.0
|
|
36
|
+
)
|
|
37
|
+
Q2 = NumpyNDArrayWrappedDict(
|
|
38
|
+
[
|
|
39
|
+
self._state.get_all_possible_state_values(),
|
|
40
|
+
self._action_names
|
|
41
|
+
],
|
|
42
|
+
default_initial_value=0.0
|
|
43
|
+
)
|
|
44
|
+
Q2_track = NumpyNDArrayWrappedDict(
|
|
45
|
+
[
|
|
46
|
+
list(range(episodes)),
|
|
47
|
+
self._state.get_all_possible_state_values(),
|
|
48
|
+
self._action_names
|
|
49
|
+
],
|
|
50
|
+
default_initial_value=0.0
|
|
51
|
+
)
|
|
52
|
+
pi_track = []
|
|
53
|
+
|
|
54
|
+
Q1_array, Q1_track_array = Q1.to_numpy(), Q1_track.to_numpy()
|
|
55
|
+
Q2_array, Q2_track_array = Q2.to_numpy(), Q2_track.to_numpy()
|
|
56
|
+
alphas = decay_schedule(
|
|
57
|
+
self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
|
|
58
|
+
)
|
|
59
|
+
epsilons = decay_schedule(
|
|
60
|
+
self.init_epsilon, self.min_epsilon, self.epsilon_decay_ratio, episodes
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
for i in range(episodes):
|
|
64
|
+
self._state.state_index = self.initial_state_index
|
|
65
|
+
average_Q = Q1.generate_dict(0.5*(Q1_array+Q2_array))
|
|
66
|
+
done = False
|
|
67
|
+
action_value = select_action(self._state.state_value, average_Q, epsilons[i])
|
|
68
|
+
while not done:
|
|
69
|
+
# decide whether to pick Q1 or Q2
|
|
70
|
+
Q = Q1 if np.random.randint(2) else Q2
|
|
71
|
+
|
|
72
|
+
old_state_value = self._state.state_value
|
|
73
|
+
new_action_value = select_action(self._state.state_value, Q, epsilons[i])
|
|
74
|
+
new_action_func = self._actions_dict[new_action_value]
|
|
75
|
+
self._state = new_action_func(self._state)
|
|
76
|
+
new_state_value = self._state.state_value
|
|
77
|
+
reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
|
|
78
|
+
done = self._state.is_terminal
|
|
79
|
+
|
|
80
|
+
new_state_index = Q.get_key_index(0, new_state_value)
|
|
81
|
+
max_Q_given_state = Q.to_numpy()[new_state_index, :].max()
|
|
82
|
+
td_target = reward + self.gamma * max_Q_given_state * (not done)
|
|
83
|
+
td_error = td_target - Q[old_state_value, action_value]
|
|
84
|
+
Q[old_state_value, action_value] = Q[old_state_value, action_value] + alphas[i] * td_error
|
|
85
|
+
|
|
86
|
+
Q1_track_array[i, :, :] = Q1_array
|
|
87
|
+
Q2_track_array[i, :, :] = Q2_array
|
|
88
|
+
average_Q = Q1.generate_dict(0.5 * (Q1_array + Q2_array))
|
|
89
|
+
pi_track.append(DiscreteDeterminsticPolicy(
|
|
90
|
+
{
|
|
91
|
+
state_value: select_action(state_value, average_Q, epsilon=0.0)
|
|
92
|
+
for state_value in self._state.get_all_possible_state_values()
|
|
93
|
+
}
|
|
94
|
+
))
|
|
95
|
+
|
|
96
|
+
Q = Q1.generate_dict(0.5 * (Q1_array + Q2_array))
|
|
97
|
+
V_array = np.max(Q.to_numpy(), axis=1)
|
|
98
|
+
V = NumpyNDArrayWrappedDict.from_numpyarray_given_keywords(
|
|
99
|
+
[self._state.get_all_possible_state_values()],
|
|
100
|
+
V_array
|
|
101
|
+
)
|
|
102
|
+
pi = DiscreteDeterminsticPolicy(
|
|
103
|
+
{
|
|
104
|
+
state_value: select_action(state_value, Q, epsilon=0.0)
|
|
105
|
+
for state_value in self._state.get_all_possible_state_values()
|
|
106
|
+
}
|
|
107
|
+
)
|
|
108
|
+
Q_track = Q1_track.generate_dict(0.5 * (Q1_track_array + Q2_track_array))
|
|
109
|
+
|
|
110
|
+
return Q, V, pi, Q_track, pi_track
|
pyrlutils/td/qlearn.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Annotated
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from npdict import NumpyNDArrayWrappedDict
|
|
6
|
+
|
|
7
|
+
from .utils import AbstractStateActionValueFunctionTemporalDifferenceLearner, decay_schedule, select_action
|
|
8
|
+
from ..policy import DiscreteDeterminsticPolicy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class QLearner(AbstractStateActionValueFunctionTemporalDifferenceLearner):
|
|
12
|
+
def learn(
|
|
13
|
+
self,
|
|
14
|
+
episodes: int
|
|
15
|
+
) -> tuple[
|
|
16
|
+
Annotated[NumpyNDArrayWrappedDict, "2D array"],
|
|
17
|
+
Annotated[NumpyNDArrayWrappedDict, "1D array"],
|
|
18
|
+
DiscreteDeterminsticPolicy,
|
|
19
|
+
Annotated[NumpyNDArrayWrappedDict, "3D array"],
|
|
20
|
+
list[DiscreteDeterminsticPolicy]
|
|
21
|
+
]:
|
|
22
|
+
Q = NumpyNDArrayWrappedDict(
|
|
23
|
+
[
|
|
24
|
+
self._state.get_all_possible_state_values(),
|
|
25
|
+
self._action_names
|
|
26
|
+
],
|
|
27
|
+
default_initial_value=0.0
|
|
28
|
+
)
|
|
29
|
+
Q_track = NumpyNDArrayWrappedDict(
|
|
30
|
+
[
|
|
31
|
+
list(range(episodes)),
|
|
32
|
+
self._state.get_all_possible_state_values(),
|
|
33
|
+
self._action_names
|
|
34
|
+
],
|
|
35
|
+
default_initial_value=0.0
|
|
36
|
+
)
|
|
37
|
+
pi_track = []
|
|
38
|
+
|
|
39
|
+
Q_array, Q_track_array = Q.to_numpy(), Q_track.to_numpy()
|
|
40
|
+
alphas = decay_schedule(
|
|
41
|
+
self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
|
|
42
|
+
)
|
|
43
|
+
epsilons = decay_schedule(
|
|
44
|
+
self.init_epsilon, self.min_epsilon, self.epsilon_decay_ratio, episodes
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
for i in range(episodes):
|
|
48
|
+
self._state.state_index = self.initial_state_index
|
|
49
|
+
done = False
|
|
50
|
+
action_value = select_action(self._state.state_value, Q, epsilons[i])
|
|
51
|
+
while not done:
|
|
52
|
+
old_state_value = self._state.state_value
|
|
53
|
+
new_action_value = select_action(self._state.state_value, Q, epsilons[i])
|
|
54
|
+
new_action_func = self._actions_dict[new_action_value]
|
|
55
|
+
self._state = new_action_func(self._state)
|
|
56
|
+
new_state_value = self._state.state_value
|
|
57
|
+
reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
|
|
58
|
+
done = self._state.is_terminal
|
|
59
|
+
|
|
60
|
+
new_state_index = Q.get_key_index(0, new_state_value)
|
|
61
|
+
max_Q_given_state = Q.to_numpy()[new_state_index, :].max()
|
|
62
|
+
td_target = reward + self.gamma * max_Q_given_state * (not done)
|
|
63
|
+
td_error = td_target - Q[old_state_value, action_value]
|
|
64
|
+
Q[old_state_value, action_value] = Q[old_state_value, action_value] + alphas[i] * td_error
|
|
65
|
+
|
|
66
|
+
Q_track_array[i, :, :] = Q_array
|
|
67
|
+
pi_track.append(DiscreteDeterminsticPolicy(
|
|
68
|
+
{
|
|
69
|
+
state_value: select_action(state_value, Q, epsilon=0.0)
|
|
70
|
+
for state_value in self._state.get_all_possible_state_values()
|
|
71
|
+
}
|
|
72
|
+
))
|
|
73
|
+
|
|
74
|
+
V_array = np.max(Q_array, axis=1)
|
|
75
|
+
V = NumpyNDArrayWrappedDict.from_numpyarray_given_keywords(
|
|
76
|
+
[self._state.get_all_possible_state_values()],
|
|
77
|
+
V_array
|
|
78
|
+
)
|
|
79
|
+
pi = DiscreteDeterminsticPolicy(
|
|
80
|
+
{
|
|
81
|
+
state_value: select_action(state_value, Q, epsilon=0.0)
|
|
82
|
+
for state_value in self._state.get_all_possible_state_values()
|
|
83
|
+
}
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return Q, V, pi, Q_track, pi_track
|
pyrlutils/td/sarsa.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Annotated
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from npdict import NumpyNDArrayWrappedDict
|
|
6
|
+
|
|
7
|
+
from .utils import AbstractStateActionValueFunctionTemporalDifferenceLearner, decay_schedule, select_action
|
|
8
|
+
from ..policy import DiscreteDeterminsticPolicy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SARSALearner(AbstractStateActionValueFunctionTemporalDifferenceLearner):
|
|
12
|
+
def learn(
|
|
13
|
+
self,
|
|
14
|
+
episodes: int
|
|
15
|
+
) -> tuple[
|
|
16
|
+
Annotated[NumpyNDArrayWrappedDict, "2D array"],
|
|
17
|
+
Annotated[NumpyNDArrayWrappedDict, "1D array"],
|
|
18
|
+
DiscreteDeterminsticPolicy,
|
|
19
|
+
Annotated[NumpyNDArrayWrappedDict, "3D array"],
|
|
20
|
+
list[DiscreteDeterminsticPolicy]
|
|
21
|
+
]:
|
|
22
|
+
Q = NumpyNDArrayWrappedDict(
|
|
23
|
+
[
|
|
24
|
+
self._state.get_all_possible_state_values(),
|
|
25
|
+
self._action_names
|
|
26
|
+
],
|
|
27
|
+
default_initial_value=0.0
|
|
28
|
+
)
|
|
29
|
+
Q_track = NumpyNDArrayWrappedDict(
|
|
30
|
+
[
|
|
31
|
+
list(range(episodes)),
|
|
32
|
+
self._state.get_all_possible_state_values(),
|
|
33
|
+
self._action_names
|
|
34
|
+
],
|
|
35
|
+
default_initial_value=0.0
|
|
36
|
+
)
|
|
37
|
+
pi_track = []
|
|
38
|
+
|
|
39
|
+
Q_array, Q_track_array = Q.to_numpy(), Q_track.to_numpy()
|
|
40
|
+
alphas = decay_schedule(
|
|
41
|
+
self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
|
|
42
|
+
)
|
|
43
|
+
epsilons = decay_schedule(
|
|
44
|
+
self.init_epsilon, self.min_epsilon, self.epsilon_decay_ratio, episodes
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
for i in range(episodes):
|
|
48
|
+
self._state.state_index = self.initial_state_index
|
|
49
|
+
done = False
|
|
50
|
+
action_value = select_action(self._state.state_value, Q, epsilons[i])
|
|
51
|
+
while not done:
|
|
52
|
+
old_state_value = self._state.state_value
|
|
53
|
+
action_func = self._actions_dict[action_value]
|
|
54
|
+
self._state = action_func(self._state)
|
|
55
|
+
new_state_value = self._state.state_value
|
|
56
|
+
reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
|
|
57
|
+
done = self._state.is_terminal
|
|
58
|
+
new_action_value = select_action(new_state_value, Q, epsilons[i])
|
|
59
|
+
|
|
60
|
+
td_target = reward + self.gamma * Q[new_state_value, new_action_value] * (not done)
|
|
61
|
+
td_error = td_target - Q[old_state_value, action_value]
|
|
62
|
+
Q[old_state_value, action_value] = Q[old_state_value, action_value] + alphas[i] * td_error
|
|
63
|
+
|
|
64
|
+
action_value = new_action_value
|
|
65
|
+
|
|
66
|
+
Q_track_array[i, :, :] = Q_array
|
|
67
|
+
pi_track.append(DiscreteDeterminsticPolicy(
|
|
68
|
+
{
|
|
69
|
+
state_value: select_action(state_value, Q, epsilon=0.0)
|
|
70
|
+
for state_value in self._state.get_all_possible_state_values()
|
|
71
|
+
}
|
|
72
|
+
))
|
|
73
|
+
|
|
74
|
+
V_array = np.max(Q_array, axis=1)
|
|
75
|
+
V = NumpyNDArrayWrappedDict.from_numpyarray_given_keywords(
|
|
76
|
+
[self._state.get_all_possible_state_values()],
|
|
77
|
+
V_array
|
|
78
|
+
)
|
|
79
|
+
pi = DiscreteDeterminsticPolicy(
|
|
80
|
+
{
|
|
81
|
+
state_value: select_action(state_value, Q, epsilon=0.0)
|
|
82
|
+
for state_value in self._state.get_all_possible_state_values()
|
|
83
|
+
}
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return Q, V, pi, Q_track, pi_track
|