pyrlutils 0.0.4__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyrlutils might be problematic. Click here for more details.
- pyrlutils/action.py +10 -2
- pyrlutils/bandit/reward.py +3 -2
- pyrlutils/dp/__init__.py +0 -0
- pyrlutils/{valuefcns.py → dp/valuefcns.py} +16 -11
- pyrlutils/helpers/__init__.py +0 -0
- pyrlutils/helpers/exceptions.py +5 -0
- pyrlutils/openai/utils.py +3 -3
- pyrlutils/policy.py +79 -12
- pyrlutils/state.py +169 -74
- pyrlutils/td/__init__.py +0 -0
- pyrlutils/td/qlearn.py +86 -0
- pyrlutils/td/sarsa.py +86 -0
- pyrlutils/td/state_td.py +111 -0
- pyrlutils/td/utils.py +258 -0
- pyrlutils/transition.py +44 -35
- {pyrlutils-0.0.4.dist-info → pyrlutils-0.1.1.dist-info}/METADATA +7 -6
- pyrlutils-0.1.1.dist-info/RECORD +25 -0
- {pyrlutils-0.0.4.dist-info → pyrlutils-0.1.1.dist-info}/WHEEL +1 -1
- pyrlutils-0.0.4.dist-info/RECORD +0 -17
- {pyrlutils-0.0.4.dist-info → pyrlutils-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {pyrlutils-0.0.4.dist-info → pyrlutils-0.1.1.dist-info}/top_level.txt +0 -0
pyrlutils/action.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
|
|
2
|
-
from types import LambdaType
|
|
2
|
+
from types import LambdaType, FunctionType
|
|
3
3
|
from typing import Union
|
|
4
4
|
|
|
5
5
|
from .state import State
|
|
@@ -8,7 +8,7 @@ from .state import State
|
|
|
8
8
|
DiscreteActionValueType = Union[float, str]
|
|
9
9
|
|
|
10
10
|
class Action:
|
|
11
|
-
def __init__(self, actionfunc: LambdaType):
|
|
11
|
+
def __init__(self, actionfunc: Union[FunctionType, LambdaType]):
|
|
12
12
|
self._actionfunc = actionfunc
|
|
13
13
|
|
|
14
14
|
def act(self, state: State, *args, **kwargs) -> State:
|
|
@@ -17,3 +17,11 @@ class Action:
|
|
|
17
17
|
|
|
18
18
|
def __call__(self, state: State) -> State:
|
|
19
19
|
return self.act(state)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def action_function(self) -> Union[FunctionType, LambdaType]:
|
|
23
|
+
return self._actionfunc
|
|
24
|
+
|
|
25
|
+
@action_function.setter
|
|
26
|
+
def action_function(self, new_func: Union[FunctionType, LambdaType]) -> None:
|
|
27
|
+
self._actionfunc = new_func
|
pyrlutils/bandit/reward.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class IndividualBanditRewardFunction(ABC):
|
|
6
7
|
@abstractmethod
|
|
7
|
-
def reward(self, action_value) -> float:
|
|
8
|
+
def reward(self, action_value: Any) -> float:
|
|
8
9
|
pass
|
|
9
10
|
|
|
10
|
-
def __call__(self, action_value) -> float:
|
|
11
|
+
def __call__(self, action_value: Any) -> float:
|
|
11
12
|
return self.reward(action_value)
|
pyrlutils/dp/__init__.py
ADDED
|
File without changes
|
|
@@ -1,18 +1,23 @@
|
|
|
1
1
|
|
|
2
2
|
import random
|
|
3
3
|
from copy import copy
|
|
4
|
-
from typing import Tuple, Dict
|
|
5
4
|
from itertools import product
|
|
5
|
+
from typing import Annotated
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
8
9
|
|
|
9
|
-
from
|
|
10
|
-
from
|
|
11
|
-
from
|
|
10
|
+
from ..state import DiscreteStateValueType
|
|
11
|
+
from ..transition import TransitionProbabilityFactory
|
|
12
|
+
from ..policy import DiscreteDeterminsticPolicy
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class OptimalPolicyOnValueFunctions:
|
|
15
|
-
def __init__(
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
discount_factor: float,
|
|
19
|
+
transprobfac: TransitionProbabilityFactory
|
|
20
|
+
):
|
|
16
21
|
try:
|
|
17
22
|
assert 0. <= discount_factor <= 1.
|
|
18
23
|
except AssertionError:
|
|
@@ -31,7 +36,7 @@ class OptimalPolicyOnValueFunctions:
|
|
|
31
36
|
self._theta = 1e-10
|
|
32
37
|
self._policy_evaluation_maxiter = 10000
|
|
33
38
|
|
|
34
|
-
def _policy_evaluation(self, policy: DiscreteDeterminsticPolicy) -> np.
|
|
39
|
+
def _policy_evaluation(self, policy: DiscreteDeterminsticPolicy) -> Annotated[NDArray[np.float64], "1D Array"]:
|
|
35
40
|
prev_V = np.zeros(len(self._states_to_indices))
|
|
36
41
|
|
|
37
42
|
for _ in range(self._policy_evaluation_maxiter):
|
|
@@ -55,7 +60,7 @@ class OptimalPolicyOnValueFunctions:
|
|
|
55
60
|
|
|
56
61
|
return V
|
|
57
62
|
|
|
58
|
-
def _policy_improvement(self, V: np.
|
|
63
|
+
def _policy_improvement(self, V: Annotated[NDArray[np.float64], "1D Array"]) -> DiscreteDeterminsticPolicy:
|
|
59
64
|
Q = np.zeros((len(self._states_to_indices), len(self._actions_to_indices)))
|
|
60
65
|
|
|
61
66
|
for state_value in self._state_names:
|
|
@@ -78,7 +83,7 @@ class OptimalPolicyOnValueFunctions:
|
|
|
78
83
|
optimal_policy.add_deterministic_rule(state_value, action_value)
|
|
79
84
|
return optimal_policy
|
|
80
85
|
|
|
81
|
-
def _policy_iteration(self) ->
|
|
86
|
+
def _policy_iteration(self) -> tuple[Annotated[NDArray[np.float64], "1D Array"], DiscreteDeterminsticPolicy]:
|
|
82
87
|
policy = DiscreteDeterminsticPolicy(self._actions_dict)
|
|
83
88
|
for state_value in self._state_names:
|
|
84
89
|
policy.add_deterministic_rule(state_value, random.choice(self._action_names))
|
|
@@ -97,7 +102,7 @@ class OptimalPolicyOnValueFunctions:
|
|
|
97
102
|
return V, policy
|
|
98
103
|
|
|
99
104
|
|
|
100
|
-
def _value_iteration(self) ->
|
|
105
|
+
def _value_iteration(self) -> tuple[Annotated[NDArray[np.float64], "1D Array"], DiscreteDeterminsticPolicy]:
|
|
101
106
|
V = np.zeros(len(self._state_names))
|
|
102
107
|
|
|
103
108
|
for _ in range(self._policy_evaluation_maxiter):
|
|
@@ -127,7 +132,7 @@ class OptimalPolicyOnValueFunctions:
|
|
|
127
132
|
|
|
128
133
|
return V, policy
|
|
129
134
|
|
|
130
|
-
def policy_iteration(self) ->
|
|
135
|
+
def policy_iteration(self) -> tuple[dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
|
|
131
136
|
V, policy = self._policy_iteration()
|
|
132
137
|
state_values_dict = {
|
|
133
138
|
self._state_names[i]: V[i]
|
|
@@ -135,7 +140,7 @@ class OptimalPolicyOnValueFunctions:
|
|
|
135
140
|
}
|
|
136
141
|
return state_values_dict, policy
|
|
137
142
|
|
|
138
|
-
def value_iteration(self) ->
|
|
143
|
+
def value_iteration(self) -> tuple[dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
|
|
139
144
|
V, policy = self._value_iteration()
|
|
140
145
|
state_values_dict = {
|
|
141
146
|
self._state_names[i]: V[i]
|
|
File without changes
|
pyrlutils/openai/utils.py
CHANGED
|
@@ -5,7 +5,7 @@ from ..transition import TransitionProbabilityFactory, NextStateTuple
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory(TransitionProbabilityFactory):
|
|
8
|
-
def __init__(self, envname):
|
|
8
|
+
def __init__(self, envname: str):
|
|
9
9
|
super().__init__()
|
|
10
10
|
self._envname = envname
|
|
11
11
|
self._gymenv = gym.make(envname)
|
|
@@ -23,9 +23,9 @@ class OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory(TransitionProbabi
|
|
|
23
23
|
self.add_state_transitions(state_value, new_trans_dict)
|
|
24
24
|
|
|
25
25
|
@property
|
|
26
|
-
def envname(self):
|
|
26
|
+
def envname(self) -> str:
|
|
27
27
|
return self._envname
|
|
28
28
|
|
|
29
29
|
@property
|
|
30
|
-
def gymenv(self):
|
|
30
|
+
def gymenv(self) -> gym.Env:
|
|
31
31
|
return self._gymenv
|
pyrlutils/policy.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
-
from typing import Union,
|
|
3
|
+
from typing import Union, Annotated
|
|
4
4
|
from warnings import warn
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
7
8
|
|
|
8
9
|
from .state import State, DiscreteState, DiscreteStateValueType
|
|
9
10
|
from .action import Action, DiscreteActionValueType
|
|
@@ -12,7 +13,11 @@ from .action import Action, DiscreteActionValueType
|
|
|
12
13
|
class Policy(ABC):
|
|
13
14
|
@abstractmethod
|
|
14
15
|
def get_action(self, state: State) -> Action:
|
|
15
|
-
|
|
16
|
+
raise NotImplemented()
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def get_action_value(self, state: State) -> DiscreteActionValueType:
|
|
20
|
+
raise NotImplemented()
|
|
16
21
|
|
|
17
22
|
def __call__(self, state: State) -> Action:
|
|
18
23
|
return self.get_action(state)
|
|
@@ -25,7 +30,7 @@ class Policy(ABC):
|
|
|
25
30
|
class DeterministicPolicy(Policy):
|
|
26
31
|
@abstractmethod
|
|
27
32
|
def add_deterministic_rule(self, *args, **kwargs):
|
|
28
|
-
|
|
33
|
+
raise NotImplemented()
|
|
29
34
|
|
|
30
35
|
@property
|
|
31
36
|
def is_stochastic(self) -> bool:
|
|
@@ -33,16 +38,23 @@ class DeterministicPolicy(Policy):
|
|
|
33
38
|
|
|
34
39
|
|
|
35
40
|
class DiscreteDeterminsticPolicy(DeterministicPolicy):
|
|
36
|
-
def __init__(self, actions_dict:
|
|
41
|
+
def __init__(self, actions_dict: dict[DiscreteActionValueType, Action]):
|
|
37
42
|
self._state_to_action = {}
|
|
38
43
|
self._actions_dict = actions_dict
|
|
39
44
|
|
|
40
|
-
def add_deterministic_rule(
|
|
45
|
+
def add_deterministic_rule(
|
|
46
|
+
self,
|
|
47
|
+
state_value: DiscreteStateValueType,
|
|
48
|
+
action_value: DiscreteActionValueType
|
|
49
|
+
) -> None:
|
|
41
50
|
if state_value in self._state_to_action:
|
|
42
51
|
warn('State value {} exists in rule; it will be replaced.'.format(state_value))
|
|
43
52
|
self._state_to_action[state_value] = action_value
|
|
44
53
|
|
|
45
|
-
def get_action_value(
|
|
54
|
+
def get_action_value(
|
|
55
|
+
self,
|
|
56
|
+
state_value: DiscreteStateValueType
|
|
57
|
+
) -> DiscreteActionValueType:
|
|
46
58
|
return self._state_to_action.get(state_value)
|
|
47
59
|
|
|
48
60
|
def get_action(self, state: DiscreteState) -> Action:
|
|
@@ -62,10 +74,16 @@ class DiscreteDeterminsticPolicy(DeterministicPolicy):
|
|
|
62
74
|
return True
|
|
63
75
|
|
|
64
76
|
|
|
77
|
+
class DiscreteContinuousPolicy(DeterministicPolicy):
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def get_action(self, state: State) -> Action:
|
|
80
|
+
raise NotImplemented()
|
|
81
|
+
|
|
82
|
+
|
|
65
83
|
class StochasticPolicy(Policy):
|
|
66
84
|
@abstractmethod
|
|
67
85
|
def get_probability(self, *args, **kwargs) -> float:
|
|
68
|
-
|
|
86
|
+
raise NotImplemented()
|
|
69
87
|
|
|
70
88
|
@property
|
|
71
89
|
def is_stochastic(self) -> bool:
|
|
@@ -73,12 +91,61 @@ class StochasticPolicy(Policy):
|
|
|
73
91
|
|
|
74
92
|
|
|
75
93
|
class DiscreteStochasticPolicy(StochasticPolicy):
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
94
|
+
def __init__(self, actions_dict: dict[DiscreteActionValueType, Action]):
|
|
95
|
+
self._state_to_action = {}
|
|
96
|
+
self._actions_dict = actions_dict
|
|
97
|
+
|
|
98
|
+
def add_stochastic_rule(
|
|
99
|
+
self,
|
|
100
|
+
state_value: DiscreteStateValueType,
|
|
101
|
+
action_values: list[DiscreteActionValueType],
|
|
102
|
+
probs: Union[list[float], Annotated[NDArray[np.float64], "1D Array"]] = None
|
|
103
|
+
):
|
|
104
|
+
if probs is not None:
|
|
105
|
+
assert len(action_values) == len(probs)
|
|
106
|
+
probs = np.array(probs)
|
|
107
|
+
else:
|
|
108
|
+
probs = np.repeat(1./len(action_values), len(action_values))
|
|
109
|
+
|
|
110
|
+
if state_value in self._state_to_action:
|
|
111
|
+
warn('State value {} exists in rule; it will be replaced.'.format(state_value))
|
|
112
|
+
self._state_to_action[state_value] = {
|
|
113
|
+
action_value: prob
|
|
114
|
+
for action_value, prob in zip(action_values, probs)
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
def get_probability(
|
|
118
|
+
self,
|
|
119
|
+
state_value: DiscreteStateValueType,
|
|
120
|
+
action_value: DiscreteActionValueType
|
|
121
|
+
) -> float:
|
|
122
|
+
if state_value not in self._state_to_action:
|
|
123
|
+
return 0.0
|
|
124
|
+
if action_value in self._state_to_action[state_value]:
|
|
125
|
+
return self._state_to_action[state_value][action_value]
|
|
126
|
+
else:
|
|
127
|
+
return 0.0
|
|
128
|
+
|
|
129
|
+
def get_action_value(self, state: State) -> DiscreteActionValueType:
|
|
130
|
+
allowed_actions = list(self._state_to_action[state].keys())
|
|
131
|
+
probs = np.array(list(self._state_to_action[state].values()))
|
|
132
|
+
sumprobs = np.sum(probs)
|
|
133
|
+
return np.random.choice(allowed_actions, p=probs/sumprobs)
|
|
134
|
+
|
|
135
|
+
def get_action(self, state: DiscreteState) -> Action:
|
|
136
|
+
return self._actions_dict[self.get_action_value(state.state_value)]
|
|
79
137
|
|
|
80
138
|
|
|
81
139
|
class ContinuousStochasticPolicy(StochasticPolicy):
|
|
82
140
|
@abstractmethod
|
|
83
|
-
def get_probability(
|
|
84
|
-
|
|
141
|
+
def get_probability(
|
|
142
|
+
self,
|
|
143
|
+
state_value: Union[float, Annotated[NDArray[np.float64], "1D Array"]],
|
|
144
|
+
action_value: DiscreteActionValueType,
|
|
145
|
+
value: Union[float, Annotated[NDArray[np.float64], "1D Array"]]
|
|
146
|
+
) -> float:
|
|
147
|
+
raise NotImplemented()
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
DiscretePolicy = Union[DiscreteDeterminsticPolicy, DiscreteStochasticPolicy]
|
|
151
|
+
ContinuousPolicy = Union[ContinuousStochasticPolicy]
|
pyrlutils/state.py
CHANGED
|
@@ -1,81 +1,87 @@
|
|
|
1
1
|
|
|
2
|
-
|
|
2
|
+
import sys
|
|
3
|
+
from abc import ABC
|
|
3
4
|
from enum import Enum
|
|
4
|
-
from
|
|
5
|
-
from typing import Tuple, List, Optional, Union
|
|
5
|
+
from typing import Optional, Union, Annotated, Literal
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
if sys.version_info < (3, 11):
|
|
10
|
+
from typing_extensions import Self
|
|
11
|
+
else:
|
|
12
|
+
from typing import Self
|
|
8
13
|
|
|
9
|
-
|
|
10
|
-
class StateValue(ABC):
|
|
11
|
-
@property
|
|
12
|
-
@abstractmethod
|
|
13
|
-
def value(self):
|
|
14
|
-
pass
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
@dataclass
|
|
18
|
-
class DiscreteStateValue(StateValue):
|
|
19
|
-
enum: Enum
|
|
20
|
-
|
|
21
|
-
@property
|
|
22
|
-
def value(self):
|
|
23
|
-
return self.enum.value
|
|
24
|
-
|
|
25
|
-
def name(self):
|
|
26
|
-
return self.enum.name
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class ContinuousStateValue(StateValue):
|
|
30
|
-
_value: float
|
|
31
|
-
|
|
32
|
-
@property
|
|
33
|
-
def value(self) -> float:
|
|
34
|
-
return self._value
|
|
14
|
+
from .helpers.exceptions import InvalidRangeError
|
|
35
15
|
|
|
36
16
|
|
|
37
17
|
class State(ABC):
|
|
38
18
|
@property
|
|
39
19
|
def state_value(self):
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
@abstractmethod
|
|
43
|
-
def set_state_value(self, state_value):
|
|
44
|
-
pass
|
|
20
|
+
raise NotImplemented()
|
|
45
21
|
|
|
46
|
-
@abstractmethod
|
|
47
|
-
def get_state_value(self):
|
|
48
|
-
pass
|
|
49
|
-
|
|
50
|
-
@state_value.setter
|
|
51
|
-
def state_value(self, new_state_value):
|
|
52
|
-
self.set_state_value(new_state_value)
|
|
53
22
|
|
|
54
|
-
|
|
55
|
-
DiscreteStateValueType = Union[float, str, Tuple[int], Enum]
|
|
23
|
+
DiscreteStateValueType = Union[str, int, tuple[int], Enum]
|
|
56
24
|
|
|
57
25
|
|
|
58
26
|
class DiscreteState(State):
|
|
59
|
-
def __init__(
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
all_state_values: list[DiscreteStateValueType],
|
|
30
|
+
initial_value: Optional[DiscreteStateValueType] = None,
|
|
31
|
+
terminals: Optional[dict[DiscreteStateValueType, bool]]=None
|
|
32
|
+
):
|
|
60
33
|
super().__init__()
|
|
61
34
|
self._all_state_values = all_state_values
|
|
62
|
-
self.
|
|
35
|
+
self._state_values_to_indices = {
|
|
36
|
+
state_value: idx
|
|
37
|
+
for idx, state_value in enumerate(self._all_state_values)
|
|
38
|
+
}
|
|
39
|
+
if initial_value is not None:
|
|
40
|
+
self._current_index = self._state_values_to_indices[initial_value]
|
|
41
|
+
else:
|
|
42
|
+
self._current_index = 0
|
|
43
|
+
if terminals is None:
|
|
44
|
+
self._terminal_dict = {
|
|
45
|
+
state_value: False
|
|
46
|
+
for state_value in self._all_state_values
|
|
47
|
+
}
|
|
48
|
+
else:
|
|
49
|
+
self._terminal_dict = terminals.copy()
|
|
50
|
+
for state_value in self._all_state_values:
|
|
51
|
+
if self._terminal_dict.get(state_value) is None:
|
|
52
|
+
self._terminal_dict[state_value] = False
|
|
53
|
+
|
|
54
|
+
def _get_state_value_from_index(self, index: int) -> DiscreteStateValueType:
|
|
55
|
+
return self._all_state_values[index]
|
|
63
56
|
|
|
64
57
|
def get_state_value(self) -> DiscreteStateValueType:
|
|
65
|
-
return self.
|
|
58
|
+
return self._get_state_value_from_index(self._current_index)
|
|
66
59
|
|
|
67
|
-
def set_state_value(self, state_value: DiscreteStateValueType):
|
|
60
|
+
def set_state_value(self, state_value: DiscreteStateValueType) -> None:
|
|
68
61
|
if state_value in self._all_state_values:
|
|
69
|
-
self.
|
|
62
|
+
self._current_index = self._state_values_to_indices[state_value]
|
|
70
63
|
else:
|
|
71
64
|
raise ValueError('State value {} is invalid.'.format(state_value))
|
|
72
65
|
|
|
73
|
-
def get_all_possible_state_values(self) ->
|
|
66
|
+
def get_all_possible_state_values(self) -> list[DiscreteStateValueType]:
|
|
74
67
|
return self._all_state_values
|
|
75
68
|
|
|
69
|
+
def query_state_index_from_value(self, value: DiscreteStateValueType) -> int:
|
|
70
|
+
return self._state_values_to_indices[value]
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def state_index(self) -> int:
|
|
74
|
+
return self._current_index
|
|
75
|
+
|
|
76
|
+
@state_index.setter
|
|
77
|
+
def state_index(self, new_index: int) -> None:
|
|
78
|
+
if new_index >= len(self._all_state_values):
|
|
79
|
+
raise ValueError(f"Invalid index {new_index}; it must be less than {self.nb_state_values}.")
|
|
80
|
+
self._current_index = new_index
|
|
81
|
+
|
|
76
82
|
@property
|
|
77
83
|
def state_value(self) -> DiscreteStateValueType:
|
|
78
|
-
return self.
|
|
84
|
+
return self._all_state_values[self._current_index]
|
|
79
85
|
|
|
80
86
|
@state_value.setter
|
|
81
87
|
def state_value(self, new_state_value: DiscreteStateValueType):
|
|
@@ -85,22 +91,53 @@ class DiscreteState(State):
|
|
|
85
91
|
def state_space_size(self):
|
|
86
92
|
return len(self._all_state_values)
|
|
87
93
|
|
|
94
|
+
@property
|
|
95
|
+
def nb_state_values(self) -> int:
|
|
96
|
+
return len(self._all_state_values)
|
|
88
97
|
|
|
89
|
-
|
|
90
|
-
def
|
|
91
|
-
self.
|
|
92
|
-
|
|
98
|
+
@property
|
|
99
|
+
def is_terminal(self) -> bool:
|
|
100
|
+
return self._terminal_dict[self._all_state_values[self._current_index]]
|
|
101
|
+
|
|
102
|
+
def __hash__(self):
|
|
103
|
+
return self._current_index
|
|
104
|
+
|
|
105
|
+
def __eq__(self, other: Self) -> bool:
|
|
106
|
+
return self._current_index == other._current_index
|
|
93
107
|
|
|
94
108
|
|
|
95
109
|
class ContinuousState(State):
|
|
96
|
-
def __init__(
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
nbdims: int,
|
|
113
|
+
ranges: Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]],
|
|
114
|
+
init_value: Optional[Union[float, Annotated[NDArray[np.float64], "1D Array"]]] = None
|
|
115
|
+
):
|
|
116
|
+
super().__init__()
|
|
97
117
|
self._nbdims = nbdims
|
|
98
118
|
|
|
119
|
+
try:
|
|
120
|
+
assert isinstance(ranges, np.ndarray)
|
|
121
|
+
except AssertionError:
|
|
122
|
+
raise TypeError('Range must be a numpy array.')
|
|
123
|
+
|
|
99
124
|
try:
|
|
100
125
|
assert (ranges.dtype == np.float64) or (ranges.dtype == np.float32) or (ranges.dtype == np.float16)
|
|
101
126
|
except AssertionError:
|
|
102
127
|
raise TypeError('It has to be floating type numpy.ndarray.')
|
|
103
128
|
|
|
129
|
+
try:
|
|
130
|
+
assert ranges.ndim == 1 or ranges.ndim == 2
|
|
131
|
+
match ranges.ndim:
|
|
132
|
+
case 1:
|
|
133
|
+
assert ranges.shape[0] == 2
|
|
134
|
+
case 2:
|
|
135
|
+
assert ranges.shape[1] == 2
|
|
136
|
+
case _:
|
|
137
|
+
raise ValueError("Ranges must be of shape (2, ) or (*, 2).")
|
|
138
|
+
except AssertionError:
|
|
139
|
+
raise ValueError("Ranges must be of shape (2, ) or (*, 2).")
|
|
140
|
+
|
|
104
141
|
try:
|
|
105
142
|
assert self._nbdims > 0
|
|
106
143
|
except AssertionError:
|
|
@@ -146,50 +183,53 @@ class ContinuousState(State):
|
|
|
146
183
|
raise ValueError('Initialized value does not have the right dimension.')
|
|
147
184
|
for i in range(self._nbdims):
|
|
148
185
|
try:
|
|
149
|
-
assert
|
|
186
|
+
assert self._ranges[i, 0] <= init_value[i] <= self.ranges[i, 1]
|
|
150
187
|
except AssertionError:
|
|
151
188
|
raise InvalidRangeError('Initialized value at dimension {} (value: {}) is not within the permitted range ({} -> {})!'.format(i, init_value[i], self._ranges[i, 0], self._ranges[i, 1]))
|
|
152
189
|
else:
|
|
153
190
|
try:
|
|
154
|
-
assert
|
|
191
|
+
assert self._ranges[0, 0] <= init_value <= self.ranges[0, 1]
|
|
155
192
|
except AssertionError:
|
|
156
193
|
raise InvalidRangeError('Initialized value is out of range.')
|
|
157
194
|
self._state_value = init_value
|
|
158
195
|
|
|
159
|
-
def set_state_value(self, state_value: Union[float, np.
|
|
160
|
-
if self.
|
|
196
|
+
def set_state_value(self, state_value: Union[float, Annotated[NDArray[np.float64], "1D Array"]]):
|
|
197
|
+
if self._nbdims > 1:
|
|
161
198
|
try:
|
|
162
199
|
assert state_value.shape[0] == self._nbdims
|
|
163
200
|
except AssertionError:
|
|
164
201
|
raise ValueError('Given value does not have the right dimension.')
|
|
165
|
-
for i in range(self.
|
|
202
|
+
for i in range(self._nbdims):
|
|
166
203
|
try:
|
|
167
|
-
assert
|
|
204
|
+
assert self.ranges[i, 0] <= state_value[i] <= self.ranges[i, 1]
|
|
168
205
|
except AssertionError:
|
|
169
206
|
raise InvalidRangeError()
|
|
170
207
|
else:
|
|
171
208
|
try:
|
|
172
|
-
assert
|
|
209
|
+
assert self.ranges[0, 0] <= state_value <= self.ranges[0, 1]
|
|
173
210
|
except AssertionError:
|
|
174
211
|
raise InvalidRangeError()
|
|
175
212
|
|
|
176
213
|
self._state_value = state_value
|
|
177
214
|
|
|
178
|
-
def get_state_value(self) -> np.
|
|
215
|
+
def get_state_value(self) -> Annotated[NDArray[np.float64], "1D Array"]:
|
|
179
216
|
return self._state_value
|
|
180
217
|
|
|
181
|
-
def get_state_value_ranges(self) -> np.
|
|
218
|
+
def get_state_value_ranges(self) -> Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]]:
|
|
182
219
|
return self._ranges
|
|
183
220
|
|
|
184
|
-
def get_state_value_range_at_dimension(self, dimension: int) -> np.
|
|
185
|
-
|
|
221
|
+
def get_state_value_range_at_dimension(self, dimension: int) -> Annotated[NDArray[np.float64], Literal["2"]]:
|
|
222
|
+
if dimension < self._nbdims:
|
|
223
|
+
return self._ranges[dimension]
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(f"There are only {self._nbdims} dimensions!")
|
|
186
226
|
|
|
187
227
|
@property
|
|
188
|
-
def ranges(self) -> np.
|
|
228
|
+
def ranges(self) -> Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]]:
|
|
189
229
|
return self.get_state_value_ranges()
|
|
190
230
|
|
|
191
231
|
@property
|
|
192
|
-
def state_value(self) -> Union[float, np.
|
|
232
|
+
def state_value(self) -> Union[float, NDArray[np.float64]]:
|
|
193
233
|
return self.get_state_value()
|
|
194
234
|
|
|
195
235
|
@state_value.setter
|
|
@@ -200,9 +240,28 @@ class ContinuousState(State):
|
|
|
200
240
|
def nbdims(self) -> int:
|
|
201
241
|
return self._nbdims
|
|
202
242
|
|
|
243
|
+
def __hash__(self):
|
|
244
|
+
return hash(tuple(self._state_value))
|
|
245
|
+
|
|
246
|
+
def __eq__(self, other: Self):
|
|
247
|
+
if self.nbdims != other.nbdims:
|
|
248
|
+
raise ValueError(f"The two states have two different dimensions. ({self.nbdims} vs. {other.nbdims})")
|
|
249
|
+
for i in range(self.nbdims):
|
|
250
|
+
if self.state_value[i] != other.state_value[i]:
|
|
251
|
+
return False
|
|
252
|
+
return True
|
|
253
|
+
|
|
203
254
|
|
|
204
255
|
class Discrete2DCartesianState(DiscreteState):
|
|
205
|
-
def __init__(
|
|
256
|
+
def __init__(
|
|
257
|
+
self,
|
|
258
|
+
x_lowlim: int,
|
|
259
|
+
x_hilim: int,
|
|
260
|
+
y_lowlim: int,
|
|
261
|
+
y_hilim: int,
|
|
262
|
+
initial_coordinate: list[int]=None,
|
|
263
|
+
terminals: Optional[dict[DiscreteStateValueType, bool]] = None
|
|
264
|
+
):
|
|
206
265
|
self._x_lowlim = x_lowlim
|
|
207
266
|
self._x_hilim = x_hilim
|
|
208
267
|
self._y_lowlim = y_lowlim
|
|
@@ -212,14 +271,50 @@ class Discrete2DCartesianState(DiscreteState):
|
|
|
212
271
|
if initial_coordinate is None:
|
|
213
272
|
initial_coordinate = [self._x_lowlim, self._y_lowlim]
|
|
214
273
|
initial_value = (initial_coordinate[1] - self._y_lowlim) * self._countx + (initial_coordinate[0] - self._x_lowlim)
|
|
215
|
-
super().__init__(list(range(self._countx*self._county)),
|
|
274
|
+
super().__init__(list(range(self._countx*self._county)), initial_value=initial_value, terminals=terminals)
|
|
216
275
|
|
|
217
276
|
def _encode_coordinates(self, x, y) -> int:
|
|
218
277
|
return (y - self._y_lowlim) * self._countx + (x - self._x_lowlim)
|
|
219
278
|
|
|
220
|
-
def encode_coordinates(self, coordinates:
|
|
221
|
-
|
|
279
|
+
def encode_coordinates(self, coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]]) -> int:
|
|
280
|
+
if isinstance(coordinates, list):
|
|
281
|
+
assert len(coordinates) == 2
|
|
222
282
|
return self._encode_coordinates(coordinates[0], coordinates[1])
|
|
223
283
|
|
|
224
|
-
def decode_coordinates(self, hashcode) ->
|
|
225
|
-
return [hashcode % self._countx, hashcode // self._countx]
|
|
284
|
+
def decode_coordinates(self, hashcode) -> list[int]:
|
|
285
|
+
return [hashcode % self._countx + self._x_lowlim, hashcode // self._countx + self._y_lowlim]
|
|
286
|
+
|
|
287
|
+
def get_whether_terminal_given_coordinates(
|
|
288
|
+
self,
|
|
289
|
+
coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]]
|
|
290
|
+
) -> bool:
|
|
291
|
+
if isinstance(coordinates, list):
|
|
292
|
+
assert len(coordinates) == 2
|
|
293
|
+
hashcode = self._encode_coordinates(coordinates[0], coordinates[1])
|
|
294
|
+
return self._terminal_dict.get(hashcode, False)
|
|
295
|
+
|
|
296
|
+
def set_terminal_given_coordinates(
|
|
297
|
+
self,
|
|
298
|
+
coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]],
|
|
299
|
+
terminal_value: bool
|
|
300
|
+
) -> None:
|
|
301
|
+
if isinstance(coordinates, list):
|
|
302
|
+
assert len(coordinates) == 2
|
|
303
|
+
hashcode = self._encode_coordinates(coordinates[0], coordinates[1])
|
|
304
|
+
self._terminal_dict[hashcode] = terminal_value
|
|
305
|
+
|
|
306
|
+
@property
|
|
307
|
+
def x_lowlim(self) -> int:
|
|
308
|
+
return self._x_lowlim
|
|
309
|
+
|
|
310
|
+
@property
|
|
311
|
+
def x_hilim(self) -> int:
|
|
312
|
+
return self._x_hilim
|
|
313
|
+
|
|
314
|
+
@property
|
|
315
|
+
def y_lowlim(self) -> int:
|
|
316
|
+
return self._y_lowlim
|
|
317
|
+
|
|
318
|
+
@property
|
|
319
|
+
def y_hilim(self) -> int:
|
|
320
|
+
return self._y_hilim
|
pyrlutils/td/__init__.py
ADDED
|
File without changes
|