pyrlutils 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyrlutils/__init__.py +0 -0
- pyrlutils/action.py +27 -0
- pyrlutils/bandit/__init__.py +0 -0
- pyrlutils/bandit/algo.py +128 -0
- pyrlutils/bandit/reward.py +12 -0
- pyrlutils/dp/__init__.py +0 -0
- pyrlutils/dp/valuefcns.py +149 -0
- pyrlutils/helpers/__init__.py +0 -0
- pyrlutils/helpers/exceptions.py +5 -0
- pyrlutils/openai/__init__.py +0 -0
- pyrlutils/openai/utils.py +31 -0
- pyrlutils/policy.py +151 -0
- pyrlutils/reward.py +37 -0
- pyrlutils/state.py +320 -0
- pyrlutils/td/__init__.py +0 -0
- pyrlutils/td/doubleqlearn.py +110 -0
- pyrlutils/td/qlearn.py +86 -0
- pyrlutils/td/sarsa.py +86 -0
- pyrlutils/td/state_td.py +111 -0
- pyrlutils/td/utils.py +258 -0
- pyrlutils/transition.py +155 -0
- pyrlutils-0.1.2.dist-info/METADATA +43 -0
- pyrlutils-0.1.2.dist-info/RECORD +26 -0
- pyrlutils-0.1.2.dist-info/WHEEL +5 -0
- pyrlutils-0.1.2.dist-info/licenses/LICENSE +19 -0
- pyrlutils-0.1.2.dist-info/top_level.txt +1 -0
pyrlutils/__init__.py
ADDED
|
File without changes
|
pyrlutils/action.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
|
|
2
|
+
from types import LambdaType, FunctionType
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from .state import State
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
DiscreteActionValueType = Union[float, str]
|
|
9
|
+
|
|
10
|
+
class Action:
|
|
11
|
+
def __init__(self, actionfunc: Union[FunctionType, LambdaType]):
|
|
12
|
+
self._actionfunc = actionfunc
|
|
13
|
+
|
|
14
|
+
def act(self, state: State, *args, **kwargs) -> State:
|
|
15
|
+
self._actionfunc(state, *args, **kwargs)
|
|
16
|
+
return state
|
|
17
|
+
|
|
18
|
+
def __call__(self, state: State) -> State:
|
|
19
|
+
return self.act(state)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def action_function(self) -> Union[FunctionType, LambdaType]:
|
|
23
|
+
return self._actionfunc
|
|
24
|
+
|
|
25
|
+
@action_function.setter
|
|
26
|
+
def action_function(self, new_func: Union[FunctionType, LambdaType]) -> None:
|
|
27
|
+
self._actionfunc = new_func
|
|
File without changes
|
pyrlutils/bandit/algo.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from .reward import IndividualBanditRewardFunction
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BanditAlgorithm(ABC):
|
|
10
|
+
def __init__(self, action_values: list, reward_function: IndividualBanditRewardFunction):
|
|
11
|
+
self._action_values = action_values
|
|
12
|
+
self._reward_function = reward_function
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def _go_one_loop(self):
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
def loop(self, nbiterations: int):
|
|
19
|
+
for _ in range(nbiterations):
|
|
20
|
+
self._go_one_loop()
|
|
21
|
+
|
|
22
|
+
def reward(self, action_value) -> float:
|
|
23
|
+
return self._reward_function(action_value)
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def get_action(self):
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def action_values(self):
|
|
31
|
+
return self._action_values
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def reward_function(self) -> IndividualBanditRewardFunction:
|
|
35
|
+
return self._reward_function
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SimpleBandit(BanditAlgorithm):
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
action_values: list,
|
|
42
|
+
reward_function: IndividualBanditRewardFunction,
|
|
43
|
+
epsilon: float=0.05
|
|
44
|
+
):
|
|
45
|
+
super().__init__(action_values, reward_function)
|
|
46
|
+
self._epsilon = epsilon
|
|
47
|
+
self._initialize()
|
|
48
|
+
|
|
49
|
+
def _initialize(self):
|
|
50
|
+
self._Q = np.zeros(len(self._action_values))
|
|
51
|
+
self._N = np.zeros(len(self._action_values), dtype=np.int32)
|
|
52
|
+
|
|
53
|
+
def _go_one_loop(self):
|
|
54
|
+
r = np.random.uniform()
|
|
55
|
+
if r < self.epsilon:
|
|
56
|
+
selected_action_idx = np.argmax(self._Q)
|
|
57
|
+
else:
|
|
58
|
+
selected_action_idx = np.random.choice(range(len(self._action_values)))
|
|
59
|
+
reward = self._reward_function(self._action_values[selected_action_idx])
|
|
60
|
+
self._N[selected_action_idx] += 1
|
|
61
|
+
self._Q[selected_action_idx] += (reward - self._Q[selected_action_idx]) / self._N[selected_action_idx]
|
|
62
|
+
|
|
63
|
+
def get_action(self):
|
|
64
|
+
selected_action_idx = np.argmax(self._Q)
|
|
65
|
+
return self._action_values[selected_action_idx]
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def epsilon(self) -> float:
|
|
69
|
+
return self._epsilon
|
|
70
|
+
|
|
71
|
+
@epsilon.setter
|
|
72
|
+
def epsilon(self, val: float):
|
|
73
|
+
self._epsilon = val
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class GradientBandit(BanditAlgorithm):
|
|
77
|
+
def __init__(self, action_values: list, reward_function: IndividualBanditRewardFunction, temperature: float=1.0, alpha: float=0.1):
|
|
78
|
+
super().__init__(action_values, reward_function)
|
|
79
|
+
self._T = temperature
|
|
80
|
+
self._alpha = alpha
|
|
81
|
+
self._initialize()
|
|
82
|
+
|
|
83
|
+
def _initialize(self):
|
|
84
|
+
self._preferences = np.zeros(len(self._action_values))
|
|
85
|
+
self._rewards_over_time = []
|
|
86
|
+
|
|
87
|
+
def _get_probs(self) -> np.ndarray:
|
|
88
|
+
# getting probabilities using softmax
|
|
89
|
+
exp_preferences = np.exp(self._preferences / self.T)
|
|
90
|
+
sum_exp_preferences = np.sum(exp_preferences)
|
|
91
|
+
return exp_preferences / sum_exp_preferences
|
|
92
|
+
|
|
93
|
+
def get_action(self):
|
|
94
|
+
selected_action_idx = np.argmax(self._preferences)
|
|
95
|
+
return self._action_values[selected_action_idx]
|
|
96
|
+
|
|
97
|
+
def _go_one_loop(self):
|
|
98
|
+
probs = self._get_probs()
|
|
99
|
+
selected_action_idx = np.random.choice(range(self._preferences.shape[0]), p=probs)
|
|
100
|
+
reward = self._reward_function(self._action_values[selected_action_idx])
|
|
101
|
+
self._rewards_over_time.append(reward)
|
|
102
|
+
average_reward = np.mean(self._rewards_over_time) if len(self._rewards_over_time) > 0 else 0.
|
|
103
|
+
|
|
104
|
+
for i in range(len(self._action_values)):
|
|
105
|
+
if i == selected_action_idx:
|
|
106
|
+
self._preferences[i] += self.alpha * (reward - average_reward) * (1 - probs[i])
|
|
107
|
+
else:
|
|
108
|
+
self._preferences[i] -= self.alpha * (reward - average_reward) * probs[i]
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def alpha(self) -> float:
|
|
112
|
+
return self._alpha
|
|
113
|
+
|
|
114
|
+
@alpha.setter
|
|
115
|
+
def alpha(self, val: float):
|
|
116
|
+
self._alpha = val
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def T(self) -> float:
|
|
120
|
+
return self._T
|
|
121
|
+
|
|
122
|
+
@T.setter
|
|
123
|
+
def T(self, val: float):
|
|
124
|
+
self._T = val
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def temperature(self) -> float:
|
|
128
|
+
return self._T
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class IndividualBanditRewardFunction(ABC):
|
|
7
|
+
@abstractmethod
|
|
8
|
+
def reward(self, action_value: Any) -> float:
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
def __call__(self, action_value: Any) -> float:
|
|
12
|
+
return self.reward(action_value)
|
pyrlutils/dp/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
|
|
2
|
+
import random
|
|
3
|
+
from copy import copy
|
|
4
|
+
from itertools import product
|
|
5
|
+
from typing import Annotated
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
|
|
10
|
+
from ..state import DiscreteStateValueType
|
|
11
|
+
from ..transition import TransitionProbabilityFactory
|
|
12
|
+
from ..policy import DiscreteDeterminsticPolicy
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OptimalPolicyOnValueFunctions:
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
discount_factor: float,
|
|
19
|
+
transprobfac: TransitionProbabilityFactory
|
|
20
|
+
):
|
|
21
|
+
try:
|
|
22
|
+
assert 0. <= discount_factor <= 1.
|
|
23
|
+
except AssertionError:
|
|
24
|
+
raise ValueError('Discount factor must be between 0 and 1.')
|
|
25
|
+
self._gamma = discount_factor
|
|
26
|
+
self._transprobfac = transprobfac
|
|
27
|
+
self._states, self._actions_dict, self._indrewardfcn = self._transprobfac.generate_mdp_objects()
|
|
28
|
+
self._state_names = self._states.get_all_possible_state_values()
|
|
29
|
+
self._states_to_indices = {state: idx for idx, state in enumerate(self._state_names)}
|
|
30
|
+
self._action_names = list(self._actions_dict.keys())
|
|
31
|
+
self._actions_to_indices = {action_value: idx for idx, action_value in enumerate(self._action_names)}
|
|
32
|
+
|
|
33
|
+
self._evaluated = False
|
|
34
|
+
self._improved = False
|
|
35
|
+
|
|
36
|
+
self._theta = 1e-10
|
|
37
|
+
self._policy_evaluation_maxiter = 10000
|
|
38
|
+
|
|
39
|
+
def _policy_evaluation(self, policy: DiscreteDeterminsticPolicy) -> Annotated[NDArray[np.float64], "1D Array"]:
|
|
40
|
+
prev_V = np.zeros(len(self._states_to_indices))
|
|
41
|
+
|
|
42
|
+
for _ in range(self._policy_evaluation_maxiter):
|
|
43
|
+
V = np.zeros(len(self._states_to_indices))
|
|
44
|
+
for state_value in self._state_names:
|
|
45
|
+
state_index = self._states_to_indices[state_value]
|
|
46
|
+
action_value = policy.get_action_value(state_value)
|
|
47
|
+
for next_state_tuple in self._transprobfac.transition_probabilities[state_value][action_value]:
|
|
48
|
+
prob = next_state_tuple.probability
|
|
49
|
+
reward = next_state_tuple.reward
|
|
50
|
+
next_state_value = next_state_tuple.next_state_value
|
|
51
|
+
next_state_index = self._states_to_indices[next_state_value]
|
|
52
|
+
terminal = next_state_tuple.terminal
|
|
53
|
+
|
|
54
|
+
V[state_index] += prob * (reward + (self._gamma*prev_V[next_state_index] if not terminal else 0.))
|
|
55
|
+
|
|
56
|
+
if np.max(np.abs(V-prev_V)) < self._theta:
|
|
57
|
+
break
|
|
58
|
+
|
|
59
|
+
prev_V = V.copy()
|
|
60
|
+
|
|
61
|
+
return V
|
|
62
|
+
|
|
63
|
+
def _policy_improvement(self, V: Annotated[NDArray[np.float64], "1D Array"]) -> DiscreteDeterminsticPolicy:
|
|
64
|
+
Q = np.zeros((len(self._states_to_indices), len(self._actions_to_indices)))
|
|
65
|
+
|
|
66
|
+
for state_value in self._state_names:
|
|
67
|
+
state_index = self._states_to_indices[state_value]
|
|
68
|
+
for action_value in self._action_names:
|
|
69
|
+
action_index = self._actions_to_indices[action_value]
|
|
70
|
+
for next_state_tuple in self._transprobfac.transition_probabilities[state_value][action_value]:
|
|
71
|
+
prob = next_state_tuple.probability
|
|
72
|
+
reward = next_state_tuple.reward
|
|
73
|
+
next_state_value = next_state_tuple.next_state_value
|
|
74
|
+
next_state_index = self._states_to_indices[next_state_value]
|
|
75
|
+
terminal = next_state_tuple.terminal
|
|
76
|
+
|
|
77
|
+
Q[state_index, action_index] += prob * (reward + (self._gamma*V[next_state_index] if not terminal else 0.))
|
|
78
|
+
|
|
79
|
+
optimal_policy = DiscreteDeterminsticPolicy(self._actions_dict)
|
|
80
|
+
optimal_action_indices = np.argmax(Q, axis=1)
|
|
81
|
+
for state_value, action_index in zip(self._state_names, optimal_action_indices):
|
|
82
|
+
action_value = self._action_names[action_index]
|
|
83
|
+
optimal_policy.add_deterministic_rule(state_value, action_value)
|
|
84
|
+
return optimal_policy
|
|
85
|
+
|
|
86
|
+
def _policy_iteration(self) -> tuple[Annotated[NDArray[np.float64], "1D Array"], DiscreteDeterminsticPolicy]:
|
|
87
|
+
policy = DiscreteDeterminsticPolicy(self._actions_dict)
|
|
88
|
+
for state_value in self._state_names:
|
|
89
|
+
policy.add_deterministic_rule(state_value, random.choice(self._action_names))
|
|
90
|
+
V = None
|
|
91
|
+
|
|
92
|
+
done = False
|
|
93
|
+
while not done:
|
|
94
|
+
old_policy = copy(policy)
|
|
95
|
+
|
|
96
|
+
V = self._policy_evaluation(policy)
|
|
97
|
+
policy = self._policy_improvement(V)
|
|
98
|
+
|
|
99
|
+
if policy == old_policy:
|
|
100
|
+
done = True
|
|
101
|
+
|
|
102
|
+
return V, policy
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _value_iteration(self) -> tuple[Annotated[NDArray[np.float64], "1D Array"], DiscreteDeterminsticPolicy]:
|
|
106
|
+
V = np.zeros(len(self._state_names))
|
|
107
|
+
|
|
108
|
+
for _ in range(self._policy_evaluation_maxiter):
|
|
109
|
+
Q = np.zeros((len(self._state_names), len(self._action_names)))
|
|
110
|
+
for state_value, action_value in product(self._state_names, self._action_names):
|
|
111
|
+
state_index = self._states_to_indices[state_value]
|
|
112
|
+
action_index = self._actions_to_indices[action_value]
|
|
113
|
+
for next_state_tuple in self._transprobfac.transition_probabilities[state_value][action_value]:
|
|
114
|
+
prob = next_state_tuple.probability
|
|
115
|
+
reward = next_state_tuple.reward
|
|
116
|
+
next_state_value = next_state_tuple.next_state_value
|
|
117
|
+
next_state_index = self._states_to_indices[next_state_value]
|
|
118
|
+
terminal = next_state_tuple.terminal
|
|
119
|
+
|
|
120
|
+
Q[state_index, action_index] += prob * (reward + (self._gamma * V[next_state_index] if not terminal else 0.))
|
|
121
|
+
|
|
122
|
+
if np.max(np.abs(V-np.max(Q, axis=1))) < self._theta:
|
|
123
|
+
break
|
|
124
|
+
|
|
125
|
+
V = np.max(Q, axis=1)
|
|
126
|
+
|
|
127
|
+
Qmaxj = np.argmax(Q, axis=1)
|
|
128
|
+
|
|
129
|
+
policy = DiscreteDeterminsticPolicy(self._actions_dict)
|
|
130
|
+
for state_value, action_index in zip(self._state_names, Qmaxj):
|
|
131
|
+
policy.add_deterministic_rule(state_value, self._action_names[action_index])
|
|
132
|
+
|
|
133
|
+
return V, policy
|
|
134
|
+
|
|
135
|
+
def policy_iteration(self) -> tuple[dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
|
|
136
|
+
V, policy = self._policy_iteration()
|
|
137
|
+
state_values_dict = {
|
|
138
|
+
self._state_names[i]: V[i]
|
|
139
|
+
for i in range(V.shape[0])
|
|
140
|
+
}
|
|
141
|
+
return state_values_dict, policy
|
|
142
|
+
|
|
143
|
+
def value_iteration(self) -> tuple[dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
|
|
144
|
+
V, policy = self._value_iteration()
|
|
145
|
+
state_values_dict = {
|
|
146
|
+
self._state_names[i]: V[i]
|
|
147
|
+
for i in range(V.shape[0])
|
|
148
|
+
}
|
|
149
|
+
return state_values_dict, policy
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
|
|
2
|
+
import gymnasium as gym
|
|
3
|
+
|
|
4
|
+
from ..transition import TransitionProbabilityFactory, NextStateTuple
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory(TransitionProbabilityFactory):
|
|
8
|
+
def __init__(self, envname: str):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self._envname = envname
|
|
11
|
+
self._gymenv = gym.make(envname)
|
|
12
|
+
self._convert_openai_gymenv_to_transprob()
|
|
13
|
+
|
|
14
|
+
def _convert_openai_gymenv_to_transprob(self):
|
|
15
|
+
P = self._gymenv.env.env.env.P
|
|
16
|
+
for state_value, trans_dict in P.items():
|
|
17
|
+
new_trans_dict = {}
|
|
18
|
+
for action_value, next_state_list in trans_dict.items():
|
|
19
|
+
new_trans_dict[action_value] = [
|
|
20
|
+
NextStateTuple(next_state[1], next_state[0], next_state[2], next_state[3])
|
|
21
|
+
for next_state in next_state_list
|
|
22
|
+
]
|
|
23
|
+
self.add_state_transitions(state_value, new_trans_dict)
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def envname(self) -> str:
|
|
27
|
+
return self._envname
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def gymenv(self) -> gym.Env:
|
|
31
|
+
return self._gymenv
|
pyrlutils/policy.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Union, Annotated
|
|
4
|
+
from warnings import warn
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from .state import State, DiscreteState, DiscreteStateValueType
|
|
10
|
+
from .action import Action, DiscreteActionValueType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Policy(ABC):
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def get_action(self, state: State) -> Action:
|
|
16
|
+
raise NotImplemented()
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def get_action_value(self, state: State) -> DiscreteActionValueType:
|
|
20
|
+
raise NotImplemented()
|
|
21
|
+
|
|
22
|
+
def __call__(self, state: State) -> Action:
|
|
23
|
+
return self.get_action(state)
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def is_stochastic(self) -> bool:
|
|
27
|
+
raise NotImplemented()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DeterministicPolicy(Policy):
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def add_deterministic_rule(self, *args, **kwargs):
|
|
33
|
+
raise NotImplemented()
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def is_stochastic(self) -> bool:
|
|
37
|
+
return False
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class DiscreteDeterminsticPolicy(DeterministicPolicy):
|
|
41
|
+
def __init__(self, actions_dict: dict[DiscreteActionValueType, Action]):
|
|
42
|
+
self._state_to_action = {}
|
|
43
|
+
self._actions_dict = actions_dict
|
|
44
|
+
|
|
45
|
+
def add_deterministic_rule(
|
|
46
|
+
self,
|
|
47
|
+
state_value: DiscreteStateValueType,
|
|
48
|
+
action_value: DiscreteActionValueType
|
|
49
|
+
) -> None:
|
|
50
|
+
if state_value in self._state_to_action:
|
|
51
|
+
warn('State value {} exists in rule; it will be replaced.'.format(state_value))
|
|
52
|
+
self._state_to_action[state_value] = action_value
|
|
53
|
+
|
|
54
|
+
def get_action_value(
|
|
55
|
+
self,
|
|
56
|
+
state_value: DiscreteStateValueType
|
|
57
|
+
) -> DiscreteActionValueType:
|
|
58
|
+
return self._state_to_action.get(state_value)
|
|
59
|
+
|
|
60
|
+
def get_action(self, state: DiscreteState) -> Action:
|
|
61
|
+
return self._actions_dict[self.get_action_value(state.state_value)]
|
|
62
|
+
|
|
63
|
+
def __eq__(self, other) -> bool:
|
|
64
|
+
if len(self._state_to_action) != len(set(self._state_to_action.keys()).union(other._state_to_action.keys())):
|
|
65
|
+
return False
|
|
66
|
+
if len(self._actions_dict) != len(set(self._actions_dict.keys()).union(other._actions_dict.keys())):
|
|
67
|
+
return False
|
|
68
|
+
for action in self._actions_dict.keys():
|
|
69
|
+
if self._actions_dict[action] != other._actions_dict[action]:
|
|
70
|
+
return False
|
|
71
|
+
for state in self._state_to_action.keys():
|
|
72
|
+
if self._state_to_action[state] != other._state_to_action[state]:
|
|
73
|
+
return False
|
|
74
|
+
return True
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class DiscreteContinuousPolicy(DeterministicPolicy):
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def get_action(self, state: State) -> Action:
|
|
80
|
+
raise NotImplemented()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class StochasticPolicy(Policy):
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def get_probability(self, *args, **kwargs) -> float:
|
|
86
|
+
raise NotImplemented()
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def is_stochastic(self) -> bool:
|
|
90
|
+
return True
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class DiscreteStochasticPolicy(StochasticPolicy):
|
|
94
|
+
def __init__(self, actions_dict: dict[DiscreteActionValueType, Action]):
|
|
95
|
+
self._state_to_action = {}
|
|
96
|
+
self._actions_dict = actions_dict
|
|
97
|
+
|
|
98
|
+
def add_stochastic_rule(
|
|
99
|
+
self,
|
|
100
|
+
state_value: DiscreteStateValueType,
|
|
101
|
+
action_values: list[DiscreteActionValueType],
|
|
102
|
+
probs: Union[list[float], Annotated[NDArray[np.float64], "1D Array"]] = None
|
|
103
|
+
):
|
|
104
|
+
if probs is not None:
|
|
105
|
+
assert len(action_values) == len(probs)
|
|
106
|
+
probs = np.array(probs)
|
|
107
|
+
else:
|
|
108
|
+
probs = np.repeat(1./len(action_values), len(action_values))
|
|
109
|
+
|
|
110
|
+
if state_value in self._state_to_action:
|
|
111
|
+
warn('State value {} exists in rule; it will be replaced.'.format(state_value))
|
|
112
|
+
self._state_to_action[state_value] = {
|
|
113
|
+
action_value: prob
|
|
114
|
+
for action_value, prob in zip(action_values, probs)
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
def get_probability(
|
|
118
|
+
self,
|
|
119
|
+
state_value: DiscreteStateValueType,
|
|
120
|
+
action_value: DiscreteActionValueType
|
|
121
|
+
) -> float:
|
|
122
|
+
if state_value not in self._state_to_action:
|
|
123
|
+
return 0.0
|
|
124
|
+
if action_value in self._state_to_action[state_value]:
|
|
125
|
+
return self._state_to_action[state_value][action_value]
|
|
126
|
+
else:
|
|
127
|
+
return 0.0
|
|
128
|
+
|
|
129
|
+
def get_action_value(self, state: State) -> DiscreteActionValueType:
|
|
130
|
+
allowed_actions = list(self._state_to_action[state].keys())
|
|
131
|
+
probs = np.array(list(self._state_to_action[state].values()))
|
|
132
|
+
sumprobs = np.sum(probs)
|
|
133
|
+
return np.random.choice(allowed_actions, p=probs/sumprobs)
|
|
134
|
+
|
|
135
|
+
def get_action(self, state: DiscreteState) -> Action:
|
|
136
|
+
return self._actions_dict[self.get_action_value(state.state_value)]
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class ContinuousStochasticPolicy(StochasticPolicy):
|
|
140
|
+
@abstractmethod
|
|
141
|
+
def get_probability(
|
|
142
|
+
self,
|
|
143
|
+
state_value: Union[float, Annotated[NDArray[np.float64], "1D Array"]],
|
|
144
|
+
action_value: DiscreteActionValueType,
|
|
145
|
+
value: Union[float, Annotated[NDArray[np.float64], "1D Array"]]
|
|
146
|
+
) -> float:
|
|
147
|
+
raise NotImplemented()
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
DiscretePolicy = Union[DiscreteDeterminsticPolicy, DiscreteStochasticPolicy]
|
|
151
|
+
ContinuousPolicy = Union[ContinuousStochasticPolicy]
|
pyrlutils/reward.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class IndividualRewardFunction(ABC):
|
|
6
|
+
@abstractmethod
|
|
7
|
+
def reward(self, state_value, action_value, next_state_value) -> float:
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
def __call__(self, state_value, action_value, next_state_value) -> float:
|
|
11
|
+
return self.reward(state_value, action_value, next_state_value)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RewardFunction(ABC):
|
|
15
|
+
def __init__(self, discount_factor: float, individual_reward_function: IndividualRewardFunction):
|
|
16
|
+
self._discount_factor = discount_factor
|
|
17
|
+
self._individual_reward_function = individual_reward_function
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def discount_factor(self) -> float:
|
|
21
|
+
return self._discount_factor
|
|
22
|
+
|
|
23
|
+
@discount_factor.setter
|
|
24
|
+
def discount_factor(self, discount_factor: float):
|
|
25
|
+
self._discount_factor = discount_factor
|
|
26
|
+
|
|
27
|
+
def individual_reward(self, state_value, action_value, next_state_value) -> float:
|
|
28
|
+
return self._individual_reward_function(state_value, action_value, next_state_value)
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def total_reward(self, state_value, action_value) -> float:
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
def __call__(self, state_value, action_value) -> float:
|
|
35
|
+
return self.total_reward(state_value, action_value)
|
|
36
|
+
|
|
37
|
+
|