pyrlutils 0.0.4__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyrlutils might be problematic. Click here for more details.
- pyrlutils/action.py +10 -2
- pyrlutils/bandit/reward.py +3 -2
- pyrlutils/dp/__init__.py +0 -0
- pyrlutils/{valuefcns.py → dp/valuefcns.py} +16 -11
- pyrlutils/helpers/__init__.py +0 -0
- pyrlutils/helpers/exceptions.py +5 -0
- pyrlutils/openai/utils.py +3 -3
- pyrlutils/policy.py +79 -12
- pyrlutils/state.py +166 -74
- pyrlutils/td/__init__.py +0 -0
- pyrlutils/td/td.py +101 -0
- pyrlutils/td/utils.py +119 -0
- pyrlutils/transition.py +44 -35
- {pyrlutils-0.0.4.dist-info → pyrlutils-0.1.0.dist-info}/METADATA +6 -6
- pyrlutils-0.1.0.dist-info/RECORD +23 -0
- {pyrlutils-0.0.4.dist-info → pyrlutils-0.1.0.dist-info}/WHEEL +1 -1
- pyrlutils-0.0.4.dist-info/RECORD +0 -17
- {pyrlutils-0.0.4.dist-info → pyrlutils-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {pyrlutils-0.0.4.dist-info → pyrlutils-0.1.0.dist-info}/top_level.txt +0 -0
pyrlutils/action.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
|
|
2
|
-
from types import LambdaType
|
|
2
|
+
from types import LambdaType, FunctionType
|
|
3
3
|
from typing import Union
|
|
4
4
|
|
|
5
5
|
from .state import State
|
|
@@ -8,7 +8,7 @@ from .state import State
|
|
|
8
8
|
DiscreteActionValueType = Union[float, str]
|
|
9
9
|
|
|
10
10
|
class Action:
|
|
11
|
-
def __init__(self, actionfunc: LambdaType):
|
|
11
|
+
def __init__(self, actionfunc: Union[FunctionType, LambdaType]):
|
|
12
12
|
self._actionfunc = actionfunc
|
|
13
13
|
|
|
14
14
|
def act(self, state: State, *args, **kwargs) -> State:
|
|
@@ -17,3 +17,11 @@ class Action:
|
|
|
17
17
|
|
|
18
18
|
def __call__(self, state: State) -> State:
|
|
19
19
|
return self.act(state)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def action_function(self) -> Union[FunctionType, LambdaType]:
|
|
23
|
+
return self._actionfunc
|
|
24
|
+
|
|
25
|
+
@action_function.setter
|
|
26
|
+
def action_function(self, new_func: Union[FunctionType, LambdaType]) -> None:
|
|
27
|
+
self._actionfunc = new_func
|
pyrlutils/bandit/reward.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class IndividualBanditRewardFunction(ABC):
|
|
6
7
|
@abstractmethod
|
|
7
|
-
def reward(self, action_value) -> float:
|
|
8
|
+
def reward(self, action_value: Any) -> float:
|
|
8
9
|
pass
|
|
9
10
|
|
|
10
|
-
def __call__(self, action_value) -> float:
|
|
11
|
+
def __call__(self, action_value: Any) -> float:
|
|
11
12
|
return self.reward(action_value)
|
pyrlutils/dp/__init__.py
ADDED
|
File without changes
|
|
@@ -1,18 +1,23 @@
|
|
|
1
1
|
|
|
2
2
|
import random
|
|
3
3
|
from copy import copy
|
|
4
|
-
from typing import Tuple, Dict
|
|
5
4
|
from itertools import product
|
|
5
|
+
from typing import Annotated
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
8
9
|
|
|
9
|
-
from
|
|
10
|
-
from
|
|
11
|
-
from
|
|
10
|
+
from ..state import DiscreteStateValueType
|
|
11
|
+
from ..transition import TransitionProbabilityFactory
|
|
12
|
+
from ..policy import DiscreteDeterminsticPolicy
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class OptimalPolicyOnValueFunctions:
|
|
15
|
-
def __init__(
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
discount_factor: float,
|
|
19
|
+
transprobfac: TransitionProbabilityFactory
|
|
20
|
+
):
|
|
16
21
|
try:
|
|
17
22
|
assert 0. <= discount_factor <= 1.
|
|
18
23
|
except AssertionError:
|
|
@@ -31,7 +36,7 @@ class OptimalPolicyOnValueFunctions:
|
|
|
31
36
|
self._theta = 1e-10
|
|
32
37
|
self._policy_evaluation_maxiter = 10000
|
|
33
38
|
|
|
34
|
-
def _policy_evaluation(self, policy: DiscreteDeterminsticPolicy) -> np.
|
|
39
|
+
def _policy_evaluation(self, policy: DiscreteDeterminsticPolicy) -> Annotated[NDArray[np.float64], "1D Array"]:
|
|
35
40
|
prev_V = np.zeros(len(self._states_to_indices))
|
|
36
41
|
|
|
37
42
|
for _ in range(self._policy_evaluation_maxiter):
|
|
@@ -55,7 +60,7 @@ class OptimalPolicyOnValueFunctions:
|
|
|
55
60
|
|
|
56
61
|
return V
|
|
57
62
|
|
|
58
|
-
def _policy_improvement(self, V: np.
|
|
63
|
+
def _policy_improvement(self, V: Annotated[NDArray[np.float64], "1D Array"]) -> DiscreteDeterminsticPolicy:
|
|
59
64
|
Q = np.zeros((len(self._states_to_indices), len(self._actions_to_indices)))
|
|
60
65
|
|
|
61
66
|
for state_value in self._state_names:
|
|
@@ -78,7 +83,7 @@ class OptimalPolicyOnValueFunctions:
|
|
|
78
83
|
optimal_policy.add_deterministic_rule(state_value, action_value)
|
|
79
84
|
return optimal_policy
|
|
80
85
|
|
|
81
|
-
def _policy_iteration(self) ->
|
|
86
|
+
def _policy_iteration(self) -> tuple[Annotated[NDArray[np.float64], "1D Array"], DiscreteDeterminsticPolicy]:
|
|
82
87
|
policy = DiscreteDeterminsticPolicy(self._actions_dict)
|
|
83
88
|
for state_value in self._state_names:
|
|
84
89
|
policy.add_deterministic_rule(state_value, random.choice(self._action_names))
|
|
@@ -97,7 +102,7 @@ class OptimalPolicyOnValueFunctions:
|
|
|
97
102
|
return V, policy
|
|
98
103
|
|
|
99
104
|
|
|
100
|
-
def _value_iteration(self) ->
|
|
105
|
+
def _value_iteration(self) -> tuple[Annotated[NDArray[np.float64], "1D Array"], DiscreteDeterminsticPolicy]:
|
|
101
106
|
V = np.zeros(len(self._state_names))
|
|
102
107
|
|
|
103
108
|
for _ in range(self._policy_evaluation_maxiter):
|
|
@@ -127,7 +132,7 @@ class OptimalPolicyOnValueFunctions:
|
|
|
127
132
|
|
|
128
133
|
return V, policy
|
|
129
134
|
|
|
130
|
-
def policy_iteration(self) ->
|
|
135
|
+
def policy_iteration(self) -> tuple[dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
|
|
131
136
|
V, policy = self._policy_iteration()
|
|
132
137
|
state_values_dict = {
|
|
133
138
|
self._state_names[i]: V[i]
|
|
@@ -135,7 +140,7 @@ class OptimalPolicyOnValueFunctions:
|
|
|
135
140
|
}
|
|
136
141
|
return state_values_dict, policy
|
|
137
142
|
|
|
138
|
-
def value_iteration(self) ->
|
|
143
|
+
def value_iteration(self) -> tuple[dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
|
|
139
144
|
V, policy = self._value_iteration()
|
|
140
145
|
state_values_dict = {
|
|
141
146
|
self._state_names[i]: V[i]
|
|
File without changes
|
pyrlutils/openai/utils.py
CHANGED
|
@@ -5,7 +5,7 @@ from ..transition import TransitionProbabilityFactory, NextStateTuple
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory(TransitionProbabilityFactory):
|
|
8
|
-
def __init__(self, envname):
|
|
8
|
+
def __init__(self, envname: str):
|
|
9
9
|
super().__init__()
|
|
10
10
|
self._envname = envname
|
|
11
11
|
self._gymenv = gym.make(envname)
|
|
@@ -23,9 +23,9 @@ class OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory(TransitionProbabi
|
|
|
23
23
|
self.add_state_transitions(state_value, new_trans_dict)
|
|
24
24
|
|
|
25
25
|
@property
|
|
26
|
-
def envname(self):
|
|
26
|
+
def envname(self) -> str:
|
|
27
27
|
return self._envname
|
|
28
28
|
|
|
29
29
|
@property
|
|
30
|
-
def gymenv(self):
|
|
30
|
+
def gymenv(self) -> gym.Env:
|
|
31
31
|
return self._gymenv
|
pyrlutils/policy.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
-
from typing import Union,
|
|
3
|
+
from typing import Union, Annotated
|
|
4
4
|
from warnings import warn
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
7
8
|
|
|
8
9
|
from .state import State, DiscreteState, DiscreteStateValueType
|
|
9
10
|
from .action import Action, DiscreteActionValueType
|
|
@@ -12,7 +13,11 @@ from .action import Action, DiscreteActionValueType
|
|
|
12
13
|
class Policy(ABC):
|
|
13
14
|
@abstractmethod
|
|
14
15
|
def get_action(self, state: State) -> Action:
|
|
15
|
-
|
|
16
|
+
raise NotImplemented()
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def get_action_value(self, state: State) -> DiscreteActionValueType:
|
|
20
|
+
raise NotImplemented()
|
|
16
21
|
|
|
17
22
|
def __call__(self, state: State) -> Action:
|
|
18
23
|
return self.get_action(state)
|
|
@@ -25,7 +30,7 @@ class Policy(ABC):
|
|
|
25
30
|
class DeterministicPolicy(Policy):
|
|
26
31
|
@abstractmethod
|
|
27
32
|
def add_deterministic_rule(self, *args, **kwargs):
|
|
28
|
-
|
|
33
|
+
raise NotImplemented()
|
|
29
34
|
|
|
30
35
|
@property
|
|
31
36
|
def is_stochastic(self) -> bool:
|
|
@@ -33,16 +38,23 @@ class DeterministicPolicy(Policy):
|
|
|
33
38
|
|
|
34
39
|
|
|
35
40
|
class DiscreteDeterminsticPolicy(DeterministicPolicy):
|
|
36
|
-
def __init__(self, actions_dict:
|
|
41
|
+
def __init__(self, actions_dict: dict[DiscreteActionValueType, Action]):
|
|
37
42
|
self._state_to_action = {}
|
|
38
43
|
self._actions_dict = actions_dict
|
|
39
44
|
|
|
40
|
-
def add_deterministic_rule(
|
|
45
|
+
def add_deterministic_rule(
|
|
46
|
+
self,
|
|
47
|
+
state_value: DiscreteStateValueType,
|
|
48
|
+
action_value: DiscreteActionValueType
|
|
49
|
+
) -> None:
|
|
41
50
|
if state_value in self._state_to_action:
|
|
42
51
|
warn('State value {} exists in rule; it will be replaced.'.format(state_value))
|
|
43
52
|
self._state_to_action[state_value] = action_value
|
|
44
53
|
|
|
45
|
-
def get_action_value(
|
|
54
|
+
def get_action_value(
|
|
55
|
+
self,
|
|
56
|
+
state_value: DiscreteStateValueType
|
|
57
|
+
) -> DiscreteActionValueType:
|
|
46
58
|
return self._state_to_action.get(state_value)
|
|
47
59
|
|
|
48
60
|
def get_action(self, state: DiscreteState) -> Action:
|
|
@@ -62,10 +74,16 @@ class DiscreteDeterminsticPolicy(DeterministicPolicy):
|
|
|
62
74
|
return True
|
|
63
75
|
|
|
64
76
|
|
|
77
|
+
class DiscreteContinuousPolicy(DeterministicPolicy):
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def get_action(self, state: State) -> Action:
|
|
80
|
+
raise NotImplemented()
|
|
81
|
+
|
|
82
|
+
|
|
65
83
|
class StochasticPolicy(Policy):
|
|
66
84
|
@abstractmethod
|
|
67
85
|
def get_probability(self, *args, **kwargs) -> float:
|
|
68
|
-
|
|
86
|
+
raise NotImplemented()
|
|
69
87
|
|
|
70
88
|
@property
|
|
71
89
|
def is_stochastic(self) -> bool:
|
|
@@ -73,12 +91,61 @@ class StochasticPolicy(Policy):
|
|
|
73
91
|
|
|
74
92
|
|
|
75
93
|
class DiscreteStochasticPolicy(StochasticPolicy):
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
94
|
+
def __init__(self, actions_dict: dict[DiscreteActionValueType, Action]):
|
|
95
|
+
self._state_to_action = {}
|
|
96
|
+
self._actions_dict = actions_dict
|
|
97
|
+
|
|
98
|
+
def add_stochastic_rule(
|
|
99
|
+
self,
|
|
100
|
+
state_value: DiscreteStateValueType,
|
|
101
|
+
action_values: list[DiscreteActionValueType],
|
|
102
|
+
probs: Union[list[float], Annotated[NDArray[np.float64], "1D Array"]] = None
|
|
103
|
+
):
|
|
104
|
+
if probs is not None:
|
|
105
|
+
assert len(action_values) == len(probs)
|
|
106
|
+
probs = np.array(probs)
|
|
107
|
+
else:
|
|
108
|
+
probs = np.repeat(1./len(action_values), len(action_values))
|
|
109
|
+
|
|
110
|
+
if state_value in self._state_to_action:
|
|
111
|
+
warn('State value {} exists in rule; it will be replaced.'.format(state_value))
|
|
112
|
+
self._state_to_action[state_value] = {
|
|
113
|
+
action_value: prob
|
|
114
|
+
for action_value, prob in zip(action_values, probs)
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
def get_probability(
|
|
118
|
+
self,
|
|
119
|
+
state_value: DiscreteStateValueType,
|
|
120
|
+
action_value: DiscreteActionValueType
|
|
121
|
+
) -> float:
|
|
122
|
+
if state_value not in self._state_to_action:
|
|
123
|
+
return 0.0
|
|
124
|
+
if action_value in self._state_to_action[state_value]:
|
|
125
|
+
return self._state_to_action[state_value][action_value]
|
|
126
|
+
else:
|
|
127
|
+
return 0.0
|
|
128
|
+
|
|
129
|
+
def get_action_value(self, state: State) -> DiscreteActionValueType:
|
|
130
|
+
allowed_actions = list(self._state_to_action[state].keys())
|
|
131
|
+
probs = np.array(list(self._state_to_action[state].values()))
|
|
132
|
+
sumprobs = np.sum(probs)
|
|
133
|
+
return np.random.choice(allowed_actions, p=probs/sumprobs)
|
|
134
|
+
|
|
135
|
+
def get_action(self, state: DiscreteState) -> Action:
|
|
136
|
+
return self._actions_dict[self.get_action_value(state.state_value)]
|
|
79
137
|
|
|
80
138
|
|
|
81
139
|
class ContinuousStochasticPolicy(StochasticPolicy):
|
|
82
140
|
@abstractmethod
|
|
83
|
-
def get_probability(
|
|
84
|
-
|
|
141
|
+
def get_probability(
|
|
142
|
+
self,
|
|
143
|
+
state_value: Union[float, Annotated[NDArray[np.float64], "1D Array"]],
|
|
144
|
+
action_value: DiscreteActionValueType,
|
|
145
|
+
value: Union[float, Annotated[NDArray[np.float64], "1D Array"]]
|
|
146
|
+
) -> float:
|
|
147
|
+
raise NotImplemented()
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
DiscretePolicy = Union[DiscreteDeterminsticPolicy, DiscreteStochasticPolicy]
|
|
151
|
+
ContinuousPolicy = Union[ContinuousStochasticPolicy]
|
pyrlutils/state.py
CHANGED
|
@@ -1,81 +1,84 @@
|
|
|
1
1
|
|
|
2
|
-
|
|
2
|
+
import sys
|
|
3
|
+
from abc import ABC
|
|
3
4
|
from enum import Enum
|
|
4
|
-
from
|
|
5
|
-
from typing import Tuple, List, Optional, Union
|
|
5
|
+
from typing import Optional, Union, Annotated, Literal
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
if sys.version_info < (3, 11):
|
|
10
|
+
from typing_extensions import Self
|
|
11
|
+
else:
|
|
12
|
+
from typing import Self
|
|
8
13
|
|
|
9
|
-
|
|
10
|
-
class StateValue(ABC):
|
|
11
|
-
@property
|
|
12
|
-
@abstractmethod
|
|
13
|
-
def value(self):
|
|
14
|
-
pass
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
@dataclass
|
|
18
|
-
class DiscreteStateValue(StateValue):
|
|
19
|
-
enum: Enum
|
|
20
|
-
|
|
21
|
-
@property
|
|
22
|
-
def value(self):
|
|
23
|
-
return self.enum.value
|
|
24
|
-
|
|
25
|
-
def name(self):
|
|
26
|
-
return self.enum.name
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class ContinuousStateValue(StateValue):
|
|
30
|
-
_value: float
|
|
31
|
-
|
|
32
|
-
@property
|
|
33
|
-
def value(self) -> float:
|
|
34
|
-
return self._value
|
|
14
|
+
from .helpers.exceptions import InvalidRangeError
|
|
35
15
|
|
|
36
16
|
|
|
37
17
|
class State(ABC):
|
|
38
18
|
@property
|
|
39
19
|
def state_value(self):
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
@abstractmethod
|
|
43
|
-
def set_state_value(self, state_value):
|
|
44
|
-
pass
|
|
20
|
+
raise NotImplemented()
|
|
45
21
|
|
|
46
|
-
@abstractmethod
|
|
47
|
-
def get_state_value(self):
|
|
48
|
-
pass
|
|
49
|
-
|
|
50
|
-
@state_value.setter
|
|
51
|
-
def state_value(self, new_state_value):
|
|
52
|
-
self.set_state_value(new_state_value)
|
|
53
22
|
|
|
54
|
-
|
|
55
|
-
DiscreteStateValueType = Union[float, str, Tuple[int], Enum]
|
|
23
|
+
DiscreteStateValueType = Union[str, int, tuple[int], Enum]
|
|
56
24
|
|
|
57
25
|
|
|
58
26
|
class DiscreteState(State):
|
|
59
|
-
def __init__(
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
all_state_values: list[DiscreteStateValueType],
|
|
30
|
+
initial_value: Optional[DiscreteStateValueType] = None,
|
|
31
|
+
terminals: Optional[dict[DiscreteStateValueType, bool]]=None
|
|
32
|
+
):
|
|
60
33
|
super().__init__()
|
|
61
34
|
self._all_state_values = all_state_values
|
|
62
|
-
self.
|
|
35
|
+
self._state_values_to_indices = {
|
|
36
|
+
state_value: idx
|
|
37
|
+
for idx, state_value in enumerate(self._all_state_values)
|
|
38
|
+
}
|
|
39
|
+
if initial_value is not None:
|
|
40
|
+
self._current_index = self._state_values_to_indices[initial_value]
|
|
41
|
+
else:
|
|
42
|
+
self._current_index = 0
|
|
43
|
+
if terminals is None:
|
|
44
|
+
self._terminal_dict = {
|
|
45
|
+
state_value: False
|
|
46
|
+
for state_value in self._all_state_values
|
|
47
|
+
}
|
|
48
|
+
else:
|
|
49
|
+
self._terminal_dict = terminals.copy()
|
|
50
|
+
for state_value in self._all_state_values:
|
|
51
|
+
if self._terminal_dict.get(state_value) is None:
|
|
52
|
+
self._terminal_dict[state_value] = False
|
|
53
|
+
|
|
54
|
+
def _get_state_value_from_index(self, index: int) -> DiscreteStateValueType:
|
|
55
|
+
return self._all_state_values[index]
|
|
63
56
|
|
|
64
57
|
def get_state_value(self) -> DiscreteStateValueType:
|
|
65
|
-
return self.
|
|
58
|
+
return self._get_state_value_from_index(self._current_index)
|
|
66
59
|
|
|
67
|
-
def set_state_value(self, state_value: DiscreteStateValueType):
|
|
60
|
+
def set_state_value(self, state_value: DiscreteStateValueType) -> None:
|
|
68
61
|
if state_value in self._all_state_values:
|
|
69
|
-
self.
|
|
62
|
+
self._current_index = self._state_values_to_indices[state_value]
|
|
70
63
|
else:
|
|
71
64
|
raise ValueError('State value {} is invalid.'.format(state_value))
|
|
72
65
|
|
|
73
|
-
def get_all_possible_state_values(self) ->
|
|
66
|
+
def get_all_possible_state_values(self) -> list[DiscreteStateValueType]:
|
|
74
67
|
return self._all_state_values
|
|
75
68
|
|
|
69
|
+
@property
|
|
70
|
+
def state_index(self) -> int:
|
|
71
|
+
return self._current_index
|
|
72
|
+
|
|
73
|
+
@state_index.setter
|
|
74
|
+
def state_index(self, new_index: int) -> None:
|
|
75
|
+
if new_index >= len(self._all_state_values):
|
|
76
|
+
raise ValueError(f"Invalid index {new_index}; it must be less than {len(self._all_state_values)}.")
|
|
77
|
+
self._current_index = new_index
|
|
78
|
+
|
|
76
79
|
@property
|
|
77
80
|
def state_value(self) -> DiscreteStateValueType:
|
|
78
|
-
return self.
|
|
81
|
+
return self._all_state_values[self._current_index]
|
|
79
82
|
|
|
80
83
|
@state_value.setter
|
|
81
84
|
def state_value(self, new_state_value: DiscreteStateValueType):
|
|
@@ -85,22 +88,53 @@ class DiscreteState(State):
|
|
|
85
88
|
def state_space_size(self):
|
|
86
89
|
return len(self._all_state_values)
|
|
87
90
|
|
|
91
|
+
@property
|
|
92
|
+
def nb_state_values(self) -> int:
|
|
93
|
+
return len(self._all_state_values)
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def is_terminal(self) -> bool:
|
|
97
|
+
return self._terminal_dict[self._all_state_values[self._current_index]]
|
|
88
98
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
99
|
+
def __hash__(self):
|
|
100
|
+
return self._current_index
|
|
101
|
+
|
|
102
|
+
def __eq__(self, other: Self) -> bool:
|
|
103
|
+
return self._current_index == other._current_index
|
|
93
104
|
|
|
94
105
|
|
|
95
106
|
class ContinuousState(State):
|
|
96
|
-
def __init__(
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
nbdims: int,
|
|
110
|
+
ranges: Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]],
|
|
111
|
+
init_value: Optional[Union[float, Annotated[NDArray[np.float64], "1D Array"]]] = None
|
|
112
|
+
):
|
|
113
|
+
super().__init__()
|
|
97
114
|
self._nbdims = nbdims
|
|
98
115
|
|
|
116
|
+
try:
|
|
117
|
+
assert isinstance(ranges, np.ndarray)
|
|
118
|
+
except AssertionError:
|
|
119
|
+
raise TypeError('Range must be a numpy array.')
|
|
120
|
+
|
|
99
121
|
try:
|
|
100
122
|
assert (ranges.dtype == np.float64) or (ranges.dtype == np.float32) or (ranges.dtype == np.float16)
|
|
101
123
|
except AssertionError:
|
|
102
124
|
raise TypeError('It has to be floating type numpy.ndarray.')
|
|
103
125
|
|
|
126
|
+
try:
|
|
127
|
+
assert ranges.ndim == 1 or ranges.ndim == 2
|
|
128
|
+
match ranges.ndim:
|
|
129
|
+
case 1:
|
|
130
|
+
assert ranges.shape[0] == 2
|
|
131
|
+
case 2:
|
|
132
|
+
assert ranges.shape[1] == 2
|
|
133
|
+
case _:
|
|
134
|
+
raise ValueError("Ranges must be of shape (2, ) or (*, 2).")
|
|
135
|
+
except AssertionError:
|
|
136
|
+
raise ValueError("Ranges must be of shape (2, ) or (*, 2).")
|
|
137
|
+
|
|
104
138
|
try:
|
|
105
139
|
assert self._nbdims > 0
|
|
106
140
|
except AssertionError:
|
|
@@ -146,50 +180,53 @@ class ContinuousState(State):
|
|
|
146
180
|
raise ValueError('Initialized value does not have the right dimension.')
|
|
147
181
|
for i in range(self._nbdims):
|
|
148
182
|
try:
|
|
149
|
-
assert
|
|
183
|
+
assert self._ranges[i, 0] <= init_value[i] <= self.ranges[i, 1]
|
|
150
184
|
except AssertionError:
|
|
151
185
|
raise InvalidRangeError('Initialized value at dimension {} (value: {}) is not within the permitted range ({} -> {})!'.format(i, init_value[i], self._ranges[i, 0], self._ranges[i, 1]))
|
|
152
186
|
else:
|
|
153
187
|
try:
|
|
154
|
-
assert
|
|
188
|
+
assert self._ranges[0, 0] <= init_value <= self.ranges[0, 1]
|
|
155
189
|
except AssertionError:
|
|
156
190
|
raise InvalidRangeError('Initialized value is out of range.')
|
|
157
191
|
self._state_value = init_value
|
|
158
192
|
|
|
159
|
-
def set_state_value(self, state_value: Union[float, np.
|
|
160
|
-
if self.
|
|
193
|
+
def set_state_value(self, state_value: Union[float, Annotated[NDArray[np.float64], "1D Array"]]):
|
|
194
|
+
if self._nbdims > 1:
|
|
161
195
|
try:
|
|
162
196
|
assert state_value.shape[0] == self._nbdims
|
|
163
197
|
except AssertionError:
|
|
164
198
|
raise ValueError('Given value does not have the right dimension.')
|
|
165
|
-
for i in range(self.
|
|
199
|
+
for i in range(self._nbdims):
|
|
166
200
|
try:
|
|
167
|
-
assert
|
|
201
|
+
assert self.ranges[i, 0] <= state_value[i] <= self.ranges[i, 1]
|
|
168
202
|
except AssertionError:
|
|
169
203
|
raise InvalidRangeError()
|
|
170
204
|
else:
|
|
171
205
|
try:
|
|
172
|
-
assert
|
|
206
|
+
assert self.ranges[0, 0] <= state_value <= self.ranges[0, 1]
|
|
173
207
|
except AssertionError:
|
|
174
208
|
raise InvalidRangeError()
|
|
175
209
|
|
|
176
210
|
self._state_value = state_value
|
|
177
211
|
|
|
178
|
-
def get_state_value(self) -> np.
|
|
212
|
+
def get_state_value(self) -> Annotated[NDArray[np.float64], "1D Array"]:
|
|
179
213
|
return self._state_value
|
|
180
214
|
|
|
181
|
-
def get_state_value_ranges(self) -> np.
|
|
215
|
+
def get_state_value_ranges(self) -> Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]]:
|
|
182
216
|
return self._ranges
|
|
183
217
|
|
|
184
|
-
def get_state_value_range_at_dimension(self, dimension: int) -> np.
|
|
185
|
-
|
|
218
|
+
def get_state_value_range_at_dimension(self, dimension: int) -> Annotated[NDArray[np.float64], Literal["2"]]:
|
|
219
|
+
if dimension < self._nbdims:
|
|
220
|
+
return self._ranges[dimension]
|
|
221
|
+
else:
|
|
222
|
+
raise ValueError(f"There are only {self._nbdims} dimensions!")
|
|
186
223
|
|
|
187
224
|
@property
|
|
188
|
-
def ranges(self) -> np.
|
|
225
|
+
def ranges(self) -> Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]]:
|
|
189
226
|
return self.get_state_value_ranges()
|
|
190
227
|
|
|
191
228
|
@property
|
|
192
|
-
def state_value(self) -> Union[float, np.
|
|
229
|
+
def state_value(self) -> Union[float, NDArray[np.float64]]:
|
|
193
230
|
return self.get_state_value()
|
|
194
231
|
|
|
195
232
|
@state_value.setter
|
|
@@ -200,9 +237,28 @@ class ContinuousState(State):
|
|
|
200
237
|
def nbdims(self) -> int:
|
|
201
238
|
return self._nbdims
|
|
202
239
|
|
|
240
|
+
def __hash__(self):
|
|
241
|
+
return hash(tuple(self._state_value))
|
|
242
|
+
|
|
243
|
+
def __eq__(self, other: Self):
|
|
244
|
+
if self.nbdims != other.nbdims:
|
|
245
|
+
raise ValueError(f"The two states have two different dimensions. ({self.nbdims} vs. {other.nbdims})")
|
|
246
|
+
for i in range(self.nbdims):
|
|
247
|
+
if self.state_value[i] != other.state_value[i]:
|
|
248
|
+
return False
|
|
249
|
+
return True
|
|
250
|
+
|
|
203
251
|
|
|
204
252
|
class Discrete2DCartesianState(DiscreteState):
|
|
205
|
-
def __init__(
|
|
253
|
+
def __init__(
|
|
254
|
+
self,
|
|
255
|
+
x_lowlim: int,
|
|
256
|
+
x_hilim: int,
|
|
257
|
+
y_lowlim: int,
|
|
258
|
+
y_hilim: int,
|
|
259
|
+
initial_coordinate: list[int]=None,
|
|
260
|
+
terminals: Optional[dict[DiscreteStateValueType, bool]] = None
|
|
261
|
+
):
|
|
206
262
|
self._x_lowlim = x_lowlim
|
|
207
263
|
self._x_hilim = x_hilim
|
|
208
264
|
self._y_lowlim = y_lowlim
|
|
@@ -212,14 +268,50 @@ class Discrete2DCartesianState(DiscreteState):
|
|
|
212
268
|
if initial_coordinate is None:
|
|
213
269
|
initial_coordinate = [self._x_lowlim, self._y_lowlim]
|
|
214
270
|
initial_value = (initial_coordinate[1] - self._y_lowlim) * self._countx + (initial_coordinate[0] - self._x_lowlim)
|
|
215
|
-
super().__init__(list(range(self._countx*self._county)),
|
|
271
|
+
super().__init__(list(range(self._countx*self._county)), initial_value=initial_value, terminals=terminals)
|
|
216
272
|
|
|
217
273
|
def _encode_coordinates(self, x, y) -> int:
|
|
218
274
|
return (y - self._y_lowlim) * self._countx + (x - self._x_lowlim)
|
|
219
275
|
|
|
220
|
-
def encode_coordinates(self, coordinates:
|
|
221
|
-
|
|
276
|
+
def encode_coordinates(self, coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]]) -> int:
|
|
277
|
+
if isinstance(coordinates, list):
|
|
278
|
+
assert len(coordinates) == 2
|
|
222
279
|
return self._encode_coordinates(coordinates[0], coordinates[1])
|
|
223
280
|
|
|
224
|
-
def decode_coordinates(self, hashcode) ->
|
|
225
|
-
return [hashcode % self._countx, hashcode // self._countx]
|
|
281
|
+
def decode_coordinates(self, hashcode) -> list[int]:
|
|
282
|
+
return [hashcode % self._countx + self._x_lowlim, hashcode // self._countx + self._y_lowlim]
|
|
283
|
+
|
|
284
|
+
def get_whether_terminal_given_coordinates(
|
|
285
|
+
self,
|
|
286
|
+
coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]]
|
|
287
|
+
) -> bool:
|
|
288
|
+
if isinstance(coordinates, list):
|
|
289
|
+
assert len(coordinates) == 2
|
|
290
|
+
hashcode = self._encode_coordinates(coordinates[0], coordinates[1])
|
|
291
|
+
return self._terminal_dict.get(hashcode, False)
|
|
292
|
+
|
|
293
|
+
def set_terminal_given_coordinates(
|
|
294
|
+
self,
|
|
295
|
+
coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]],
|
|
296
|
+
terminal_value: bool
|
|
297
|
+
) -> None:
|
|
298
|
+
if isinstance(coordinates, list):
|
|
299
|
+
assert len(coordinates) == 2
|
|
300
|
+
hashcode = self._encode_coordinates(coordinates[0], coordinates[1])
|
|
301
|
+
self._terminal_dict[hashcode] = terminal_value
|
|
302
|
+
|
|
303
|
+
@property
|
|
304
|
+
def x_lowlim(self) -> int:
|
|
305
|
+
return self._x_lowlim
|
|
306
|
+
|
|
307
|
+
@property
|
|
308
|
+
def x_hilim(self) -> int:
|
|
309
|
+
return self._x_hilim
|
|
310
|
+
|
|
311
|
+
@property
|
|
312
|
+
def y_lowlim(self) -> int:
|
|
313
|
+
return self._y_lowlim
|
|
314
|
+
|
|
315
|
+
@property
|
|
316
|
+
def y_hilim(self) -> int:
|
|
317
|
+
return self._y_hilim
|
pyrlutils/td/__init__.py
ADDED
|
File without changes
|
pyrlutils/td/td.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Annotated
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
|
|
7
|
+
from .utils import decay_schedule, AbstractTemporalDifferenceLearner, TimeDifferencePathElements
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SingleStepTemporalDifferenceLearner(AbstractTemporalDifferenceLearner):
|
|
11
|
+
def learn(
|
|
12
|
+
self,
|
|
13
|
+
episodes: int
|
|
14
|
+
) -> tuple[Annotated[NDArray[np.float64], "1D Array"], Annotated[NDArray[np.float64], "2D Array"]]:
|
|
15
|
+
V = np.zeros(self.nb_states)
|
|
16
|
+
V_track = np.zeros((episodes, self.nb_states))
|
|
17
|
+
alphas = decay_schedule(
|
|
18
|
+
self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
for i in range(episodes):
|
|
22
|
+
self._state.set_state_value(self.initial_state_index)
|
|
23
|
+
done = False
|
|
24
|
+
while not done:
|
|
25
|
+
old_state_index = self._state.state_index
|
|
26
|
+
old_state_value = self._state.state_value
|
|
27
|
+
action_value = self._policy.get_action_value(self._state.state_value)
|
|
28
|
+
action_func = self._actions_dict[action_value]
|
|
29
|
+
self._state = action_func(self._state)
|
|
30
|
+
new_state_index = self._state.state_index
|
|
31
|
+
new_state_value = self._state.state_value
|
|
32
|
+
reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
|
|
33
|
+
done = self._state.is_terminal
|
|
34
|
+
|
|
35
|
+
td_target = reward + self.gamma * V[new_state_index] * (not done)
|
|
36
|
+
td_error = td_target - V[old_state_index]
|
|
37
|
+
V[old_state_index] = V[old_state_index] + alphas[i] * td_error
|
|
38
|
+
|
|
39
|
+
V_track[i, :] = V
|
|
40
|
+
|
|
41
|
+
return V, V_track
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class MultipleStepTemporalDifferenceLearner(AbstractTemporalDifferenceLearner):
|
|
45
|
+
def learn(
|
|
46
|
+
self,
|
|
47
|
+
episodes: int,
|
|
48
|
+
n_steps: int=3
|
|
49
|
+
) -> tuple[Annotated[NDArray[np.float64], "1D Array"], Annotated[NDArray[np.float64], "2D Array"]]:
|
|
50
|
+
V = np.zeros(self.nb_states)
|
|
51
|
+
V_track = np.zeros((episodes, self.nb_states))
|
|
52
|
+
alphas = decay_schedule(
|
|
53
|
+
self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
|
|
54
|
+
)
|
|
55
|
+
discounts = np.logspace(0, n_steps-1, num=n_steps+1, base=self.gamma, endpoint=False)
|
|
56
|
+
|
|
57
|
+
for i in range(episodes):
|
|
58
|
+
self._state.set_state_value(self.initial_state_index)
|
|
59
|
+
done = False
|
|
60
|
+
path = []
|
|
61
|
+
|
|
62
|
+
while not done or path is not None:
|
|
63
|
+
path = path[1:] # worth revisiting this line
|
|
64
|
+
|
|
65
|
+
next_state_index = -1
|
|
66
|
+
while not done and len(path) < n_steps:
|
|
67
|
+
old_state_index = self._state.state_index
|
|
68
|
+
old_state_value = self._state.state_value
|
|
69
|
+
action_value = self._policy.get_action_value(self._state.state_value)
|
|
70
|
+
action_func = self._actions_dict[action_value]
|
|
71
|
+
self._state = action_func(self._state)
|
|
72
|
+
new_state_index = self._state.state_index
|
|
73
|
+
new_state_value = self._state.state_value
|
|
74
|
+
reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
|
|
75
|
+
done = self._state.is_terminal
|
|
76
|
+
|
|
77
|
+
path.append(
|
|
78
|
+
TimeDifferencePathElements(
|
|
79
|
+
this_state_index=old_state_index,
|
|
80
|
+
reward=reward,
|
|
81
|
+
next_state_index=new_state_index,
|
|
82
|
+
done=done
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
if done:
|
|
86
|
+
break
|
|
87
|
+
|
|
88
|
+
n = len(path)
|
|
89
|
+
estimated_state_index = path[0].this_state_index
|
|
90
|
+
rewards = np.array([this_moment.reward for this_moment in path])
|
|
91
|
+
partial_return = discounts[n:] * rewards
|
|
92
|
+
bs_val = discounts[-1] * V[next_state_index] * (not done)
|
|
93
|
+
ntd_target = np.sum(np.append(partial_return, bs_val))
|
|
94
|
+
ntd_error = ntd_target - V[estimated_state_index]
|
|
95
|
+
V[estimated_state_index] = V[estimated_state_index] + alphas[i] * ntd_error
|
|
96
|
+
if len(path) == 1 and path[0].done:
|
|
97
|
+
path = None
|
|
98
|
+
|
|
99
|
+
V_track[i, :] = V
|
|
100
|
+
|
|
101
|
+
return V, V_track
|
pyrlutils/td/utils.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Optional, Annotated
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from ..policy import DiscretePolicy
|
|
10
|
+
from ..transition import TransitionProbabilityFactory
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def decay_schedule(
|
|
14
|
+
init_value: float,
|
|
15
|
+
min_value: float,
|
|
16
|
+
decay_ratio: float,
|
|
17
|
+
max_steps: int,
|
|
18
|
+
log_start: int=-2,
|
|
19
|
+
log_base: int=10
|
|
20
|
+
) -> Annotated[NDArray[np.float64], "1D Array"]:
|
|
21
|
+
decay_steps = int(max_steps*decay_ratio)
|
|
22
|
+
rem_steps = max_steps - decay_steps
|
|
23
|
+
|
|
24
|
+
values = np.logspace(log_start, 0, decay_steps, base=log_base, endpoint=True)[::-1]
|
|
25
|
+
values = (values - values.min()) / (values.max() - values.min())
|
|
26
|
+
values = (init_value - min_value) * values + min_value
|
|
27
|
+
values = np.pad(values, (0, rem_steps), 'edge')
|
|
28
|
+
return values
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AbstractTemporalDifferenceLearner(ABC):
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
transprobfac: TransitionProbabilityFactory,
|
|
35
|
+
gamma: float=1.0,
|
|
36
|
+
init_alpha: float=0.5,
|
|
37
|
+
min_alpha: float=0.01,
|
|
38
|
+
alpha_decay_ratio: float=0.3,
|
|
39
|
+
policy: Optional[DiscretePolicy]=None,
|
|
40
|
+
initial_state_index: int=0
|
|
41
|
+
):
|
|
42
|
+
self._gamma = gamma
|
|
43
|
+
self._init_alpha = init_alpha
|
|
44
|
+
self._min_alpha = min_alpha
|
|
45
|
+
try:
|
|
46
|
+
assert 0.0 <= alpha_decay_ratio <= 1.0
|
|
47
|
+
except AssertionError:
|
|
48
|
+
raise ValueError("alpha_decay_ratio must be between 0 and 1!")
|
|
49
|
+
self._alpha_decay_ratio = alpha_decay_ratio
|
|
50
|
+
self._transprobfac = transprobfac
|
|
51
|
+
self._state, self._actions_dict, self._indrewardfcn = self._transprobfac.generate_mdp_objects()
|
|
52
|
+
self._action_names = list(self._actions_dict.keys())
|
|
53
|
+
self._actions_to_indices = {action_value: idx for idx, action_value in enumerate(self._action_names)}
|
|
54
|
+
self._policy = policy
|
|
55
|
+
try:
|
|
56
|
+
assert 0 <= initial_state_index < self._state.nb_state_values
|
|
57
|
+
except AssertionError:
|
|
58
|
+
raise ValueError("Initial state index must be between 0 and {}".format(len(self._state_names)))
|
|
59
|
+
self._init_state_index = initial_state_index
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
def learn(self, *args, **kwargs) -> tuple[Annotated[NDArray[np.float64], "1D Array"], Annotated[NDArray[np.float64], "2D Array"]]:
|
|
63
|
+
raise NotImplementedError()
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def nb_states(self) -> int:
|
|
67
|
+
return self._state.nb_state_values
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def policy(self) -> DiscretePolicy:
|
|
71
|
+
return self._policy
|
|
72
|
+
|
|
73
|
+
@policy.setter
|
|
74
|
+
def policy(self, val: DiscretePolicy):
|
|
75
|
+
self._policy = val
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def gamma(self) -> float:
|
|
79
|
+
return self._gamma
|
|
80
|
+
|
|
81
|
+
@gamma.setter
|
|
82
|
+
def gamma(self, val: float):
|
|
83
|
+
self._gamma = val
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def init_alpha(self) -> float:
|
|
87
|
+
return self._init_alpha
|
|
88
|
+
|
|
89
|
+
@init_alpha.setter
|
|
90
|
+
def init_alpha(self, val: float):
|
|
91
|
+
self._init_alpha = val
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def min_alpha(self) -> float:
|
|
95
|
+
return self._min_alpha
|
|
96
|
+
|
|
97
|
+
@min_alpha.setter
|
|
98
|
+
def min_alpha(self, val: float):
|
|
99
|
+
self._min_alpha = val
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def alpha_decay_ratio(self) -> float:
|
|
103
|
+
return self._alpha_decay_ratio
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def initial_state_index(self) -> int:
|
|
107
|
+
return self._init_state_index
|
|
108
|
+
|
|
109
|
+
@initial_state_index.setter
|
|
110
|
+
def initial_state_index(self, val: int):
|
|
111
|
+
self._init_state_index = val
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dataclass
|
|
115
|
+
class TimeDifferencePathElements:
|
|
116
|
+
this_state_index: int
|
|
117
|
+
reward: float
|
|
118
|
+
next_state_index: int
|
|
119
|
+
done: bool
|
pyrlutils/transition.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
|
|
2
|
-
from types import LambdaType
|
|
3
|
-
from typing import
|
|
2
|
+
from types import LambdaType, FunctionType
|
|
3
|
+
from typing import Union
|
|
4
|
+
from dataclasses import dataclass
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
|
|
@@ -9,28 +10,12 @@ from .reward import IndividualRewardFunction
|
|
|
9
10
|
from .action import Action, DiscreteActionValueType
|
|
10
11
|
|
|
11
12
|
|
|
13
|
+
@dataclass
|
|
12
14
|
class NextStateTuple:
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
self._terminal = terminal
|
|
18
|
-
|
|
19
|
-
@property
|
|
20
|
-
def next_state_value(self) -> DiscreteStateValueType:
|
|
21
|
-
return self._next_state_value
|
|
22
|
-
|
|
23
|
-
@property
|
|
24
|
-
def probability(self) -> float:
|
|
25
|
-
return self._probability
|
|
26
|
-
|
|
27
|
-
@property
|
|
28
|
-
def reward(self) -> float:
|
|
29
|
-
return self._reward
|
|
30
|
-
|
|
31
|
-
@property
|
|
32
|
-
def terminal(self) -> bool:
|
|
33
|
-
return self._terminal
|
|
15
|
+
next_state_value: DiscreteStateValueType
|
|
16
|
+
probability: float
|
|
17
|
+
reward: float
|
|
18
|
+
terminal: bool
|
|
34
19
|
|
|
35
20
|
|
|
36
21
|
class TransitionProbabilityFactory:
|
|
@@ -40,7 +25,11 @@ class TransitionProbabilityFactory:
|
|
|
40
25
|
self._all_action_values = []
|
|
41
26
|
self._objects_generated = False
|
|
42
27
|
|
|
43
|
-
def add_state_transitions(
|
|
28
|
+
def add_state_transitions(
|
|
29
|
+
self,
|
|
30
|
+
state_value: DiscreteStateValueType,
|
|
31
|
+
action_values_to_next_state: dict[DiscreteActionValueType, Union[list[NextStateTuple], dict]]
|
|
32
|
+
):
|
|
44
33
|
if state_value not in self._all_state_values:
|
|
45
34
|
self._all_state_values.append(state_value)
|
|
46
35
|
|
|
@@ -69,7 +58,10 @@ class TransitionProbabilityFactory:
|
|
|
69
58
|
|
|
70
59
|
self._transprobs[state_value] = this_state_transition_dict
|
|
71
60
|
|
|
72
|
-
def _get_probs_for_eachstate(
|
|
61
|
+
def _get_probs_for_eachstate(
|
|
62
|
+
self,
|
|
63
|
+
action_value: DiscreteActionValueType
|
|
64
|
+
) -> dict[DiscreteStateValueType, list[NextStateTuple]]:
|
|
73
65
|
state_nexttuples = {}
|
|
74
66
|
for state_value, action_nexttuples_pair in self._transprobs.items():
|
|
75
67
|
for this_action_value, nexttuples in action_nexttuples_pair.items():
|
|
@@ -77,7 +69,10 @@ class TransitionProbabilityFactory:
|
|
|
77
69
|
state_nexttuples[state_value] = nexttuples
|
|
78
70
|
return state_nexttuples
|
|
79
71
|
|
|
80
|
-
def _generate_action_function(
|
|
72
|
+
def _generate_action_function(
|
|
73
|
+
self,
|
|
74
|
+
state_nexttuples: dict[DiscreteStateValueType, list[NextStateTuple]]
|
|
75
|
+
) -> Union[FunctionType, LambdaType]:
|
|
81
76
|
|
|
82
77
|
def _action_function(state: DiscreteState) -> DiscreteState:
|
|
83
78
|
nexttuples = state_nexttuples[state.state_value]
|
|
@@ -91,7 +86,11 @@ class TransitionProbabilityFactory:
|
|
|
91
86
|
|
|
92
87
|
def _generate_individual_reward_function(self) -> IndividualRewardFunction:
|
|
93
88
|
|
|
94
|
-
def _individual_reward_function(
|
|
89
|
+
def _individual_reward_function(
|
|
90
|
+
state_value: DiscreteStateValueType,
|
|
91
|
+
action_value: DiscreteActionValueType,
|
|
92
|
+
next_state_value: DiscreteStateValueType
|
|
93
|
+
) -> float:
|
|
95
94
|
if state_value not in self._transprobs.keys():
|
|
96
95
|
return 0.
|
|
97
96
|
|
|
@@ -105,15 +104,22 @@ class TransitionProbabilityFactory:
|
|
|
105
104
|
return reward
|
|
106
105
|
|
|
107
106
|
class ThisIndividualRewardFunction(IndividualRewardFunction):
|
|
108
|
-
def
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
107
|
+
def reward(
|
|
108
|
+
self,
|
|
109
|
+
state_value: DiscreteStateValueType,
|
|
110
|
+
action_value: DiscreteActionValueType,
|
|
111
|
+
next_state_value: DiscreteStateValueType
|
|
112
|
+
) -> float:
|
|
112
113
|
return _individual_reward_function(state_value, action_value, next_state_value)
|
|
113
114
|
|
|
114
115
|
return ThisIndividualRewardFunction()
|
|
115
116
|
|
|
116
|
-
def get_probability(
|
|
117
|
+
def get_probability(
|
|
118
|
+
self,
|
|
119
|
+
state_value: DiscreteStateValueType,
|
|
120
|
+
action_value: DiscreteActionValueType,
|
|
121
|
+
new_state_value: DiscreteStateValueType
|
|
122
|
+
) -> float:
|
|
117
123
|
if state_value not in self._transprobs.keys():
|
|
118
124
|
return 0.
|
|
119
125
|
|
|
@@ -127,18 +133,21 @@ class TransitionProbabilityFactory:
|
|
|
127
133
|
return probs
|
|
128
134
|
|
|
129
135
|
@property
|
|
130
|
-
def transition_probabilities(self) -> dict:
|
|
136
|
+
def transition_probabilities(self) -> dict[DiscreteStateValueType, dict[DiscreteActionValueType, list[NextStateTuple]]]:
|
|
131
137
|
return self._transprobs
|
|
132
138
|
|
|
133
|
-
def generate_mdp_objects(self) ->
|
|
139
|
+
def generate_mdp_objects(self) -> tuple[DiscreteState, dict[DiscreteActionValueType, Action], IndividualRewardFunction]:
|
|
134
140
|
state = DiscreteState(self._all_state_values)
|
|
135
141
|
actions_dict = {}
|
|
136
142
|
for action_value in self._all_action_values:
|
|
137
143
|
state_nexttuple = self._get_probs_for_eachstate(action_value)
|
|
138
144
|
actions_dict[action_value] = Action(self._generate_action_function(state_nexttuple))
|
|
145
|
+
for next_tuples in state_nexttuple.values():
|
|
146
|
+
for next_tuple in next_tuples:
|
|
147
|
+
state._terminal_dict[next_tuple.next_state_value] = next_tuple.terminal
|
|
139
148
|
|
|
140
149
|
individual_reward_fcn = self._generate_individual_reward_function()
|
|
141
|
-
|
|
150
|
+
self._objects_generated = True
|
|
142
151
|
return state, actions_dict, individual_reward_fcn
|
|
143
152
|
|
|
144
153
|
@property
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: pyrlutils
|
|
3
|
-
Version: 0.0
|
|
3
|
+
Version: 0.1.0
|
|
4
4
|
Summary: Utility and Helpers for Reinformcement Learning
|
|
5
5
|
Author-email: Kwan Yuet Stephen Ho <stephenhky@yahoo.com.hk>
|
|
6
6
|
License: MIT
|
|
@@ -11,22 +11,22 @@ Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
|
11
11
|
Classifier: License :: OSI Approved :: MIT License
|
|
12
12
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
13
13
|
Classifier: Topic :: Software Development :: Version Control :: Git
|
|
14
|
-
Classifier: Programming Language :: Python :: 3.7
|
|
15
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
16
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
17
14
|
Classifier: Programming Language :: Python :: 3.10
|
|
18
15
|
Classifier: Programming Language :: Python :: 3.11
|
|
19
16
|
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
18
|
Classifier: Intended Audience :: Science/Research
|
|
21
19
|
Classifier: Intended Audience :: Developers
|
|
22
|
-
Requires-Python: >=3.
|
|
20
|
+
Requires-Python: >=3.10
|
|
23
21
|
Description-Content-Type: text/markdown
|
|
24
22
|
License-File: LICENSE
|
|
25
23
|
Requires-Dist: numpy
|
|
24
|
+
Requires-Dist: typing-extensions
|
|
26
25
|
Provides-Extra: openaigym
|
|
27
26
|
Requires-Dist: gymnasium; extra == "openaigym"
|
|
28
27
|
Provides-Extra: test
|
|
29
28
|
Requires-Dist: unittest; extra == "test"
|
|
29
|
+
Dynamic: license-file
|
|
30
30
|
|
|
31
31
|
# PyRLUtils
|
|
32
32
|
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
pyrlutils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
pyrlutils/action.py,sha256=QoBdtcGtK_EkYAjb50bruhoB_XIz0agLpQjdGFnGbRQ,732
|
|
3
|
+
pyrlutils/policy.py,sha256=A9bj2eVd6XjNNkClSYVJDoxoGuGkyoYVr1DpVdI0wzs,5120
|
|
4
|
+
pyrlutils/reward.py,sha256=are0swsobMqI1IbrBVBaPMYXWpJnp6lZwAyfgBEm2zg,1211
|
|
5
|
+
pyrlutils/state.py,sha256=A3XJSjNJrsInXUWsUvb1GE7Oq-CY6DNEB-ulrVa1rR4,11774
|
|
6
|
+
pyrlutils/transition.py,sha256=_32jxeYbsiKyaHR9Y2XceUQYbb1jslLCQO2AWL61_EU,6260
|
|
7
|
+
pyrlutils/bandit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
+
pyrlutils/bandit/algo.py,sha256=X2Pn4DOi-RXWz5CNg1h0RJCoV3VlAwEGHRMjkfbckfw,3969
|
|
9
|
+
pyrlutils/bandit/reward.py,sha256=l2H_gZk2qqDxZioHe1M28pD8N47fgSR-K0Q6muchVd0,282
|
|
10
|
+
pyrlutils/dp/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
pyrlutils/dp/valuefcns.py,sha256=0T7vzdKRIKhLMsaq7JgPqONMmq4lWRGPj7xPtuxVtbE,6546
|
|
12
|
+
pyrlutils/helpers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
+
pyrlutils/helpers/exceptions.py,sha256=4fPGW839BChfap-Gd7b-75Dz-Ed3foqbJQ1lg15TZ-4,192
|
|
14
|
+
pyrlutils/openai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
+
pyrlutils/openai/utils.py,sha256=PJc9WHZM8aM4Z9MlACUxUC8TO7VARp8taatba_ikhew,1056
|
|
16
|
+
pyrlutils/td/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
|
+
pyrlutils/td/td.py,sha256=EnecL84yyUm7rO2idaHgVfvtWW5LYPxEkefHhI1SVPQ,4269
|
|
18
|
+
pyrlutils/td/utils.py,sha256=PALXGaDLd3PjFh8qDV9DY_MkaBuj3_GpfVWJOb424vE,3571
|
|
19
|
+
pyrlutils-0.1.0.dist-info/licenses/LICENSE,sha256=bnQPjIcaeBdr2ZofX-_j-nELs8pAx5fQ4Cdfgeaspew,1063
|
|
20
|
+
pyrlutils-0.1.0.dist-info/METADATA,sha256=qKVydib9iWVw-NXgMnB3y0JtDibVQcvclyc7zP2PYH0,2185
|
|
21
|
+
pyrlutils-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
pyrlutils-0.1.0.dist-info/top_level.txt,sha256=gOBuxugE2MA4WDXlLhzkQh_rUonZU6nvJnMuomeHMCU,10
|
|
23
|
+
pyrlutils-0.1.0.dist-info/RECORD,,
|
pyrlutils-0.0.4.dist-info/RECORD
DELETED
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
pyrlutils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
pyrlutils/action.py,sha256=2kJqNZxsLOV8yOTl-RcpM8b0zu-WXNREJCrl49uZi2c,437
|
|
3
|
-
pyrlutils/policy.py,sha256=Cx4vsIXzFZi_KEgI06S378Y5E6g-AfK90skDYoGsfOI,2794
|
|
4
|
-
pyrlutils/reward.py,sha256=are0swsobMqI1IbrBVBaPMYXWpJnp6lZwAyfgBEm2zg,1211
|
|
5
|
-
pyrlutils/state.py,sha256=w0YJ50FUyNboPoYduLMX1xaBJJHAOaSlsr3Og1dd0dY,7840
|
|
6
|
-
pyrlutils/transition.py,sha256=lgh4YfOi-YjSIyymWfrXe-ugDWpZYK3MvjdeehgcQhk,5816
|
|
7
|
-
pyrlutils/valuefcns.py,sha256=CJxu0EIFgrdbP0n0x6nzs3X08accFsuJW71tv1rMTkQ,6342
|
|
8
|
-
pyrlutils/bandit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
-
pyrlutils/bandit/algo.py,sha256=X2Pn4DOi-RXWz5CNg1h0RJCoV3VlAwEGHRMjkfbckfw,3969
|
|
10
|
-
pyrlutils/bandit/reward.py,sha256=S_uECjMOg3cmK24J-5uPcckLvtxmU4yllR7JEvMwAQE,249
|
|
11
|
-
pyrlutils/openai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
-
pyrlutils/openai/utils.py,sha256=ehj1cGlDYjQLno3pKMCS3CzZwbZGSTmjxDlU07aSBFo,1033
|
|
13
|
-
pyrlutils-0.0.4.dist-info/LICENSE,sha256=bnQPjIcaeBdr2ZofX-_j-nELs8pAx5fQ4Cdfgeaspew,1063
|
|
14
|
-
pyrlutils-0.0.4.dist-info/METADATA,sha256=7ncLjVrpqIZpdMFMrRjqRNgfZl9LUKc_SZFkw_CoTFc,2228
|
|
15
|
-
pyrlutils-0.0.4.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
16
|
-
pyrlutils-0.0.4.dist-info/top_level.txt,sha256=gOBuxugE2MA4WDXlLhzkQh_rUonZU6nvJnMuomeHMCU,10
|
|
17
|
-
pyrlutils-0.0.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|