pyrlutils 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyrlutils/state.py ADDED
@@ -0,0 +1,320 @@
1
+
2
+ import sys
3
+ from abc import ABC
4
+ from enum import Enum
5
+ from typing import Optional, Union, Annotated, Literal
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+ if sys.version_info < (3, 11):
10
+ from typing_extensions import Self
11
+ else:
12
+ from typing import Self
13
+
14
+ from .helpers.exceptions import InvalidRangeError
15
+
16
+
17
+ class State(ABC):
18
+ @property
19
+ def state_value(self):
20
+ raise NotImplemented()
21
+
22
+
23
+ DiscreteStateValueType = Union[str, int, tuple[int], Enum]
24
+
25
+
26
+ class DiscreteState(State):
27
+ def __init__(
28
+ self,
29
+ all_state_values: list[DiscreteStateValueType],
30
+ initial_value: Optional[DiscreteStateValueType] = None,
31
+ terminals: Optional[dict[DiscreteStateValueType, bool]]=None
32
+ ):
33
+ super().__init__()
34
+ self._all_state_values = all_state_values
35
+ self._state_values_to_indices = {
36
+ state_value: idx
37
+ for idx, state_value in enumerate(self._all_state_values)
38
+ }
39
+ if initial_value is not None:
40
+ self._current_index = self._state_values_to_indices[initial_value]
41
+ else:
42
+ self._current_index = 0
43
+ if terminals is None:
44
+ self._terminal_dict = {
45
+ state_value: False
46
+ for state_value in self._all_state_values
47
+ }
48
+ else:
49
+ self._terminal_dict = terminals.copy()
50
+ for state_value in self._all_state_values:
51
+ if self._terminal_dict.get(state_value) is None:
52
+ self._terminal_dict[state_value] = False
53
+
54
+ def _get_state_value_from_index(self, index: int) -> DiscreteStateValueType:
55
+ return self._all_state_values[index]
56
+
57
+ def get_state_value(self) -> DiscreteStateValueType:
58
+ return self._get_state_value_from_index(self._current_index)
59
+
60
+ def set_state_value(self, state_value: DiscreteStateValueType) -> None:
61
+ if state_value in self._all_state_values:
62
+ self._current_index = self._state_values_to_indices[state_value]
63
+ else:
64
+ raise ValueError('State value {} is invalid.'.format(state_value))
65
+
66
+ def get_all_possible_state_values(self) -> list[DiscreteStateValueType]:
67
+ return self._all_state_values
68
+
69
+ def query_state_index_from_value(self, value: DiscreteStateValueType) -> int:
70
+ return self._state_values_to_indices[value]
71
+
72
+ @property
73
+ def state_index(self) -> int:
74
+ return self._current_index
75
+
76
+ @state_index.setter
77
+ def state_index(self, new_index: int) -> None:
78
+ if new_index >= len(self._all_state_values):
79
+ raise ValueError(f"Invalid index {new_index}; it must be less than {self.nb_state_values}.")
80
+ self._current_index = new_index
81
+
82
+ @property
83
+ def state_value(self) -> DiscreteStateValueType:
84
+ return self._all_state_values[self._current_index]
85
+
86
+ @state_value.setter
87
+ def state_value(self, new_state_value: DiscreteStateValueType):
88
+ self.set_state_value(new_state_value)
89
+
90
+ @property
91
+ def state_space_size(self):
92
+ return len(self._all_state_values)
93
+
94
+ @property
95
+ def nb_state_values(self) -> int:
96
+ return len(self._all_state_values)
97
+
98
+ @property
99
+ def is_terminal(self) -> bool:
100
+ return self._terminal_dict[self._all_state_values[self._current_index]]
101
+
102
+ def __hash__(self):
103
+ return self._current_index
104
+
105
+ def __eq__(self, other: Self) -> bool:
106
+ return self._current_index == other._current_index
107
+
108
+
109
+ class ContinuousState(State):
110
+ def __init__(
111
+ self,
112
+ nbdims: int,
113
+ ranges: Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]],
114
+ init_value: Optional[Union[float, Annotated[NDArray[np.float64], "1D Array"]]] = None
115
+ ):
116
+ super().__init__()
117
+ self._nbdims = nbdims
118
+
119
+ try:
120
+ assert isinstance(ranges, np.ndarray)
121
+ except AssertionError:
122
+ raise TypeError('Range must be a numpy array.')
123
+
124
+ try:
125
+ assert (ranges.dtype == np.float64) or (ranges.dtype == np.float32) or (ranges.dtype == np.float16)
126
+ except AssertionError:
127
+ raise TypeError('It has to be floating type numpy.ndarray.')
128
+
129
+ try:
130
+ assert ranges.ndim == 1 or ranges.ndim == 2
131
+ match ranges.ndim:
132
+ case 1:
133
+ assert ranges.shape[0] == 2
134
+ case 2:
135
+ assert ranges.shape[1] == 2
136
+ case _:
137
+ raise ValueError("Ranges must be of shape (2, ) or (*, 2).")
138
+ except AssertionError:
139
+ raise ValueError("Ranges must be of shape (2, ) or (*, 2).")
140
+
141
+ try:
142
+ assert self._nbdims > 0
143
+ except AssertionError:
144
+ raise ValueError('Number of dimensions must be positive.')
145
+
146
+ if self._nbdims > 1:
147
+ try:
148
+ assert self._nbdims == ranges.shape[0]
149
+ except AssertionError:
150
+ raise ValueError('Number of ranges does not meet the number of dimensions.')
151
+ try:
152
+ assert ranges.shape[1] == 2
153
+ except AssertionError:
154
+ raise ValueError("Only the smallest and largest values in `ranges'.")
155
+ else:
156
+ try:
157
+ assert ranges.shape[0] == 2
158
+ except AssertionError:
159
+ raise ValueError("Only the smallest and largest values in `ranges'.")
160
+
161
+ if self._nbdims > 1:
162
+ try:
163
+ for i in range(ranges.shape[0]):
164
+ assert ranges[i, 0] <= ranges[i, 1]
165
+ except AssertionError:
166
+ raise InvalidRangeError()
167
+ else:
168
+ try:
169
+ assert ranges[0] <= ranges[1]
170
+ except AssertionError:
171
+ raise InvalidRangeError()
172
+
173
+ self._ranges = ranges if self._nbdims > 1 else np.expand_dims(ranges, axis=0)
174
+ if init_value is None:
175
+ self._state_value = np.zeros(self._nbdims)
176
+ for i in range(self._nbdims):
177
+ self._state_value[i] = np.random.uniform(self._ranges[i, 0], self._ranges[i, 1])
178
+ else:
179
+ if self._nbdims > 1:
180
+ try:
181
+ assert init_value.shape[0] == self._nbdims
182
+ except AssertionError:
183
+ raise ValueError('Initialized value does not have the right dimension.')
184
+ for i in range(self._nbdims):
185
+ try:
186
+ assert self._ranges[i, 0] <= init_value[i] <= self.ranges[i, 1]
187
+ except AssertionError:
188
+ raise InvalidRangeError('Initialized value at dimension {} (value: {}) is not within the permitted range ({} -> {})!'.format(i, init_value[i], self._ranges[i, 0], self._ranges[i, 1]))
189
+ else:
190
+ try:
191
+ assert self._ranges[0, 0] <= init_value <= self.ranges[0, 1]
192
+ except AssertionError:
193
+ raise InvalidRangeError('Initialized value is out of range.')
194
+ self._state_value = init_value
195
+
196
+ def set_state_value(self, state_value: Union[float, Annotated[NDArray[np.float64], "1D Array"]]):
197
+ if self._nbdims > 1:
198
+ try:
199
+ assert state_value.shape[0] == self._nbdims
200
+ except AssertionError:
201
+ raise ValueError('Given value does not have the right dimension.')
202
+ for i in range(self._nbdims):
203
+ try:
204
+ assert self.ranges[i, 0] <= state_value[i] <= self.ranges[i, 1]
205
+ except AssertionError:
206
+ raise InvalidRangeError()
207
+ else:
208
+ try:
209
+ assert self.ranges[0, 0] <= state_value <= self.ranges[0, 1]
210
+ except AssertionError:
211
+ raise InvalidRangeError()
212
+
213
+ self._state_value = state_value
214
+
215
+ def get_state_value(self) -> Annotated[NDArray[np.float64], "1D Array"]:
216
+ return self._state_value
217
+
218
+ def get_state_value_ranges(self) -> Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]]:
219
+ return self._ranges
220
+
221
+ def get_state_value_range_at_dimension(self, dimension: int) -> Annotated[NDArray[np.float64], Literal["2"]]:
222
+ if dimension < self._nbdims:
223
+ return self._ranges[dimension]
224
+ else:
225
+ raise ValueError(f"There are only {self._nbdims} dimensions!")
226
+
227
+ @property
228
+ def ranges(self) -> Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]]:
229
+ return self.get_state_value_ranges()
230
+
231
+ @property
232
+ def state_value(self) -> Union[float, NDArray[np.float64]]:
233
+ return self.get_state_value()
234
+
235
+ @state_value.setter
236
+ def state_value(self, new_state_value):
237
+ self.set_state_value(new_state_value)
238
+
239
+ @property
240
+ def nbdims(self) -> int:
241
+ return self._nbdims
242
+
243
+ def __hash__(self):
244
+ return hash(tuple(self._state_value))
245
+
246
+ def __eq__(self, other: Self):
247
+ if self.nbdims != other.nbdims:
248
+ raise ValueError(f"The two states have two different dimensions. ({self.nbdims} vs. {other.nbdims})")
249
+ for i in range(self.nbdims):
250
+ if self.state_value[i] != other.state_value[i]:
251
+ return False
252
+ return True
253
+
254
+
255
+ class Discrete2DCartesianState(DiscreteState):
256
+ def __init__(
257
+ self,
258
+ x_lowlim: int,
259
+ x_hilim: int,
260
+ y_lowlim: int,
261
+ y_hilim: int,
262
+ initial_coordinate: list[int]=None,
263
+ terminals: Optional[dict[DiscreteStateValueType, bool]] = None
264
+ ):
265
+ self._x_lowlim = x_lowlim
266
+ self._x_hilim = x_hilim
267
+ self._y_lowlim = y_lowlim
268
+ self._y_hilim = y_hilim
269
+ self._countx = self._x_hilim - self._x_lowlim + 1
270
+ self._county = self._y_hilim - self._y_lowlim + 1
271
+ if initial_coordinate is None:
272
+ initial_coordinate = [self._x_lowlim, self._y_lowlim]
273
+ initial_value = (initial_coordinate[1] - self._y_lowlim) * self._countx + (initial_coordinate[0] - self._x_lowlim)
274
+ super().__init__(list(range(self._countx*self._county)), initial_value=initial_value, terminals=terminals)
275
+
276
+ def _encode_coordinates(self, x, y) -> int:
277
+ return (y - self._y_lowlim) * self._countx + (x - self._x_lowlim)
278
+
279
+ def encode_coordinates(self, coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]]) -> int:
280
+ if isinstance(coordinates, list):
281
+ assert len(coordinates) == 2
282
+ return self._encode_coordinates(coordinates[0], coordinates[1])
283
+
284
+ def decode_coordinates(self, hashcode) -> list[int]:
285
+ return [hashcode % self._countx + self._x_lowlim, hashcode // self._countx + self._y_lowlim]
286
+
287
+ def get_whether_terminal_given_coordinates(
288
+ self,
289
+ coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]]
290
+ ) -> bool:
291
+ if isinstance(coordinates, list):
292
+ assert len(coordinates) == 2
293
+ hashcode = self._encode_coordinates(coordinates[0], coordinates[1])
294
+ return self._terminal_dict.get(hashcode, False)
295
+
296
+ def set_terminal_given_coordinates(
297
+ self,
298
+ coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]],
299
+ terminal_value: bool
300
+ ) -> None:
301
+ if isinstance(coordinates, list):
302
+ assert len(coordinates) == 2
303
+ hashcode = self._encode_coordinates(coordinates[0], coordinates[1])
304
+ self._terminal_dict[hashcode] = terminal_value
305
+
306
+ @property
307
+ def x_lowlim(self) -> int:
308
+ return self._x_lowlim
309
+
310
+ @property
311
+ def x_hilim(self) -> int:
312
+ return self._x_hilim
313
+
314
+ @property
315
+ def y_lowlim(self) -> int:
316
+ return self._y_lowlim
317
+
318
+ @property
319
+ def y_hilim(self) -> int:
320
+ return self._y_hilim
File without changes
@@ -0,0 +1,110 @@
1
+
2
+ from typing import Annotated
3
+
4
+ import numpy as np
5
+ from npdict import NumpyNDArrayWrappedDict
6
+
7
+ from .utils import AbstractStateActionValueFunctionTemporalDifferenceLearner, decay_schedule, select_action
8
+ from ..policy import DiscreteDeterminsticPolicy
9
+
10
+
11
+ class DoubleQLearner(AbstractStateActionValueFunctionTemporalDifferenceLearner):
12
+ def learn(
13
+ self,
14
+ episodes: int
15
+ ) -> tuple[
16
+ Annotated[NumpyNDArrayWrappedDict, "2D array"],
17
+ Annotated[NumpyNDArrayWrappedDict, "1D array"],
18
+ DiscreteDeterminsticPolicy,
19
+ Annotated[NumpyNDArrayWrappedDict, "3D array"],
20
+ list[DiscreteDeterminsticPolicy]
21
+ ]:
22
+ Q1 = NumpyNDArrayWrappedDict(
23
+ [
24
+ self._state.get_all_possible_state_values(),
25
+ self._action_names
26
+ ],
27
+ default_initial_value=0.0
28
+ )
29
+ Q1_track = NumpyNDArrayWrappedDict(
30
+ [
31
+ list(range(episodes)),
32
+ self._state.get_all_possible_state_values(),
33
+ self._action_names
34
+ ],
35
+ default_initial_value=0.0
36
+ )
37
+ Q2 = NumpyNDArrayWrappedDict(
38
+ [
39
+ self._state.get_all_possible_state_values(),
40
+ self._action_names
41
+ ],
42
+ default_initial_value=0.0
43
+ )
44
+ Q2_track = NumpyNDArrayWrappedDict(
45
+ [
46
+ list(range(episodes)),
47
+ self._state.get_all_possible_state_values(),
48
+ self._action_names
49
+ ],
50
+ default_initial_value=0.0
51
+ )
52
+ pi_track = []
53
+
54
+ Q1_array, Q1_track_array = Q1.to_numpy(), Q1_track.to_numpy()
55
+ Q2_array, Q2_track_array = Q2.to_numpy(), Q2_track.to_numpy()
56
+ alphas = decay_schedule(
57
+ self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
58
+ )
59
+ epsilons = decay_schedule(
60
+ self.init_epsilon, self.min_epsilon, self.epsilon_decay_ratio, episodes
61
+ )
62
+
63
+ for i in range(episodes):
64
+ self._state.state_index = self.initial_state_index
65
+ average_Q = Q1.generate_dict(0.5*(Q1_array+Q2_array))
66
+ done = False
67
+ action_value = select_action(self._state.state_value, average_Q, epsilons[i])
68
+ while not done:
69
+ # decide whether to pick Q1 or Q2
70
+ Q = Q1 if np.random.randint(2) else Q2
71
+
72
+ old_state_value = self._state.state_value
73
+ new_action_value = select_action(self._state.state_value, Q, epsilons[i])
74
+ new_action_func = self._actions_dict[new_action_value]
75
+ self._state = new_action_func(self._state)
76
+ new_state_value = self._state.state_value
77
+ reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
78
+ done = self._state.is_terminal
79
+
80
+ new_state_index = Q.get_key_index(0, new_state_value)
81
+ max_Q_given_state = Q.to_numpy()[new_state_index, :].max()
82
+ td_target = reward + self.gamma * max_Q_given_state * (not done)
83
+ td_error = td_target - Q[old_state_value, action_value]
84
+ Q[old_state_value, action_value] = Q[old_state_value, action_value] + alphas[i] * td_error
85
+
86
+ Q1_track_array[i, :, :] = Q1_array
87
+ Q2_track_array[i, :, :] = Q2_array
88
+ average_Q = Q1.generate_dict(0.5 * (Q1_array + Q2_array))
89
+ pi_track.append(DiscreteDeterminsticPolicy(
90
+ {
91
+ state_value: select_action(state_value, average_Q, epsilon=0.0)
92
+ for state_value in self._state.get_all_possible_state_values()
93
+ }
94
+ ))
95
+
96
+ Q = Q1.generate_dict(0.5 * (Q1_array + Q2_array))
97
+ V_array = np.max(Q.to_numpy(), axis=1)
98
+ V = NumpyNDArrayWrappedDict.from_numpyarray_given_keywords(
99
+ [self._state.get_all_possible_state_values()],
100
+ V_array
101
+ )
102
+ pi = DiscreteDeterminsticPolicy(
103
+ {
104
+ state_value: select_action(state_value, Q, epsilon=0.0)
105
+ for state_value in self._state.get_all_possible_state_values()
106
+ }
107
+ )
108
+ Q_track = Q1_track.generate_dict(0.5 * (Q1_track_array + Q2_track_array))
109
+
110
+ return Q, V, pi, Q_track, pi_track
pyrlutils/td/qlearn.py ADDED
@@ -0,0 +1,86 @@
1
+
2
+ from typing import Annotated
3
+
4
+ import numpy as np
5
+ from npdict import NumpyNDArrayWrappedDict
6
+
7
+ from .utils import AbstractStateActionValueFunctionTemporalDifferenceLearner, decay_schedule, select_action
8
+ from ..policy import DiscreteDeterminsticPolicy
9
+
10
+
11
+ class QLearner(AbstractStateActionValueFunctionTemporalDifferenceLearner):
12
+ def learn(
13
+ self,
14
+ episodes: int
15
+ ) -> tuple[
16
+ Annotated[NumpyNDArrayWrappedDict, "2D array"],
17
+ Annotated[NumpyNDArrayWrappedDict, "1D array"],
18
+ DiscreteDeterminsticPolicy,
19
+ Annotated[NumpyNDArrayWrappedDict, "3D array"],
20
+ list[DiscreteDeterminsticPolicy]
21
+ ]:
22
+ Q = NumpyNDArrayWrappedDict(
23
+ [
24
+ self._state.get_all_possible_state_values(),
25
+ self._action_names
26
+ ],
27
+ default_initial_value=0.0
28
+ )
29
+ Q_track = NumpyNDArrayWrappedDict(
30
+ [
31
+ list(range(episodes)),
32
+ self._state.get_all_possible_state_values(),
33
+ self._action_names
34
+ ],
35
+ default_initial_value=0.0
36
+ )
37
+ pi_track = []
38
+
39
+ Q_array, Q_track_array = Q.to_numpy(), Q_track.to_numpy()
40
+ alphas = decay_schedule(
41
+ self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
42
+ )
43
+ epsilons = decay_schedule(
44
+ self.init_epsilon, self.min_epsilon, self.epsilon_decay_ratio, episodes
45
+ )
46
+
47
+ for i in range(episodes):
48
+ self._state.state_index = self.initial_state_index
49
+ done = False
50
+ action_value = select_action(self._state.state_value, Q, epsilons[i])
51
+ while not done:
52
+ old_state_value = self._state.state_value
53
+ new_action_value = select_action(self._state.state_value, Q, epsilons[i])
54
+ new_action_func = self._actions_dict[new_action_value]
55
+ self._state = new_action_func(self._state)
56
+ new_state_value = self._state.state_value
57
+ reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
58
+ done = self._state.is_terminal
59
+
60
+ new_state_index = Q.get_key_index(0, new_state_value)
61
+ max_Q_given_state = Q.to_numpy()[new_state_index, :].max()
62
+ td_target = reward + self.gamma * max_Q_given_state * (not done)
63
+ td_error = td_target - Q[old_state_value, action_value]
64
+ Q[old_state_value, action_value] = Q[old_state_value, action_value] + alphas[i] * td_error
65
+
66
+ Q_track_array[i, :, :] = Q_array
67
+ pi_track.append(DiscreteDeterminsticPolicy(
68
+ {
69
+ state_value: select_action(state_value, Q, epsilon=0.0)
70
+ for state_value in self._state.get_all_possible_state_values()
71
+ }
72
+ ))
73
+
74
+ V_array = np.max(Q_array, axis=1)
75
+ V = NumpyNDArrayWrappedDict.from_numpyarray_given_keywords(
76
+ [self._state.get_all_possible_state_values()],
77
+ V_array
78
+ )
79
+ pi = DiscreteDeterminsticPolicy(
80
+ {
81
+ state_value: select_action(state_value, Q, epsilon=0.0)
82
+ for state_value in self._state.get_all_possible_state_values()
83
+ }
84
+ )
85
+
86
+ return Q, V, pi, Q_track, pi_track
pyrlutils/td/sarsa.py ADDED
@@ -0,0 +1,86 @@
1
+
2
+ from typing import Annotated
3
+
4
+ import numpy as np
5
+ from npdict import NumpyNDArrayWrappedDict
6
+
7
+ from .utils import AbstractStateActionValueFunctionTemporalDifferenceLearner, decay_schedule, select_action
8
+ from ..policy import DiscreteDeterminsticPolicy
9
+
10
+
11
+ class SARSALearner(AbstractStateActionValueFunctionTemporalDifferenceLearner):
12
+ def learn(
13
+ self,
14
+ episodes: int
15
+ ) -> tuple[
16
+ Annotated[NumpyNDArrayWrappedDict, "2D array"],
17
+ Annotated[NumpyNDArrayWrappedDict, "1D array"],
18
+ DiscreteDeterminsticPolicy,
19
+ Annotated[NumpyNDArrayWrappedDict, "3D array"],
20
+ list[DiscreteDeterminsticPolicy]
21
+ ]:
22
+ Q = NumpyNDArrayWrappedDict(
23
+ [
24
+ self._state.get_all_possible_state_values(),
25
+ self._action_names
26
+ ],
27
+ default_initial_value=0.0
28
+ )
29
+ Q_track = NumpyNDArrayWrappedDict(
30
+ [
31
+ list(range(episodes)),
32
+ self._state.get_all_possible_state_values(),
33
+ self._action_names
34
+ ],
35
+ default_initial_value=0.0
36
+ )
37
+ pi_track = []
38
+
39
+ Q_array, Q_track_array = Q.to_numpy(), Q_track.to_numpy()
40
+ alphas = decay_schedule(
41
+ self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
42
+ )
43
+ epsilons = decay_schedule(
44
+ self.init_epsilon, self.min_epsilon, self.epsilon_decay_ratio, episodes
45
+ )
46
+
47
+ for i in range(episodes):
48
+ self._state.state_index = self.initial_state_index
49
+ done = False
50
+ action_value = select_action(self._state.state_value, Q, epsilons[i])
51
+ while not done:
52
+ old_state_value = self._state.state_value
53
+ action_func = self._actions_dict[action_value]
54
+ self._state = action_func(self._state)
55
+ new_state_value = self._state.state_value
56
+ reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
57
+ done = self._state.is_terminal
58
+ new_action_value = select_action(new_state_value, Q, epsilons[i])
59
+
60
+ td_target = reward + self.gamma * Q[new_state_value, new_action_value] * (not done)
61
+ td_error = td_target - Q[old_state_value, action_value]
62
+ Q[old_state_value, action_value] = Q[old_state_value, action_value] + alphas[i] * td_error
63
+
64
+ action_value = new_action_value
65
+
66
+ Q_track_array[i, :, :] = Q_array
67
+ pi_track.append(DiscreteDeterminsticPolicy(
68
+ {
69
+ state_value: select_action(state_value, Q, epsilon=0.0)
70
+ for state_value in self._state.get_all_possible_state_values()
71
+ }
72
+ ))
73
+
74
+ V_array = np.max(Q_array, axis=1)
75
+ V = NumpyNDArrayWrappedDict.from_numpyarray_given_keywords(
76
+ [self._state.get_all_possible_state_values()],
77
+ V_array
78
+ )
79
+ pi = DiscreteDeterminsticPolicy(
80
+ {
81
+ state_value: select_action(state_value, Q, epsilon=0.0)
82
+ for state_value in self._state.get_all_possible_state_values()
83
+ }
84
+ )
85
+
86
+ return Q, V, pi, Q_track, pi_track