pyrlutils 0.0.4__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyrlutils might be problematic. Click here for more details.

pyrlutils/td/qlearn.py ADDED
@@ -0,0 +1,86 @@
1
+
2
+ from typing import Annotated
3
+
4
+ import numpy as np
5
+ from npdict import NumpyNDArrayWrappedDict
6
+
7
+ from .utils import AbstractStateActionValueFunctionTemporalDifferenceLearner, decay_schedule, select_action
8
+ from ..policy import DiscreteDeterminsticPolicy
9
+
10
+
11
+ class QLearner(AbstractStateActionValueFunctionTemporalDifferenceLearner):
12
+ def learn(
13
+ self,
14
+ episodes: int
15
+ ) -> tuple[
16
+ Annotated[NumpyNDArrayWrappedDict, "2D array"],
17
+ Annotated[NumpyNDArrayWrappedDict, "1D array"],
18
+ DiscreteDeterminsticPolicy,
19
+ Annotated[NumpyNDArrayWrappedDict, "3D array"],
20
+ list[DiscreteDeterminsticPolicy]
21
+ ]:
22
+ Q = NumpyNDArrayWrappedDict(
23
+ [
24
+ self._state.get_all_possible_state_values(),
25
+ self._action_names
26
+ ],
27
+ default_initial_value=0.0
28
+ )
29
+ Q_track = NumpyNDArrayWrappedDict(
30
+ [
31
+ list(range(episodes)),
32
+ self._state.get_all_possible_state_values(),
33
+ self._action_names
34
+ ],
35
+ default_initial_value=0.0
36
+ )
37
+ pi_track = []
38
+
39
+ Q_array, Q_track_array = Q.to_numpy(), Q_track.to_numpy()
40
+ alphas = decay_schedule(
41
+ self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
42
+ )
43
+ epsilons = decay_schedule(
44
+ self.init_epsilon, self.min_epsilon, self.epsilon_decay_ratio, episodes
45
+ )
46
+
47
+ for i in range(episodes):
48
+ self._state.state_index = self.initial_state_index
49
+ done = False
50
+ action_value = select_action(self._state.state_value, Q, epsilons[i])
51
+ while not done:
52
+ old_state_value = self._state.state_value
53
+ new_action_value = select_action(self._state.state_value, Q, epsilons[i])
54
+ new_action_func = self._actions_dict[new_action_value]
55
+ self._state = new_action_func(self._state)
56
+ new_state_value = self._state.state_value
57
+ reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
58
+ done = self._state.is_terminal
59
+
60
+ new_state_index = Q.get_key_index(0, new_state_value)
61
+ max_Q_given_state = Q.to_numpy()[new_state_index, :].max()
62
+ td_target = reward + self.gamma * max_Q_given_state * (not done)
63
+ td_error = td_target - Q[old_state_value, action_value]
64
+ Q[old_state_value, action_value] = Q[old_state_value, action_value] + alphas[i] * td_error
65
+
66
+ Q_track_array[i, :, :] = Q_array
67
+ pi_track.append(DiscreteDeterminsticPolicy(
68
+ {
69
+ state_value: select_action(state_value, Q, epsilon=0.0)
70
+ for state_value in self._state.get_all_possible_state_values()
71
+ }
72
+ ))
73
+
74
+ V_array = np.max(Q_array, axis=1)
75
+ V = NumpyNDArrayWrappedDict.from_numpyarray_given_keywords(
76
+ [self._state.get_all_possible_state_values()],
77
+ V_array
78
+ )
79
+ pi = DiscreteDeterminsticPolicy(
80
+ {
81
+ state_value: select_action(state_value, Q, epsilon=0.0)
82
+ for state_value in self._state.get_all_possible_state_values()
83
+ }
84
+ )
85
+
86
+ return Q, V, pi, Q_track, pi_track
pyrlutils/td/sarsa.py ADDED
@@ -0,0 +1,86 @@
1
+
2
+ from typing import Annotated
3
+
4
+ import numpy as np
5
+ from npdict import NumpyNDArrayWrappedDict
6
+
7
+ from .utils import AbstractStateActionValueFunctionTemporalDifferenceLearner, decay_schedule, select_action
8
+ from ..policy import DiscreteDeterminsticPolicy
9
+
10
+
11
+ class SARSALearner(AbstractStateActionValueFunctionTemporalDifferenceLearner):
12
+ def learn(
13
+ self,
14
+ episodes: int
15
+ ) -> tuple[
16
+ Annotated[NumpyNDArrayWrappedDict, "2D array"],
17
+ Annotated[NumpyNDArrayWrappedDict, "1D array"],
18
+ DiscreteDeterminsticPolicy,
19
+ Annotated[NumpyNDArrayWrappedDict, "3D array"],
20
+ list[DiscreteDeterminsticPolicy]
21
+ ]:
22
+ Q = NumpyNDArrayWrappedDict(
23
+ [
24
+ self._state.get_all_possible_state_values(),
25
+ self._action_names
26
+ ],
27
+ default_initial_value=0.0
28
+ )
29
+ Q_track = NumpyNDArrayWrappedDict(
30
+ [
31
+ list(range(episodes)),
32
+ self._state.get_all_possible_state_values(),
33
+ self._action_names
34
+ ],
35
+ default_initial_value=0.0
36
+ )
37
+ pi_track = []
38
+
39
+ Q_array, Q_track_array = Q.to_numpy(), Q_track.to_numpy()
40
+ alphas = decay_schedule(
41
+ self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
42
+ )
43
+ epsilons = decay_schedule(
44
+ self.init_epsilon, self.min_epsilon, self.epsilon_decay_ratio, episodes
45
+ )
46
+
47
+ for i in range(episodes):
48
+ self._state.state_index = self.initial_state_index
49
+ done = False
50
+ action_value = select_action(self._state.state_value, Q, epsilons[i])
51
+ while not done:
52
+ old_state_value = self._state.state_value
53
+ action_func = self._actions_dict[action_value]
54
+ self._state = action_func(self._state)
55
+ new_state_value = self._state.state_value
56
+ reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
57
+ done = self._state.is_terminal
58
+ new_action_value = select_action(new_state_value, Q, epsilons[i])
59
+
60
+ td_target = reward + self.gamma * Q[new_state_value, new_action_value] * (not done)
61
+ td_error = td_target - Q[old_state_value, action_value]
62
+ Q[old_state_value, action_value] = Q[old_state_value, action_value] + alphas[i] * td_error
63
+
64
+ action_value = new_action_value
65
+
66
+ Q_track_array[i, :, :] = Q_array
67
+ pi_track.append(DiscreteDeterminsticPolicy(
68
+ {
69
+ state_value: select_action(state_value, Q, epsilon=0.0)
70
+ for state_value in self._state.get_all_possible_state_values()
71
+ }
72
+ ))
73
+
74
+ V_array = np.max(Q_array, axis=1)
75
+ V = NumpyNDArrayWrappedDict.from_numpyarray_given_keywords(
76
+ [self._state.get_all_possible_state_values()],
77
+ V_array
78
+ )
79
+ pi = DiscreteDeterminsticPolicy(
80
+ {
81
+ state_value: select_action(state_value, Q, epsilon=0.0)
82
+ for state_value in self._state.get_all_possible_state_values()
83
+ }
84
+ )
85
+
86
+ return Q, V, pi, Q_track, pi_track
@@ -0,0 +1,111 @@
1
+
2
+ from typing import Annotated
3
+
4
+ import numpy as np
5
+ from npdict import NumpyNDArrayWrappedDict
6
+
7
+ from .utils import decay_schedule, TimeDifferencePathElements, AbstractStateValueFunctionTemporalDifferenceLearner
8
+
9
+
10
+ class SingleStepTemporalDifferenceLearner(AbstractStateValueFunctionTemporalDifferenceLearner):
11
+ def learn(
12
+ self,
13
+ episodes: int
14
+ ) -> tuple[Annotated[NumpyNDArrayWrappedDict, "1D Array"], Annotated[NumpyNDArrayWrappedDict, "2D Array"]]:
15
+ V = NumpyNDArrayWrappedDict(
16
+ [self._state.get_all_possible_state_values()],
17
+ default_initial_value=0.0
18
+ )
19
+ V_track = NumpyNDArrayWrappedDict(
20
+ [list(range(episodes)), self._state.get_all_possible_state_values()],
21
+ default_initial_value=0.0
22
+ )
23
+ V_array, V_track_array = V.to_numpy(), V_track.to_numpy()
24
+ alphas = decay_schedule(
25
+ self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
26
+ )
27
+
28
+ for i in range(episodes):
29
+ self._state.state_index = self.initial_state_index
30
+ done = False
31
+ while not done:
32
+ old_state_value = self._state.state_value
33
+ action_value = self._policy.get_action_value(self._state.state_value)
34
+ action_func = self._actions_dict[action_value]
35
+ self._state = action_func(self._state)
36
+ new_state_value = self._state.state_value
37
+ reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
38
+ done = self._state.is_terminal
39
+
40
+ td_target = reward + self.gamma * V[new_state_value] * (not done)
41
+ td_error = td_target - V[old_state_value]
42
+ V[old_state_value] = V[old_state_value] + alphas[i] * td_error
43
+
44
+ V_track_array[i, :] = V_array
45
+
46
+ return V, V_track
47
+
48
+
49
+ class MultipleStepTemporalDifferenceLearner(AbstractStateValueFunctionTemporalDifferenceLearner):
50
+ def learn(
51
+ self,
52
+ episodes: int,
53
+ n_steps: int=3
54
+ ) -> tuple[Annotated[NumpyNDArrayWrappedDict, "1D Array"], Annotated[NumpyNDArrayWrappedDict, "2D Array"]]:
55
+ V = NumpyNDArrayWrappedDict(
56
+ [self._state.get_all_possible_state_values()],
57
+ default_initial_value=0.0
58
+ )
59
+ V_track = NumpyNDArrayWrappedDict(
60
+ [list(range(episodes)), self._state.get_all_possible_state_values()],
61
+ default_initial_value=0.0
62
+ )
63
+ V_array, V_track_array = V.to_numpy(), V_track.to_numpy()
64
+ alphas = decay_schedule(
65
+ self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
66
+ )
67
+ discounts = np.logspace(0, n_steps-1, num=n_steps+1, base=self.gamma, endpoint=False)
68
+
69
+ for i in range(episodes):
70
+ self._state.state_index = self.initial_state_index
71
+ done = False
72
+ path = []
73
+
74
+ while not done or path is not None:
75
+ path = path[1:] # worth revisiting this line
76
+
77
+ new_state_value = self._state._get_state_value_from_index(self._state.nb_state_values-1)
78
+ while not done and len(path) < n_steps:
79
+ old_state_value = self._state.state_value
80
+ action_value = self._policy.get_action_value(self._state.state_value)
81
+ action_func = self._actions_dict[action_value]
82
+ self._state = action_func(self._state)
83
+ new_state_value = self._state.state_value
84
+ reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
85
+ done = self._state.is_terminal
86
+
87
+ path.append(
88
+ TimeDifferencePathElements(
89
+ this_state_value=old_state_value,
90
+ reward=reward,
91
+ next_state_value=new_state_value,
92
+ done=done
93
+ )
94
+ )
95
+ if done:
96
+ break
97
+
98
+ n = len(path)
99
+ estimated_state_value = path[0].this_state_value
100
+ rewards = np.array([this_moment.reward for this_moment in path])
101
+ partial_return = discounts[n:] * rewards
102
+ bs_val = discounts[-1] * V[new_state_value] * (not done)
103
+ ntd_target = np.sum(np.append(partial_return, bs_val))
104
+ ntd_error = ntd_target - V[estimated_state_value]
105
+ V[(estimated_state_value,)] = V[estimated_state_value] + alphas[i] * ntd_error
106
+ if len(path) == 1 and path[0].done:
107
+ path = None
108
+
109
+ V_track_array[i, :] = V_array
110
+
111
+ return V, V_track
pyrlutils/td/utils.py ADDED
@@ -0,0 +1,258 @@
1
+
2
+ from typing import Annotated, Union, Optional
3
+ from dataclasses import dataclass
4
+ from abc import ABC, abstractmethod
5
+
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+ from npdict import NumpyNDArrayWrappedDict
9
+
10
+ from ..state import DiscreteStateValueType
11
+ from ..action import DiscreteActionValueType
12
+ from ..policy import DiscretePolicy
13
+ from ..transition import TransitionProbabilityFactory
14
+
15
+
16
+ def decay_schedule(
17
+ init_value: float,
18
+ min_value: float,
19
+ decay_ratio: float,
20
+ max_steps: int,
21
+ log_start: int=-2,
22
+ log_base: int=10
23
+ ) -> Annotated[NDArray[np.float64], "1D Array"]:
24
+ decay_steps = int(max_steps*decay_ratio)
25
+ rem_steps = max_steps - decay_steps
26
+
27
+ values = np.logspace(log_start, 0, decay_steps, base=log_base, endpoint=True)[::-1]
28
+ values = (values - values.min()) / (values.max() - values.min())
29
+ values = (init_value - min_value) * values + min_value
30
+ values = np.pad(values, (0, rem_steps), 'edge')
31
+ return values
32
+
33
+
34
+ def select_action(
35
+ state_value: DiscreteStateValueType,
36
+ Q: Union[Annotated[NDArray[np.float64], "2D Array"], NumpyNDArrayWrappedDict],
37
+ epsilon: float,
38
+ ) -> Union[DiscreteActionValueType, int]:
39
+ if np.random.random() <= epsilon:
40
+ if isinstance(Q, NumpyNDArrayWrappedDict):
41
+ return np.random.choice(Q._lists_keystrings[1])
42
+ else:
43
+ return np.random.choice(np.arange(Q.shape[1]))
44
+
45
+ q_matrix = Q.to_numpy() if isinstance(Q, NumpyNDArrayWrappedDict) else Q
46
+ state_index = Q.get_key_index(0, state_value) if isinstance(Q, NumpyNDArrayWrappedDict) else state_value
47
+ max_index = np.argmax(q_matrix[state_index, :])
48
+
49
+ if isinstance(Q, NumpyNDArrayWrappedDict):
50
+ return Q._lists_keystrings[1][max_index]
51
+ else:
52
+ return max_index
53
+
54
+
55
+ @dataclass
56
+ class TimeDifferencePathElements:
57
+ this_state_value: DiscreteStateValueType
58
+ reward: float
59
+ next_state_value: DiscreteStateValueType
60
+ done: bool
61
+
62
+
63
+ class AbstractStateValueFunctionTemporalDifferenceLearner(ABC):
64
+ def __init__(
65
+ self,
66
+ transprobfac: TransitionProbabilityFactory,
67
+ gamma: float=1.0,
68
+ init_alpha: float=0.5,
69
+ min_alpha: float=0.01,
70
+ alpha_decay_ratio: float=0.3,
71
+ policy: Optional[DiscretePolicy]=None,
72
+ initial_state_index: int=0
73
+ ):
74
+ self._gamma = gamma
75
+ self._init_alpha = init_alpha
76
+ self._min_alpha = min_alpha
77
+ try:
78
+ assert 0.0 <= alpha_decay_ratio <= 1.0
79
+ except AssertionError:
80
+ raise ValueError("alpha_decay_ratio must be between 0 and 1!")
81
+ self._alpha_decay_ratio = alpha_decay_ratio
82
+ self._transprobfac = transprobfac
83
+ self._state, self._actions_dict, self._indrewardfcn = self._transprobfac.generate_mdp_objects()
84
+ self._action_names = list(self._actions_dict.keys())
85
+ self._actions_to_indices = {action_value: idx for idx, action_value in enumerate(self._action_names)}
86
+ self._policy = policy
87
+ try:
88
+ assert 0 <= initial_state_index < self._state.nb_state_values
89
+ except AssertionError:
90
+ raise ValueError(f"Initial state index must be between 0 and {self._state.nb_state_values}")
91
+ self._init_state_index = initial_state_index
92
+
93
+ @abstractmethod
94
+ def learn(self, *args, **kwargs) -> tuple[Annotated[NDArray[np.float64], "1D Array"], Annotated[NDArray[np.float64], "2D Array"]]:
95
+ raise NotImplementedError()
96
+
97
+ @property
98
+ def nb_states(self) -> int:
99
+ return self._state.nb_state_values
100
+
101
+ @property
102
+ def policy(self) -> DiscretePolicy:
103
+ return self._policy
104
+
105
+ @policy.setter
106
+ def policy(self, val: DiscretePolicy):
107
+ self._policy = val
108
+
109
+ @property
110
+ def gamma(self) -> float:
111
+ return self._gamma
112
+
113
+ @gamma.setter
114
+ def gamma(self, val: float):
115
+ self._gamma = val
116
+
117
+ @property
118
+ def init_alpha(self) -> float:
119
+ return self._init_alpha
120
+
121
+ @init_alpha.setter
122
+ def init_alpha(self, val: float):
123
+ self._init_alpha = val
124
+
125
+ @property
126
+ def min_alpha(self) -> float:
127
+ return self._min_alpha
128
+
129
+ @min_alpha.setter
130
+ def min_alpha(self, val: float):
131
+ self._min_alpha = val
132
+
133
+ @property
134
+ def alpha_decay_ratio(self) -> float:
135
+ return self._alpha_decay_ratio
136
+
137
+ @property
138
+ def initial_state_index(self) -> int:
139
+ return self._init_state_index
140
+
141
+ @initial_state_index.setter
142
+ def initial_state_index(self, val: int):
143
+ self._init_state_index = val
144
+
145
+
146
+
147
+ class AbstractStateActionValueFunctionTemporalDifferenceLearner(ABC):
148
+ def __init__(
149
+ self,
150
+ transprobfac: TransitionProbabilityFactory,
151
+ gamma: float=1.0,
152
+ init_alpha: float=0.5,
153
+ min_alpha: float=0.01,
154
+ alpha_decay_ratio: float=0.3,
155
+ init_epsilon: float=1.0,
156
+ min_epsilon: float=0.1,
157
+ epsilon_decay_ratio: float=0.9,
158
+ policy: Optional[DiscretePolicy]=None,
159
+ initial_state_index: int=0
160
+ ):
161
+ self._gamma = gamma
162
+ self._init_alpha = init_alpha
163
+ self._min_alpha = min_alpha
164
+ try:
165
+ assert 0.0 <= alpha_decay_ratio <= 1.0
166
+ except AssertionError:
167
+ raise ValueError("alpha_decay_ratio must be between 0 and 1!")
168
+ self._alpha_decay_ratio = alpha_decay_ratio
169
+ self._init_epsilon = init_epsilon
170
+ self._min_epsilon = min_epsilon
171
+ self._epsilon_decay_ratio = epsilon_decay_ratio
172
+
173
+ self._transprobfac = transprobfac
174
+ self._state, self._actions_dict, self._indrewardfcn = self._transprobfac.generate_mdp_objects()
175
+ self._action_names = list(self._actions_dict.keys())
176
+ self._actions_to_indices = {action_value: idx for idx, action_value in enumerate(self._action_names)}
177
+ self._policy = policy
178
+ try:
179
+ assert 0 <= initial_state_index < self._state.nb_state_values
180
+ except AssertionError:
181
+ raise ValueError(f"Initial state index must be between 0 and {self._state.nb_state_values}")
182
+ self._init_state_index = initial_state_index
183
+
184
+ @abstractmethod
185
+ def learn(self, *args, **kwargs) -> tuple[Annotated[NDArray[np.float64], "1D Array"], Annotated[NDArray[np.float64], "2D Array"]]:
186
+ raise NotImplementedError()
187
+
188
+ @property
189
+ def nb_states(self) -> int:
190
+ return self._state.nb_state_values
191
+
192
+ @property
193
+ def policy(self) -> DiscretePolicy:
194
+ return self._policy
195
+
196
+ @policy.setter
197
+ def policy(self, val: DiscretePolicy):
198
+ self._policy = val
199
+
200
+ @property
201
+ def gamma(self) -> float:
202
+ return self._gamma
203
+
204
+ @gamma.setter
205
+ def gamma(self, val: float):
206
+ self._gamma = val
207
+
208
+ @property
209
+ def init_alpha(self) -> float:
210
+ return self._init_alpha
211
+
212
+ @init_alpha.setter
213
+ def init_alpha(self, val: float):
214
+ self._init_alpha = val
215
+
216
+ @property
217
+ def min_alpha(self) -> float:
218
+ return self._min_alpha
219
+
220
+ @min_alpha.setter
221
+ def min_alpha(self, val: float):
222
+ self._min_alpha = val
223
+
224
+ @property
225
+ def alpha_decay_ratio(self) -> float:
226
+ return self._alpha_decay_ratio
227
+
228
+ @property
229
+ def init_epsilon(self) -> float:
230
+ return self._init_epsilon
231
+
232
+ @init_epsilon.setter
233
+ def init_epsilon(self, val: float):
234
+ self._init_epsilon = val
235
+
236
+ @property
237
+ def min_epsilon(self) -> float:
238
+ return self._min_epsilon
239
+
240
+ @min_epsilon.setter
241
+ def min_epsilon(self, val: float):
242
+ self._min_epsilon = val
243
+
244
+ @property
245
+ def epsilon_decay_ratio(self) -> float:
246
+ return self._epsilon_decay_ratio
247
+
248
+ @epsilon_decay_ratio.setter
249
+ def epsilon_decay_ratio(self, val: float):
250
+ self._epsilon_decay_ratio = val
251
+
252
+ @property
253
+ def initial_state_index(self) -> int:
254
+ return self._init_state_index
255
+
256
+ @initial_state_index.setter
257
+ def initial_state_index(self, val: int):
258
+ self._init_state_index = val
pyrlutils/transition.py CHANGED
@@ -1,6 +1,7 @@
1
1
 
2
- from types import LambdaType
3
- from typing import Tuple, Dict
2
+ from types import LambdaType, FunctionType
3
+ from typing import Union
4
+ from dataclasses import dataclass
4
5
 
5
6
  import numpy as np
6
7
 
@@ -9,28 +10,12 @@ from .reward import IndividualRewardFunction
9
10
  from .action import Action, DiscreteActionValueType
10
11
 
11
12
 
13
+ @dataclass
12
14
  class NextStateTuple:
13
- def __init__(self, next_state_value: DiscreteStateValueType, probability: float, reward: float, terminal: bool):
14
- self._next_state_value = next_state_value
15
- self._probability = probability
16
- self._reward = reward
17
- self._terminal = terminal
18
-
19
- @property
20
- def next_state_value(self) -> DiscreteStateValueType:
21
- return self._next_state_value
22
-
23
- @property
24
- def probability(self) -> float:
25
- return self._probability
26
-
27
- @property
28
- def reward(self) -> float:
29
- return self._reward
30
-
31
- @property
32
- def terminal(self) -> bool:
33
- return self._terminal
15
+ next_state_value: DiscreteStateValueType
16
+ probability: float
17
+ reward: float
18
+ terminal: bool
34
19
 
35
20
 
36
21
  class TransitionProbabilityFactory:
@@ -40,7 +25,11 @@ class TransitionProbabilityFactory:
40
25
  self._all_action_values = []
41
26
  self._objects_generated = False
42
27
 
43
- def add_state_transitions(self, state_value: DiscreteStateValueType, action_values_to_next_state: dict):
28
+ def add_state_transitions(
29
+ self,
30
+ state_value: DiscreteStateValueType,
31
+ action_values_to_next_state: dict[DiscreteActionValueType, Union[list[NextStateTuple], dict]]
32
+ ):
44
33
  if state_value not in self._all_state_values:
45
34
  self._all_state_values.append(state_value)
46
35
 
@@ -69,7 +58,10 @@ class TransitionProbabilityFactory:
69
58
 
70
59
  self._transprobs[state_value] = this_state_transition_dict
71
60
 
72
- def _get_probs_for_eachstate(self, action_value: DiscreteActionValueType) -> Dict[DiscreteStateValueType, NextStateTuple]:
61
+ def _get_probs_for_eachstate(
62
+ self,
63
+ action_value: DiscreteActionValueType
64
+ ) -> dict[DiscreteStateValueType, list[NextStateTuple]]:
73
65
  state_nexttuples = {}
74
66
  for state_value, action_nexttuples_pair in self._transprobs.items():
75
67
  for this_action_value, nexttuples in action_nexttuples_pair.items():
@@ -77,7 +69,10 @@ class TransitionProbabilityFactory:
77
69
  state_nexttuples[state_value] = nexttuples
78
70
  return state_nexttuples
79
71
 
80
- def _generate_action_function(self, state_nexttuples: dict) -> LambdaType:
72
+ def _generate_action_function(
73
+ self,
74
+ state_nexttuples: dict[DiscreteStateValueType, list[NextStateTuple]]
75
+ ) -> Union[FunctionType, LambdaType]:
81
76
 
82
77
  def _action_function(state: DiscreteState) -> DiscreteState:
83
78
  nexttuples = state_nexttuples[state.state_value]
@@ -91,7 +86,11 @@ class TransitionProbabilityFactory:
91
86
 
92
87
  def _generate_individual_reward_function(self) -> IndividualRewardFunction:
93
88
 
94
- def _individual_reward_function(state_value, action_value, next_state_value) -> float:
89
+ def _individual_reward_function(
90
+ state_value: DiscreteStateValueType,
91
+ action_value: DiscreteActionValueType,
92
+ next_state_value: DiscreteStateValueType
93
+ ) -> float:
95
94
  if state_value not in self._transprobs.keys():
96
95
  return 0.
97
96
 
@@ -105,15 +104,22 @@ class TransitionProbabilityFactory:
105
104
  return reward
106
105
 
107
106
  class ThisIndividualRewardFunction(IndividualRewardFunction):
108
- def __init__(self):
109
- super().__init__()
110
-
111
- def reward(self, state_value, action_value, next_state_value) -> float:
107
+ def reward(
108
+ self,
109
+ state_value: DiscreteStateValueType,
110
+ action_value: DiscreteActionValueType,
111
+ next_state_value: DiscreteStateValueType
112
+ ) -> float:
112
113
  return _individual_reward_function(state_value, action_value, next_state_value)
113
114
 
114
115
  return ThisIndividualRewardFunction()
115
116
 
116
- def get_probability(self, state_value, action_value, new_state_value) -> float:
117
+ def get_probability(
118
+ self,
119
+ state_value: DiscreteStateValueType,
120
+ action_value: DiscreteActionValueType,
121
+ new_state_value: DiscreteStateValueType
122
+ ) -> float:
117
123
  if state_value not in self._transprobs.keys():
118
124
  return 0.
119
125
 
@@ -127,18 +133,21 @@ class TransitionProbabilityFactory:
127
133
  return probs
128
134
 
129
135
  @property
130
- def transition_probabilities(self) -> dict:
136
+ def transition_probabilities(self) -> dict[DiscreteStateValueType, dict[DiscreteActionValueType, list[NextStateTuple]]]:
131
137
  return self._transprobs
132
138
 
133
- def generate_mdp_objects(self) -> Tuple[DiscreteState, Dict[DiscreteActionValueType, Action], IndividualRewardFunction]:
139
+ def generate_mdp_objects(self) -> tuple[DiscreteState, dict[DiscreteActionValueType, Action], IndividualRewardFunction]:
134
140
  state = DiscreteState(self._all_state_values)
135
141
  actions_dict = {}
136
142
  for action_value in self._all_action_values:
137
143
  state_nexttuple = self._get_probs_for_eachstate(action_value)
138
144
  actions_dict[action_value] = Action(self._generate_action_function(state_nexttuple))
145
+ for next_tuples in state_nexttuple.values():
146
+ for next_tuple in next_tuples:
147
+ state._terminal_dict[next_tuple.next_state_value] = next_tuple.terminal
139
148
 
140
149
  individual_reward_fcn = self._generate_individual_reward_function()
141
-
150
+ self._objects_generated = True
142
151
  return state, actions_dict, individual_reward_fcn
143
152
 
144
153
  @property