pyrlutils 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyrlutils might be problematic. Click here for more details.
- pyrlutils/state.py +4 -1
- pyrlutils/td/qlearn.py +86 -0
- pyrlutils/td/sarsa.py +86 -0
- pyrlutils/td/{td.py → state_td.py} +38 -28
- pyrlutils/td/utils.py +149 -10
- {pyrlutils-0.1.0.dist-info → pyrlutils-0.1.1.dist-info}/METADATA +2 -1
- {pyrlutils-0.1.0.dist-info → pyrlutils-0.1.1.dist-info}/RECORD +10 -8
- {pyrlutils-0.1.0.dist-info → pyrlutils-0.1.1.dist-info}/WHEEL +0 -0
- {pyrlutils-0.1.0.dist-info → pyrlutils-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {pyrlutils-0.1.0.dist-info → pyrlutils-0.1.1.dist-info}/top_level.txt +0 -0
pyrlutils/state.py
CHANGED
|
@@ -66,6 +66,9 @@ class DiscreteState(State):
|
|
|
66
66
|
def get_all_possible_state_values(self) -> list[DiscreteStateValueType]:
|
|
67
67
|
return self._all_state_values
|
|
68
68
|
|
|
69
|
+
def query_state_index_from_value(self, value: DiscreteStateValueType) -> int:
|
|
70
|
+
return self._state_values_to_indices[value]
|
|
71
|
+
|
|
69
72
|
@property
|
|
70
73
|
def state_index(self) -> int:
|
|
71
74
|
return self._current_index
|
|
@@ -73,7 +76,7 @@ class DiscreteState(State):
|
|
|
73
76
|
@state_index.setter
|
|
74
77
|
def state_index(self, new_index: int) -> None:
|
|
75
78
|
if new_index >= len(self._all_state_values):
|
|
76
|
-
raise ValueError(f"Invalid index {new_index}; it must be less than {
|
|
79
|
+
raise ValueError(f"Invalid index {new_index}; it must be less than {self.nb_state_values}.")
|
|
77
80
|
self._current_index = new_index
|
|
78
81
|
|
|
79
82
|
@property
|
pyrlutils/td/qlearn.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Annotated
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from npdict import NumpyNDArrayWrappedDict
|
|
6
|
+
|
|
7
|
+
from .utils import AbstractStateActionValueFunctionTemporalDifferenceLearner, decay_schedule, select_action
|
|
8
|
+
from ..policy import DiscreteDeterminsticPolicy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class QLearner(AbstractStateActionValueFunctionTemporalDifferenceLearner):
|
|
12
|
+
def learn(
|
|
13
|
+
self,
|
|
14
|
+
episodes: int
|
|
15
|
+
) -> tuple[
|
|
16
|
+
Annotated[NumpyNDArrayWrappedDict, "2D array"],
|
|
17
|
+
Annotated[NumpyNDArrayWrappedDict, "1D array"],
|
|
18
|
+
DiscreteDeterminsticPolicy,
|
|
19
|
+
Annotated[NumpyNDArrayWrappedDict, "3D array"],
|
|
20
|
+
list[DiscreteDeterminsticPolicy]
|
|
21
|
+
]:
|
|
22
|
+
Q = NumpyNDArrayWrappedDict(
|
|
23
|
+
[
|
|
24
|
+
self._state.get_all_possible_state_values(),
|
|
25
|
+
self._action_names
|
|
26
|
+
],
|
|
27
|
+
default_initial_value=0.0
|
|
28
|
+
)
|
|
29
|
+
Q_track = NumpyNDArrayWrappedDict(
|
|
30
|
+
[
|
|
31
|
+
list(range(episodes)),
|
|
32
|
+
self._state.get_all_possible_state_values(),
|
|
33
|
+
self._action_names
|
|
34
|
+
],
|
|
35
|
+
default_initial_value=0.0
|
|
36
|
+
)
|
|
37
|
+
pi_track = []
|
|
38
|
+
|
|
39
|
+
Q_array, Q_track_array = Q.to_numpy(), Q_track.to_numpy()
|
|
40
|
+
alphas = decay_schedule(
|
|
41
|
+
self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
|
|
42
|
+
)
|
|
43
|
+
epsilons = decay_schedule(
|
|
44
|
+
self.init_epsilon, self.min_epsilon, self.epsilon_decay_ratio, episodes
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
for i in range(episodes):
|
|
48
|
+
self._state.state_index = self.initial_state_index
|
|
49
|
+
done = False
|
|
50
|
+
action_value = select_action(self._state.state_value, Q, epsilons[i])
|
|
51
|
+
while not done:
|
|
52
|
+
old_state_value = self._state.state_value
|
|
53
|
+
new_action_value = select_action(self._state.state_value, Q, epsilons[i])
|
|
54
|
+
new_action_func = self._actions_dict[new_action_value]
|
|
55
|
+
self._state = new_action_func(self._state)
|
|
56
|
+
new_state_value = self._state.state_value
|
|
57
|
+
reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
|
|
58
|
+
done = self._state.is_terminal
|
|
59
|
+
|
|
60
|
+
new_state_index = Q.get_key_index(0, new_state_value)
|
|
61
|
+
max_Q_given_state = Q.to_numpy()[new_state_index, :].max()
|
|
62
|
+
td_target = reward + self.gamma * max_Q_given_state * (not done)
|
|
63
|
+
td_error = td_target - Q[old_state_value, action_value]
|
|
64
|
+
Q[old_state_value, action_value] = Q[old_state_value, action_value] + alphas[i] * td_error
|
|
65
|
+
|
|
66
|
+
Q_track_array[i, :, :] = Q_array
|
|
67
|
+
pi_track.append(DiscreteDeterminsticPolicy(
|
|
68
|
+
{
|
|
69
|
+
state_value: select_action(state_value, Q, epsilon=0.0)
|
|
70
|
+
for state_value in self._state.get_all_possible_state_values()
|
|
71
|
+
}
|
|
72
|
+
))
|
|
73
|
+
|
|
74
|
+
V_array = np.max(Q_array, axis=1)
|
|
75
|
+
V = NumpyNDArrayWrappedDict.from_numpyarray_given_keywords(
|
|
76
|
+
[self._state.get_all_possible_state_values()],
|
|
77
|
+
V_array
|
|
78
|
+
)
|
|
79
|
+
pi = DiscreteDeterminsticPolicy(
|
|
80
|
+
{
|
|
81
|
+
state_value: select_action(state_value, Q, epsilon=0.0)
|
|
82
|
+
for state_value in self._state.get_all_possible_state_values()
|
|
83
|
+
}
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return Q, V, pi, Q_track, pi_track
|
pyrlutils/td/sarsa.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Annotated
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from npdict import NumpyNDArrayWrappedDict
|
|
6
|
+
|
|
7
|
+
from .utils import AbstractStateActionValueFunctionTemporalDifferenceLearner, decay_schedule, select_action
|
|
8
|
+
from ..policy import DiscreteDeterminsticPolicy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SARSALearner(AbstractStateActionValueFunctionTemporalDifferenceLearner):
|
|
12
|
+
def learn(
|
|
13
|
+
self,
|
|
14
|
+
episodes: int
|
|
15
|
+
) -> tuple[
|
|
16
|
+
Annotated[NumpyNDArrayWrappedDict, "2D array"],
|
|
17
|
+
Annotated[NumpyNDArrayWrappedDict, "1D array"],
|
|
18
|
+
DiscreteDeterminsticPolicy,
|
|
19
|
+
Annotated[NumpyNDArrayWrappedDict, "3D array"],
|
|
20
|
+
list[DiscreteDeterminsticPolicy]
|
|
21
|
+
]:
|
|
22
|
+
Q = NumpyNDArrayWrappedDict(
|
|
23
|
+
[
|
|
24
|
+
self._state.get_all_possible_state_values(),
|
|
25
|
+
self._action_names
|
|
26
|
+
],
|
|
27
|
+
default_initial_value=0.0
|
|
28
|
+
)
|
|
29
|
+
Q_track = NumpyNDArrayWrappedDict(
|
|
30
|
+
[
|
|
31
|
+
list(range(episodes)),
|
|
32
|
+
self._state.get_all_possible_state_values(),
|
|
33
|
+
self._action_names
|
|
34
|
+
],
|
|
35
|
+
default_initial_value=0.0
|
|
36
|
+
)
|
|
37
|
+
pi_track = []
|
|
38
|
+
|
|
39
|
+
Q_array, Q_track_array = Q.to_numpy(), Q_track.to_numpy()
|
|
40
|
+
alphas = decay_schedule(
|
|
41
|
+
self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
|
|
42
|
+
)
|
|
43
|
+
epsilons = decay_schedule(
|
|
44
|
+
self.init_epsilon, self.min_epsilon, self.epsilon_decay_ratio, episodes
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
for i in range(episodes):
|
|
48
|
+
self._state.state_index = self.initial_state_index
|
|
49
|
+
done = False
|
|
50
|
+
action_value = select_action(self._state.state_value, Q, epsilons[i])
|
|
51
|
+
while not done:
|
|
52
|
+
old_state_value = self._state.state_value
|
|
53
|
+
action_func = self._actions_dict[action_value]
|
|
54
|
+
self._state = action_func(self._state)
|
|
55
|
+
new_state_value = self._state.state_value
|
|
56
|
+
reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
|
|
57
|
+
done = self._state.is_terminal
|
|
58
|
+
new_action_value = select_action(new_state_value, Q, epsilons[i])
|
|
59
|
+
|
|
60
|
+
td_target = reward + self.gamma * Q[new_state_value, new_action_value] * (not done)
|
|
61
|
+
td_error = td_target - Q[old_state_value, action_value]
|
|
62
|
+
Q[old_state_value, action_value] = Q[old_state_value, action_value] + alphas[i] * td_error
|
|
63
|
+
|
|
64
|
+
action_value = new_action_value
|
|
65
|
+
|
|
66
|
+
Q_track_array[i, :, :] = Q_array
|
|
67
|
+
pi_track.append(DiscreteDeterminsticPolicy(
|
|
68
|
+
{
|
|
69
|
+
state_value: select_action(state_value, Q, epsilon=0.0)
|
|
70
|
+
for state_value in self._state.get_all_possible_state_values()
|
|
71
|
+
}
|
|
72
|
+
))
|
|
73
|
+
|
|
74
|
+
V_array = np.max(Q_array, axis=1)
|
|
75
|
+
V = NumpyNDArrayWrappedDict.from_numpyarray_given_keywords(
|
|
76
|
+
[self._state.get_all_possible_state_values()],
|
|
77
|
+
V_array
|
|
78
|
+
)
|
|
79
|
+
pi = DiscreteDeterminsticPolicy(
|
|
80
|
+
{
|
|
81
|
+
state_value: select_action(state_value, Q, epsilon=0.0)
|
|
82
|
+
for state_value in self._state.get_all_possible_state_values()
|
|
83
|
+
}
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return Q, V, pi, Q_track, pi_track
|
|
@@ -2,83 +2,93 @@
|
|
|
2
2
|
from typing import Annotated
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
|
-
from
|
|
5
|
+
from npdict import NumpyNDArrayWrappedDict
|
|
6
6
|
|
|
7
|
-
from .utils import decay_schedule,
|
|
7
|
+
from .utils import decay_schedule, TimeDifferencePathElements, AbstractStateValueFunctionTemporalDifferenceLearner
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
class SingleStepTemporalDifferenceLearner(
|
|
10
|
+
class SingleStepTemporalDifferenceLearner(AbstractStateValueFunctionTemporalDifferenceLearner):
|
|
11
11
|
def learn(
|
|
12
12
|
self,
|
|
13
13
|
episodes: int
|
|
14
|
-
) -> tuple[Annotated[
|
|
15
|
-
V =
|
|
16
|
-
|
|
14
|
+
) -> tuple[Annotated[NumpyNDArrayWrappedDict, "1D Array"], Annotated[NumpyNDArrayWrappedDict, "2D Array"]]:
|
|
15
|
+
V = NumpyNDArrayWrappedDict(
|
|
16
|
+
[self._state.get_all_possible_state_values()],
|
|
17
|
+
default_initial_value=0.0
|
|
18
|
+
)
|
|
19
|
+
V_track = NumpyNDArrayWrappedDict(
|
|
20
|
+
[list(range(episodes)), self._state.get_all_possible_state_values()],
|
|
21
|
+
default_initial_value=0.0
|
|
22
|
+
)
|
|
23
|
+
V_array, V_track_array = V.to_numpy(), V_track.to_numpy()
|
|
17
24
|
alphas = decay_schedule(
|
|
18
25
|
self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
|
|
19
26
|
)
|
|
20
27
|
|
|
21
28
|
for i in range(episodes):
|
|
22
|
-
self._state.
|
|
29
|
+
self._state.state_index = self.initial_state_index
|
|
23
30
|
done = False
|
|
24
31
|
while not done:
|
|
25
|
-
old_state_index = self._state.state_index
|
|
26
32
|
old_state_value = self._state.state_value
|
|
27
33
|
action_value = self._policy.get_action_value(self._state.state_value)
|
|
28
34
|
action_func = self._actions_dict[action_value]
|
|
29
35
|
self._state = action_func(self._state)
|
|
30
|
-
new_state_index = self._state.state_index
|
|
31
36
|
new_state_value = self._state.state_value
|
|
32
37
|
reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
|
|
33
38
|
done = self._state.is_terminal
|
|
34
39
|
|
|
35
|
-
td_target = reward + self.gamma * V[
|
|
36
|
-
td_error = td_target - V[
|
|
37
|
-
V[
|
|
40
|
+
td_target = reward + self.gamma * V[new_state_value] * (not done)
|
|
41
|
+
td_error = td_target - V[old_state_value]
|
|
42
|
+
V[old_state_value] = V[old_state_value] + alphas[i] * td_error
|
|
38
43
|
|
|
39
|
-
|
|
44
|
+
V_track_array[i, :] = V_array
|
|
40
45
|
|
|
41
46
|
return V, V_track
|
|
42
47
|
|
|
43
48
|
|
|
44
|
-
class MultipleStepTemporalDifferenceLearner(
|
|
49
|
+
class MultipleStepTemporalDifferenceLearner(AbstractStateValueFunctionTemporalDifferenceLearner):
|
|
45
50
|
def learn(
|
|
46
51
|
self,
|
|
47
52
|
episodes: int,
|
|
48
53
|
n_steps: int=3
|
|
49
|
-
) -> tuple[Annotated[
|
|
50
|
-
V =
|
|
51
|
-
|
|
54
|
+
) -> tuple[Annotated[NumpyNDArrayWrappedDict, "1D Array"], Annotated[NumpyNDArrayWrappedDict, "2D Array"]]:
|
|
55
|
+
V = NumpyNDArrayWrappedDict(
|
|
56
|
+
[self._state.get_all_possible_state_values()],
|
|
57
|
+
default_initial_value=0.0
|
|
58
|
+
)
|
|
59
|
+
V_track = NumpyNDArrayWrappedDict(
|
|
60
|
+
[list(range(episodes)), self._state.get_all_possible_state_values()],
|
|
61
|
+
default_initial_value=0.0
|
|
62
|
+
)
|
|
63
|
+
V_array, V_track_array = V.to_numpy(), V_track.to_numpy()
|
|
52
64
|
alphas = decay_schedule(
|
|
53
65
|
self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
|
|
54
66
|
)
|
|
55
67
|
discounts = np.logspace(0, n_steps-1, num=n_steps+1, base=self.gamma, endpoint=False)
|
|
56
68
|
|
|
57
69
|
for i in range(episodes):
|
|
58
|
-
self._state.
|
|
70
|
+
self._state.state_index = self.initial_state_index
|
|
59
71
|
done = False
|
|
60
72
|
path = []
|
|
61
73
|
|
|
62
74
|
while not done or path is not None:
|
|
63
75
|
path = path[1:] # worth revisiting this line
|
|
64
76
|
|
|
65
|
-
|
|
77
|
+
new_state_value = self._state._get_state_value_from_index(self._state.nb_state_values-1)
|
|
66
78
|
while not done and len(path) < n_steps:
|
|
67
|
-
old_state_index = self._state.state_index
|
|
68
79
|
old_state_value = self._state.state_value
|
|
69
80
|
action_value = self._policy.get_action_value(self._state.state_value)
|
|
70
81
|
action_func = self._actions_dict[action_value]
|
|
71
82
|
self._state = action_func(self._state)
|
|
72
|
-
new_state_index = self._state.state_index
|
|
73
83
|
new_state_value = self._state.state_value
|
|
74
84
|
reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
|
|
75
85
|
done = self._state.is_terminal
|
|
76
86
|
|
|
77
87
|
path.append(
|
|
78
88
|
TimeDifferencePathElements(
|
|
79
|
-
|
|
89
|
+
this_state_value=old_state_value,
|
|
80
90
|
reward=reward,
|
|
81
|
-
|
|
91
|
+
next_state_value=new_state_value,
|
|
82
92
|
done=done
|
|
83
93
|
)
|
|
84
94
|
)
|
|
@@ -86,16 +96,16 @@ class MultipleStepTemporalDifferenceLearner(AbstractTemporalDifferenceLearner):
|
|
|
86
96
|
break
|
|
87
97
|
|
|
88
98
|
n = len(path)
|
|
89
|
-
|
|
99
|
+
estimated_state_value = path[0].this_state_value
|
|
90
100
|
rewards = np.array([this_moment.reward for this_moment in path])
|
|
91
101
|
partial_return = discounts[n:] * rewards
|
|
92
|
-
bs_val = discounts[-1] * V[
|
|
102
|
+
bs_val = discounts[-1] * V[new_state_value] * (not done)
|
|
93
103
|
ntd_target = np.sum(np.append(partial_return, bs_val))
|
|
94
|
-
ntd_error = ntd_target - V[
|
|
95
|
-
V[
|
|
104
|
+
ntd_error = ntd_target - V[estimated_state_value]
|
|
105
|
+
V[(estimated_state_value,)] = V[estimated_state_value] + alphas[i] * ntd_error
|
|
96
106
|
if len(path) == 1 and path[0].done:
|
|
97
107
|
path = None
|
|
98
108
|
|
|
99
|
-
|
|
109
|
+
V_track_array[i, :] = V_array
|
|
100
110
|
|
|
101
111
|
return V, V_track
|
pyrlutils/td/utils.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
|
|
2
|
-
from
|
|
3
|
-
from typing import Optional, Annotated
|
|
2
|
+
from typing import Annotated, Union, Optional
|
|
4
3
|
from dataclasses import dataclass
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
from numpy.typing import NDArray
|
|
8
|
+
from npdict import NumpyNDArrayWrappedDict
|
|
8
9
|
|
|
10
|
+
from ..state import DiscreteStateValueType
|
|
11
|
+
from ..action import DiscreteActionValueType
|
|
9
12
|
from ..policy import DiscretePolicy
|
|
10
13
|
from ..transition import TransitionProbabilityFactory
|
|
11
14
|
|
|
@@ -28,7 +31,36 @@ def decay_schedule(
|
|
|
28
31
|
return values
|
|
29
32
|
|
|
30
33
|
|
|
31
|
-
|
|
34
|
+
def select_action(
|
|
35
|
+
state_value: DiscreteStateValueType,
|
|
36
|
+
Q: Union[Annotated[NDArray[np.float64], "2D Array"], NumpyNDArrayWrappedDict],
|
|
37
|
+
epsilon: float,
|
|
38
|
+
) -> Union[DiscreteActionValueType, int]:
|
|
39
|
+
if np.random.random() <= epsilon:
|
|
40
|
+
if isinstance(Q, NumpyNDArrayWrappedDict):
|
|
41
|
+
return np.random.choice(Q._lists_keystrings[1])
|
|
42
|
+
else:
|
|
43
|
+
return np.random.choice(np.arange(Q.shape[1]))
|
|
44
|
+
|
|
45
|
+
q_matrix = Q.to_numpy() if isinstance(Q, NumpyNDArrayWrappedDict) else Q
|
|
46
|
+
state_index = Q.get_key_index(0, state_value) if isinstance(Q, NumpyNDArrayWrappedDict) else state_value
|
|
47
|
+
max_index = np.argmax(q_matrix[state_index, :])
|
|
48
|
+
|
|
49
|
+
if isinstance(Q, NumpyNDArrayWrappedDict):
|
|
50
|
+
return Q._lists_keystrings[1][max_index]
|
|
51
|
+
else:
|
|
52
|
+
return max_index
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass
|
|
56
|
+
class TimeDifferencePathElements:
|
|
57
|
+
this_state_value: DiscreteStateValueType
|
|
58
|
+
reward: float
|
|
59
|
+
next_state_value: DiscreteStateValueType
|
|
60
|
+
done: bool
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class AbstractStateValueFunctionTemporalDifferenceLearner(ABC):
|
|
32
64
|
def __init__(
|
|
33
65
|
self,
|
|
34
66
|
transprobfac: TransitionProbabilityFactory,
|
|
@@ -55,7 +87,7 @@ class AbstractTemporalDifferenceLearner(ABC):
|
|
|
55
87
|
try:
|
|
56
88
|
assert 0 <= initial_state_index < self._state.nb_state_values
|
|
57
89
|
except AssertionError:
|
|
58
|
-
raise ValueError("Initial state index must be between 0 and {}"
|
|
90
|
+
raise ValueError(f"Initial state index must be between 0 and {self._state.nb_state_values}")
|
|
59
91
|
self._init_state_index = initial_state_index
|
|
60
92
|
|
|
61
93
|
@abstractmethod
|
|
@@ -111,9 +143,116 @@ class AbstractTemporalDifferenceLearner(ABC):
|
|
|
111
143
|
self._init_state_index = val
|
|
112
144
|
|
|
113
145
|
|
|
114
|
-
|
|
115
|
-
class
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
146
|
+
|
|
147
|
+
class AbstractStateActionValueFunctionTemporalDifferenceLearner(ABC):
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
transprobfac: TransitionProbabilityFactory,
|
|
151
|
+
gamma: float=1.0,
|
|
152
|
+
init_alpha: float=0.5,
|
|
153
|
+
min_alpha: float=0.01,
|
|
154
|
+
alpha_decay_ratio: float=0.3,
|
|
155
|
+
init_epsilon: float=1.0,
|
|
156
|
+
min_epsilon: float=0.1,
|
|
157
|
+
epsilon_decay_ratio: float=0.9,
|
|
158
|
+
policy: Optional[DiscretePolicy]=None,
|
|
159
|
+
initial_state_index: int=0
|
|
160
|
+
):
|
|
161
|
+
self._gamma = gamma
|
|
162
|
+
self._init_alpha = init_alpha
|
|
163
|
+
self._min_alpha = min_alpha
|
|
164
|
+
try:
|
|
165
|
+
assert 0.0 <= alpha_decay_ratio <= 1.0
|
|
166
|
+
except AssertionError:
|
|
167
|
+
raise ValueError("alpha_decay_ratio must be between 0 and 1!")
|
|
168
|
+
self._alpha_decay_ratio = alpha_decay_ratio
|
|
169
|
+
self._init_epsilon = init_epsilon
|
|
170
|
+
self._min_epsilon = min_epsilon
|
|
171
|
+
self._epsilon_decay_ratio = epsilon_decay_ratio
|
|
172
|
+
|
|
173
|
+
self._transprobfac = transprobfac
|
|
174
|
+
self._state, self._actions_dict, self._indrewardfcn = self._transprobfac.generate_mdp_objects()
|
|
175
|
+
self._action_names = list(self._actions_dict.keys())
|
|
176
|
+
self._actions_to_indices = {action_value: idx for idx, action_value in enumerate(self._action_names)}
|
|
177
|
+
self._policy = policy
|
|
178
|
+
try:
|
|
179
|
+
assert 0 <= initial_state_index < self._state.nb_state_values
|
|
180
|
+
except AssertionError:
|
|
181
|
+
raise ValueError(f"Initial state index must be between 0 and {self._state.nb_state_values}")
|
|
182
|
+
self._init_state_index = initial_state_index
|
|
183
|
+
|
|
184
|
+
@abstractmethod
|
|
185
|
+
def learn(self, *args, **kwargs) -> tuple[Annotated[NDArray[np.float64], "1D Array"], Annotated[NDArray[np.float64], "2D Array"]]:
|
|
186
|
+
raise NotImplementedError()
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def nb_states(self) -> int:
|
|
190
|
+
return self._state.nb_state_values
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def policy(self) -> DiscretePolicy:
|
|
194
|
+
return self._policy
|
|
195
|
+
|
|
196
|
+
@policy.setter
|
|
197
|
+
def policy(self, val: DiscretePolicy):
|
|
198
|
+
self._policy = val
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def gamma(self) -> float:
|
|
202
|
+
return self._gamma
|
|
203
|
+
|
|
204
|
+
@gamma.setter
|
|
205
|
+
def gamma(self, val: float):
|
|
206
|
+
self._gamma = val
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def init_alpha(self) -> float:
|
|
210
|
+
return self._init_alpha
|
|
211
|
+
|
|
212
|
+
@init_alpha.setter
|
|
213
|
+
def init_alpha(self, val: float):
|
|
214
|
+
self._init_alpha = val
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def min_alpha(self) -> float:
|
|
218
|
+
return self._min_alpha
|
|
219
|
+
|
|
220
|
+
@min_alpha.setter
|
|
221
|
+
def min_alpha(self, val: float):
|
|
222
|
+
self._min_alpha = val
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def alpha_decay_ratio(self) -> float:
|
|
226
|
+
return self._alpha_decay_ratio
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def init_epsilon(self) -> float:
|
|
230
|
+
return self._init_epsilon
|
|
231
|
+
|
|
232
|
+
@init_epsilon.setter
|
|
233
|
+
def init_epsilon(self, val: float):
|
|
234
|
+
self._init_epsilon = val
|
|
235
|
+
|
|
236
|
+
@property
|
|
237
|
+
def min_epsilon(self) -> float:
|
|
238
|
+
return self._min_epsilon
|
|
239
|
+
|
|
240
|
+
@min_epsilon.setter
|
|
241
|
+
def min_epsilon(self, val: float):
|
|
242
|
+
self._min_epsilon = val
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def epsilon_decay_ratio(self) -> float:
|
|
246
|
+
return self._epsilon_decay_ratio
|
|
247
|
+
|
|
248
|
+
@epsilon_decay_ratio.setter
|
|
249
|
+
def epsilon_decay_ratio(self, val: float):
|
|
250
|
+
self._epsilon_decay_ratio = val
|
|
251
|
+
|
|
252
|
+
@property
|
|
253
|
+
def initial_state_index(self) -> int:
|
|
254
|
+
return self._init_state_index
|
|
255
|
+
|
|
256
|
+
@initial_state_index.setter
|
|
257
|
+
def initial_state_index(self, val: int):
|
|
258
|
+
self._init_state_index = val
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyrlutils
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.1
|
|
4
4
|
Summary: Utility and Helpers for Reinformcement Learning
|
|
5
5
|
Author-email: Kwan Yuet Stephen Ho <stephenhky@yahoo.com.hk>
|
|
6
6
|
License: MIT
|
|
@@ -21,6 +21,7 @@ Requires-Python: >=3.10
|
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
License-File: LICENSE
|
|
23
23
|
Requires-Dist: numpy
|
|
24
|
+
Requires-Dist: npdict>=0.0.7
|
|
24
25
|
Requires-Dist: typing-extensions
|
|
25
26
|
Provides-Extra: openaigym
|
|
26
27
|
Requires-Dist: gymnasium; extra == "openaigym"
|
|
@@ -2,7 +2,7 @@ pyrlutils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
2
2
|
pyrlutils/action.py,sha256=QoBdtcGtK_EkYAjb50bruhoB_XIz0agLpQjdGFnGbRQ,732
|
|
3
3
|
pyrlutils/policy.py,sha256=A9bj2eVd6XjNNkClSYVJDoxoGuGkyoYVr1DpVdI0wzs,5120
|
|
4
4
|
pyrlutils/reward.py,sha256=are0swsobMqI1IbrBVBaPMYXWpJnp6lZwAyfgBEm2zg,1211
|
|
5
|
-
pyrlutils/state.py,sha256=
|
|
5
|
+
pyrlutils/state.py,sha256=h-OGrezt0fWfVdM9-BTfqdhx1Ert_utG0ORpIwHwXCw,11902
|
|
6
6
|
pyrlutils/transition.py,sha256=_32jxeYbsiKyaHR9Y2XceUQYbb1jslLCQO2AWL61_EU,6260
|
|
7
7
|
pyrlutils/bandit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
pyrlutils/bandit/algo.py,sha256=X2Pn4DOi-RXWz5CNg1h0RJCoV3VlAwEGHRMjkfbckfw,3969
|
|
@@ -14,10 +14,12 @@ pyrlutils/helpers/exceptions.py,sha256=4fPGW839BChfap-Gd7b-75Dz-Ed3foqbJQ1lg15TZ
|
|
|
14
14
|
pyrlutils/openai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
15
|
pyrlutils/openai/utils.py,sha256=PJc9WHZM8aM4Z9MlACUxUC8TO7VARp8taatba_ikhew,1056
|
|
16
16
|
pyrlutils/td/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
|
-
pyrlutils/td/
|
|
18
|
-
pyrlutils/td/
|
|
19
|
-
pyrlutils
|
|
20
|
-
pyrlutils
|
|
21
|
-
pyrlutils-0.1.
|
|
22
|
-
pyrlutils-0.1.
|
|
23
|
-
pyrlutils-0.1.
|
|
17
|
+
pyrlutils/td/qlearn.py,sha256=ZibW_fuB89ZAST5snNYLe5H_zUIMZ93vuJXguXpccyo,3374
|
|
18
|
+
pyrlutils/td/sarsa.py,sha256=jtnfMdPHld9C8yzDMQd3xyLZ3BwGL6ShvDq5WpHfZEo,3281
|
|
19
|
+
pyrlutils/td/state_td.py,sha256=gMX-RuSZQ-UIoTWnsmR7xLZvL2jndRknXTExWnhixpM,4778
|
|
20
|
+
pyrlutils/td/utils.py,sha256=VM5MAfWLQIk6a_qENU-iWHglp1azlaP68qkHvl4jXro,8022
|
|
21
|
+
pyrlutils-0.1.1.dist-info/licenses/LICENSE,sha256=bnQPjIcaeBdr2ZofX-_j-nELs8pAx5fQ4Cdfgeaspew,1063
|
|
22
|
+
pyrlutils-0.1.1.dist-info/METADATA,sha256=1PmFggMx23mxdJlhG04qE7_lhAtEDz_qniaBBhTmiVI,2214
|
|
23
|
+
pyrlutils-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
24
|
+
pyrlutils-0.1.1.dist-info/top_level.txt,sha256=gOBuxugE2MA4WDXlLhzkQh_rUonZU6nvJnMuomeHMCU,10
|
|
25
|
+
pyrlutils-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|