pyrlutils 0.0.4__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyrlutils might be problematic. Click here for more details.

pyrlutils/action.py CHANGED
@@ -1,5 +1,5 @@
1
1
 
2
- from types import LambdaType
2
+ from types import LambdaType, FunctionType
3
3
  from typing import Union
4
4
 
5
5
  from .state import State
@@ -8,7 +8,7 @@ from .state import State
8
8
  DiscreteActionValueType = Union[float, str]
9
9
 
10
10
  class Action:
11
- def __init__(self, actionfunc: LambdaType):
11
+ def __init__(self, actionfunc: Union[FunctionType, LambdaType]):
12
12
  self._actionfunc = actionfunc
13
13
 
14
14
  def act(self, state: State, *args, **kwargs) -> State:
@@ -17,3 +17,11 @@ class Action:
17
17
 
18
18
  def __call__(self, state: State) -> State:
19
19
  return self.act(state)
20
+
21
+ @property
22
+ def action_function(self) -> Union[FunctionType, LambdaType]:
23
+ return self._actionfunc
24
+
25
+ @action_function.setter
26
+ def action_function(self, new_func: Union[FunctionType, LambdaType]) -> None:
27
+ self._actionfunc = new_func
@@ -1,11 +1,12 @@
1
1
 
2
2
  from abc import ABC, abstractmethod
3
+ from typing import Any
3
4
 
4
5
 
5
6
  class IndividualBanditRewardFunction(ABC):
6
7
  @abstractmethod
7
- def reward(self, action_value) -> float:
8
+ def reward(self, action_value: Any) -> float:
8
9
  pass
9
10
 
10
- def __call__(self, action_value) -> float:
11
+ def __call__(self, action_value: Any) -> float:
11
12
  return self.reward(action_value)
File without changes
@@ -1,18 +1,23 @@
1
1
 
2
2
  import random
3
3
  from copy import copy
4
- from typing import Tuple, Dict
5
4
  from itertools import product
5
+ from typing import Annotated
6
6
 
7
7
  import numpy as np
8
+ from numpy.typing import NDArray
8
9
 
9
- from .state import DiscreteStateValueType
10
- from .transition import TransitionProbabilityFactory
11
- from .policy import DiscreteDeterminsticPolicy
10
+ from ..state import DiscreteStateValueType
11
+ from ..transition import TransitionProbabilityFactory
12
+ from ..policy import DiscreteDeterminsticPolicy
12
13
 
13
14
 
14
15
  class OptimalPolicyOnValueFunctions:
15
- def __init__(self, discount_factor: float, transprobfac: TransitionProbabilityFactory):
16
+ def __init__(
17
+ self,
18
+ discount_factor: float,
19
+ transprobfac: TransitionProbabilityFactory
20
+ ):
16
21
  try:
17
22
  assert 0. <= discount_factor <= 1.
18
23
  except AssertionError:
@@ -31,7 +36,7 @@ class OptimalPolicyOnValueFunctions:
31
36
  self._theta = 1e-10
32
37
  self._policy_evaluation_maxiter = 10000
33
38
 
34
- def _policy_evaluation(self, policy: DiscreteDeterminsticPolicy) -> np.ndarray:
39
+ def _policy_evaluation(self, policy: DiscreteDeterminsticPolicy) -> Annotated[NDArray[np.float64], "1D Array"]:
35
40
  prev_V = np.zeros(len(self._states_to_indices))
36
41
 
37
42
  for _ in range(self._policy_evaluation_maxiter):
@@ -55,7 +60,7 @@ class OptimalPolicyOnValueFunctions:
55
60
 
56
61
  return V
57
62
 
58
- def _policy_improvement(self, V: np.ndarray) -> DiscreteDeterminsticPolicy:
63
+ def _policy_improvement(self, V: Annotated[NDArray[np.float64], "1D Array"]) -> DiscreteDeterminsticPolicy:
59
64
  Q = np.zeros((len(self._states_to_indices), len(self._actions_to_indices)))
60
65
 
61
66
  for state_value in self._state_names:
@@ -78,7 +83,7 @@ class OptimalPolicyOnValueFunctions:
78
83
  optimal_policy.add_deterministic_rule(state_value, action_value)
79
84
  return optimal_policy
80
85
 
81
- def _policy_iteration(self) -> Tuple[np.ndarray, DiscreteDeterminsticPolicy]:
86
+ def _policy_iteration(self) -> tuple[Annotated[NDArray[np.float64], "1D Array"], DiscreteDeterminsticPolicy]:
82
87
  policy = DiscreteDeterminsticPolicy(self._actions_dict)
83
88
  for state_value in self._state_names:
84
89
  policy.add_deterministic_rule(state_value, random.choice(self._action_names))
@@ -97,7 +102,7 @@ class OptimalPolicyOnValueFunctions:
97
102
  return V, policy
98
103
 
99
104
 
100
- def _value_iteration(self) -> Tuple[np.ndarray, DiscreteDeterminsticPolicy]:
105
+ def _value_iteration(self) -> tuple[Annotated[NDArray[np.float64], "1D Array"], DiscreteDeterminsticPolicy]:
101
106
  V = np.zeros(len(self._state_names))
102
107
 
103
108
  for _ in range(self._policy_evaluation_maxiter):
@@ -127,7 +132,7 @@ class OptimalPolicyOnValueFunctions:
127
132
 
128
133
  return V, policy
129
134
 
130
- def policy_iteration(self) -> Tuple[Dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
135
+ def policy_iteration(self) -> tuple[dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
131
136
  V, policy = self._policy_iteration()
132
137
  state_values_dict = {
133
138
  self._state_names[i]: V[i]
@@ -135,7 +140,7 @@ class OptimalPolicyOnValueFunctions:
135
140
  }
136
141
  return state_values_dict, policy
137
142
 
138
- def value_iteration(self) -> Tuple[Dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
143
+ def value_iteration(self) -> tuple[dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
139
144
  V, policy = self._value_iteration()
140
145
  state_values_dict = {
141
146
  self._state_names[i]: V[i]
File without changes
@@ -0,0 +1,5 @@
1
+
2
+ class InvalidRangeError(Exception):
3
+ def __init__(self, message=None):
4
+ self.message = "Invalid range error!" if message is None else message
5
+ super().__init__(self.message)
pyrlutils/openai/utils.py CHANGED
@@ -5,7 +5,7 @@ from ..transition import TransitionProbabilityFactory, NextStateTuple
5
5
 
6
6
 
7
7
  class OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory(TransitionProbabilityFactory):
8
- def __init__(self, envname):
8
+ def __init__(self, envname: str):
9
9
  super().__init__()
10
10
  self._envname = envname
11
11
  self._gymenv = gym.make(envname)
@@ -23,9 +23,9 @@ class OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory(TransitionProbabi
23
23
  self.add_state_transitions(state_value, new_trans_dict)
24
24
 
25
25
  @property
26
- def envname(self):
26
+ def envname(self) -> str:
27
27
  return self._envname
28
28
 
29
29
  @property
30
- def gymenv(self):
30
+ def gymenv(self) -> gym.Env:
31
31
  return self._gymenv
pyrlutils/policy.py CHANGED
@@ -1,9 +1,10 @@
1
1
 
2
2
  from abc import ABC, abstractmethod
3
- from typing import Union, Dict
3
+ from typing import Union, Annotated
4
4
  from warnings import warn
5
5
 
6
6
  import numpy as np
7
+ from numpy.typing import NDArray
7
8
 
8
9
  from .state import State, DiscreteState, DiscreteStateValueType
9
10
  from .action import Action, DiscreteActionValueType
@@ -12,7 +13,11 @@ from .action import Action, DiscreteActionValueType
12
13
  class Policy(ABC):
13
14
  @abstractmethod
14
15
  def get_action(self, state: State) -> Action:
15
- pass
16
+ raise NotImplemented()
17
+
18
+ @abstractmethod
19
+ def get_action_value(self, state: State) -> DiscreteActionValueType:
20
+ raise NotImplemented()
16
21
 
17
22
  def __call__(self, state: State) -> Action:
18
23
  return self.get_action(state)
@@ -25,7 +30,7 @@ class Policy(ABC):
25
30
  class DeterministicPolicy(Policy):
26
31
  @abstractmethod
27
32
  def add_deterministic_rule(self, *args, **kwargs):
28
- pass
33
+ raise NotImplemented()
29
34
 
30
35
  @property
31
36
  def is_stochastic(self) -> bool:
@@ -33,16 +38,23 @@ class DeterministicPolicy(Policy):
33
38
 
34
39
 
35
40
  class DiscreteDeterminsticPolicy(DeterministicPolicy):
36
- def __init__(self, actions_dict: Dict[DiscreteActionValueType, Action]):
41
+ def __init__(self, actions_dict: dict[DiscreteActionValueType, Action]):
37
42
  self._state_to_action = {}
38
43
  self._actions_dict = actions_dict
39
44
 
40
- def add_deterministic_rule(self, state_value: DiscreteStateValueType, action_value: DiscreteActionValueType):
45
+ def add_deterministic_rule(
46
+ self,
47
+ state_value: DiscreteStateValueType,
48
+ action_value: DiscreteActionValueType
49
+ ) -> None:
41
50
  if state_value in self._state_to_action:
42
51
  warn('State value {} exists in rule; it will be replaced.'.format(state_value))
43
52
  self._state_to_action[state_value] = action_value
44
53
 
45
- def get_action_value(self, state_value: DiscreteStateValueType) -> DiscreteActionValueType:
54
+ def get_action_value(
55
+ self,
56
+ state_value: DiscreteStateValueType
57
+ ) -> DiscreteActionValueType:
46
58
  return self._state_to_action.get(state_value)
47
59
 
48
60
  def get_action(self, state: DiscreteState) -> Action:
@@ -62,10 +74,16 @@ class DiscreteDeterminsticPolicy(DeterministicPolicy):
62
74
  return True
63
75
 
64
76
 
77
+ class DiscreteContinuousPolicy(DeterministicPolicy):
78
+ @abstractmethod
79
+ def get_action(self, state: State) -> Action:
80
+ raise NotImplemented()
81
+
82
+
65
83
  class StochasticPolicy(Policy):
66
84
  @abstractmethod
67
85
  def get_probability(self, *args, **kwargs) -> float:
68
- pass
86
+ raise NotImplemented()
69
87
 
70
88
  @property
71
89
  def is_stochastic(self) -> bool:
@@ -73,12 +91,61 @@ class StochasticPolicy(Policy):
73
91
 
74
92
 
75
93
  class DiscreteStochasticPolicy(StochasticPolicy):
76
- @abstractmethod
77
- def get_probability(self, state_value: DiscreteStateValueType, action_value: DiscreteActionValueType) -> float:
78
- pass
94
+ def __init__(self, actions_dict: dict[DiscreteActionValueType, Action]):
95
+ self._state_to_action = {}
96
+ self._actions_dict = actions_dict
97
+
98
+ def add_stochastic_rule(
99
+ self,
100
+ state_value: DiscreteStateValueType,
101
+ action_values: list[DiscreteActionValueType],
102
+ probs: Union[list[float], Annotated[NDArray[np.float64], "1D Array"]] = None
103
+ ):
104
+ if probs is not None:
105
+ assert len(action_values) == len(probs)
106
+ probs = np.array(probs)
107
+ else:
108
+ probs = np.repeat(1./len(action_values), len(action_values))
109
+
110
+ if state_value in self._state_to_action:
111
+ warn('State value {} exists in rule; it will be replaced.'.format(state_value))
112
+ self._state_to_action[state_value] = {
113
+ action_value: prob
114
+ for action_value, prob in zip(action_values, probs)
115
+ }
116
+
117
+ def get_probability(
118
+ self,
119
+ state_value: DiscreteStateValueType,
120
+ action_value: DiscreteActionValueType
121
+ ) -> float:
122
+ if state_value not in self._state_to_action:
123
+ return 0.0
124
+ if action_value in self._state_to_action[state_value]:
125
+ return self._state_to_action[state_value][action_value]
126
+ else:
127
+ return 0.0
128
+
129
+ def get_action_value(self, state: State) -> DiscreteActionValueType:
130
+ allowed_actions = list(self._state_to_action[state].keys())
131
+ probs = np.array(list(self._state_to_action[state].values()))
132
+ sumprobs = np.sum(probs)
133
+ return np.random.choice(allowed_actions, p=probs/sumprobs)
134
+
135
+ def get_action(self, state: DiscreteState) -> Action:
136
+ return self._actions_dict[self.get_action_value(state.state_value)]
79
137
 
80
138
 
81
139
  class ContinuousStochasticPolicy(StochasticPolicy):
82
140
  @abstractmethod
83
- def get_probability(self, state_value: Union[float, np.ndarray], action_value: DiscreteActionValueType, value: Union[float, np.ndarray]) -> float:
84
- pass
141
+ def get_probability(
142
+ self,
143
+ state_value: Union[float, Annotated[NDArray[np.float64], "1D Array"]],
144
+ action_value: DiscreteActionValueType,
145
+ value: Union[float, Annotated[NDArray[np.float64], "1D Array"]]
146
+ ) -> float:
147
+ raise NotImplemented()
148
+
149
+
150
+ DiscretePolicy = Union[DiscreteDeterminsticPolicy, DiscreteStochasticPolicy]
151
+ ContinuousPolicy = Union[ContinuousStochasticPolicy]
pyrlutils/state.py CHANGED
@@ -1,81 +1,84 @@
1
1
 
2
- from abc import ABC, abstractmethod
2
+ import sys
3
+ from abc import ABC
3
4
  from enum import Enum
4
- from dataclasses import dataclass
5
- from typing import Tuple, List, Optional, Union
5
+ from typing import Optional, Union, Annotated, Literal
6
6
 
7
7
  import numpy as np
8
+ from numpy.typing import NDArray
9
+ if sys.version_info < (3, 11):
10
+ from typing_extensions import Self
11
+ else:
12
+ from typing import Self
8
13
 
9
-
10
- class StateValue(ABC):
11
- @property
12
- @abstractmethod
13
- def value(self):
14
- pass
15
-
16
-
17
- @dataclass
18
- class DiscreteStateValue(StateValue):
19
- enum: Enum
20
-
21
- @property
22
- def value(self):
23
- return self.enum.value
24
-
25
- def name(self):
26
- return self.enum.name
27
-
28
-
29
- class ContinuousStateValue(StateValue):
30
- _value: float
31
-
32
- @property
33
- def value(self) -> float:
34
- return self._value
14
+ from .helpers.exceptions import InvalidRangeError
35
15
 
36
16
 
37
17
  class State(ABC):
38
18
  @property
39
19
  def state_value(self):
40
- return self.get_state_value()
41
-
42
- @abstractmethod
43
- def set_state_value(self, state_value):
44
- pass
20
+ raise NotImplemented()
45
21
 
46
- @abstractmethod
47
- def get_state_value(self):
48
- pass
49
-
50
- @state_value.setter
51
- def state_value(self, new_state_value):
52
- self.set_state_value(new_state_value)
53
22
 
54
-
55
- DiscreteStateValueType = Union[float, str, Tuple[int], Enum]
23
+ DiscreteStateValueType = Union[str, int, tuple[int], Enum]
56
24
 
57
25
 
58
26
  class DiscreteState(State):
59
- def __init__(self, all_state_values: List[DiscreteStateValueType], initial_values: Optional[List[DiscreteStateValueType]] = None):
27
+ def __init__(
28
+ self,
29
+ all_state_values: list[DiscreteStateValueType],
30
+ initial_value: Optional[DiscreteStateValueType] = None,
31
+ terminals: Optional[dict[DiscreteStateValueType, bool]]=None
32
+ ):
60
33
  super().__init__()
61
34
  self._all_state_values = all_state_values
62
- self._state_value = initial_values if initial_values is not None and initial_values in self._all_state_values else self._all_state_values[0]
35
+ self._state_values_to_indices = {
36
+ state_value: idx
37
+ for idx, state_value in enumerate(self._all_state_values)
38
+ }
39
+ if initial_value is not None:
40
+ self._current_index = self._state_values_to_indices[initial_value]
41
+ else:
42
+ self._current_index = 0
43
+ if terminals is None:
44
+ self._terminal_dict = {
45
+ state_value: False
46
+ for state_value in self._all_state_values
47
+ }
48
+ else:
49
+ self._terminal_dict = terminals.copy()
50
+ for state_value in self._all_state_values:
51
+ if self._terminal_dict.get(state_value) is None:
52
+ self._terminal_dict[state_value] = False
53
+
54
+ def _get_state_value_from_index(self, index: int) -> DiscreteStateValueType:
55
+ return self._all_state_values[index]
63
56
 
64
57
  def get_state_value(self) -> DiscreteStateValueType:
65
- return self._state_value
58
+ return self._get_state_value_from_index(self._current_index)
66
59
 
67
- def set_state_value(self, state_value: DiscreteStateValueType):
60
+ def set_state_value(self, state_value: DiscreteStateValueType) -> None:
68
61
  if state_value in self._all_state_values:
69
- self._state_value = state_value
62
+ self._current_index = self._state_values_to_indices[state_value]
70
63
  else:
71
64
  raise ValueError('State value {} is invalid.'.format(state_value))
72
65
 
73
- def get_all_possible_state_values(self) -> List[DiscreteStateValueType]:
66
+ def get_all_possible_state_values(self) -> list[DiscreteStateValueType]:
74
67
  return self._all_state_values
75
68
 
69
+ @property
70
+ def state_index(self) -> int:
71
+ return self._current_index
72
+
73
+ @state_index.setter
74
+ def state_index(self, new_index: int) -> None:
75
+ if new_index >= len(self._all_state_values):
76
+ raise ValueError(f"Invalid index {new_index}; it must be less than {len(self._all_state_values)}.")
77
+ self._current_index = new_index
78
+
76
79
  @property
77
80
  def state_value(self) -> DiscreteStateValueType:
78
- return self._state_value
81
+ return self._all_state_values[self._current_index]
79
82
 
80
83
  @state_value.setter
81
84
  def state_value(self, new_state_value: DiscreteStateValueType):
@@ -85,22 +88,53 @@ class DiscreteState(State):
85
88
  def state_space_size(self):
86
89
  return len(self._all_state_values)
87
90
 
91
+ @property
92
+ def nb_state_values(self) -> int:
93
+ return len(self._all_state_values)
94
+
95
+ @property
96
+ def is_terminal(self) -> bool:
97
+ return self._terminal_dict[self._all_state_values[self._current_index]]
88
98
 
89
- class InvalidRangeError(Exception):
90
- def __init__(self, message=None):
91
- self.message = "Invalid range error!" if message is None else message
92
- super().__init__(self.message)
99
+ def __hash__(self):
100
+ return self._current_index
101
+
102
+ def __eq__(self, other: Self) -> bool:
103
+ return self._current_index == other._current_index
93
104
 
94
105
 
95
106
  class ContinuousState(State):
96
- def __init__(self, nbdims: int, ranges: np.array, init_value: Optional[Union[float, np.ndarray]] = None):
107
+ def __init__(
108
+ self,
109
+ nbdims: int,
110
+ ranges: Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]],
111
+ init_value: Optional[Union[float, Annotated[NDArray[np.float64], "1D Array"]]] = None
112
+ ):
113
+ super().__init__()
97
114
  self._nbdims = nbdims
98
115
 
116
+ try:
117
+ assert isinstance(ranges, np.ndarray)
118
+ except AssertionError:
119
+ raise TypeError('Range must be a numpy array.')
120
+
99
121
  try:
100
122
  assert (ranges.dtype == np.float64) or (ranges.dtype == np.float32) or (ranges.dtype == np.float16)
101
123
  except AssertionError:
102
124
  raise TypeError('It has to be floating type numpy.ndarray.')
103
125
 
126
+ try:
127
+ assert ranges.ndim == 1 or ranges.ndim == 2
128
+ match ranges.ndim:
129
+ case 1:
130
+ assert ranges.shape[0] == 2
131
+ case 2:
132
+ assert ranges.shape[1] == 2
133
+ case _:
134
+ raise ValueError("Ranges must be of shape (2, ) or (*, 2).")
135
+ except AssertionError:
136
+ raise ValueError("Ranges must be of shape (2, ) or (*, 2).")
137
+
104
138
  try:
105
139
  assert self._nbdims > 0
106
140
  except AssertionError:
@@ -146,50 +180,53 @@ class ContinuousState(State):
146
180
  raise ValueError('Initialized value does not have the right dimension.')
147
181
  for i in range(self._nbdims):
148
182
  try:
149
- assert (init_value[i] >= self._ranges[i, 0]) and (init_value[i] <= self.ranges[i, 1])
183
+ assert self._ranges[i, 0] <= init_value[i] <= self.ranges[i, 1]
150
184
  except AssertionError:
151
185
  raise InvalidRangeError('Initialized value at dimension {} (value: {}) is not within the permitted range ({} -> {})!'.format(i, init_value[i], self._ranges[i, 0], self._ranges[i, 1]))
152
186
  else:
153
187
  try:
154
- assert (init_value >= self._ranges[0, 0]) and (init_value <= self.ranges[0, 1])
188
+ assert self._ranges[0, 0] <= init_value <= self.ranges[0, 1]
155
189
  except AssertionError:
156
190
  raise InvalidRangeError('Initialized value is out of range.')
157
191
  self._state_value = init_value
158
192
 
159
- def set_state_value(self, state_value: Union[float, np.ndarray]):
160
- if self.nbdims > 1:
193
+ def set_state_value(self, state_value: Union[float, Annotated[NDArray[np.float64], "1D Array"]]):
194
+ if self._nbdims > 1:
161
195
  try:
162
196
  assert state_value.shape[0] == self._nbdims
163
197
  except AssertionError:
164
198
  raise ValueError('Given value does not have the right dimension.')
165
- for i in range(self.nbdims):
199
+ for i in range(self._nbdims):
166
200
  try:
167
- assert state_value[i] >= self.ranges[i, 0] and state_value[i] <= self.ranges[i, 1]
201
+ assert self.ranges[i, 0] <= state_value[i] <= self.ranges[i, 1]
168
202
  except AssertionError:
169
203
  raise InvalidRangeError()
170
204
  else:
171
205
  try:
172
- assert state_value >= self.ranges[0, 0] and state_value <= self.ranges[0, 1]
206
+ assert self.ranges[0, 0] <= state_value <= self.ranges[0, 1]
173
207
  except AssertionError:
174
208
  raise InvalidRangeError()
175
209
 
176
210
  self._state_value = state_value
177
211
 
178
- def get_state_value(self) -> np.ndarray:
212
+ def get_state_value(self) -> Annotated[NDArray[np.float64], "1D Array"]:
179
213
  return self._state_value
180
214
 
181
- def get_state_value_ranges(self) -> np.ndarray:
215
+ def get_state_value_ranges(self) -> Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]]:
182
216
  return self._ranges
183
217
 
184
- def get_state_value_range_at_dimension(self, dimension: int) -> np.ndarray:
185
- return self._ranges[dimension]
218
+ def get_state_value_range_at_dimension(self, dimension: int) -> Annotated[NDArray[np.float64], Literal["2"]]:
219
+ if dimension < self._nbdims:
220
+ return self._ranges[dimension]
221
+ else:
222
+ raise ValueError(f"There are only {self._nbdims} dimensions!")
186
223
 
187
224
  @property
188
- def ranges(self) -> np.ndarray:
225
+ def ranges(self) -> Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]]:
189
226
  return self.get_state_value_ranges()
190
227
 
191
228
  @property
192
- def state_value(self) -> Union[float, np.ndarray]:
229
+ def state_value(self) -> Union[float, NDArray[np.float64]]:
193
230
  return self.get_state_value()
194
231
 
195
232
  @state_value.setter
@@ -200,9 +237,28 @@ class ContinuousState(State):
200
237
  def nbdims(self) -> int:
201
238
  return self._nbdims
202
239
 
240
+ def __hash__(self):
241
+ return hash(tuple(self._state_value))
242
+
243
+ def __eq__(self, other: Self):
244
+ if self.nbdims != other.nbdims:
245
+ raise ValueError(f"The two states have two different dimensions. ({self.nbdims} vs. {other.nbdims})")
246
+ for i in range(self.nbdims):
247
+ if self.state_value[i] != other.state_value[i]:
248
+ return False
249
+ return True
250
+
203
251
 
204
252
  class Discrete2DCartesianState(DiscreteState):
205
- def __init__(self, x_lowlim: int, x_hilim: int, y_lowlim: int, y_hilim: int, initial_coordinate: List[int]=None):
253
+ def __init__(
254
+ self,
255
+ x_lowlim: int,
256
+ x_hilim: int,
257
+ y_lowlim: int,
258
+ y_hilim: int,
259
+ initial_coordinate: list[int]=None,
260
+ terminals: Optional[dict[DiscreteStateValueType, bool]] = None
261
+ ):
206
262
  self._x_lowlim = x_lowlim
207
263
  self._x_hilim = x_hilim
208
264
  self._y_lowlim = y_lowlim
@@ -212,14 +268,50 @@ class Discrete2DCartesianState(DiscreteState):
212
268
  if initial_coordinate is None:
213
269
  initial_coordinate = [self._x_lowlim, self._y_lowlim]
214
270
  initial_value = (initial_coordinate[1] - self._y_lowlim) * self._countx + (initial_coordinate[0] - self._x_lowlim)
215
- super().__init__(list(range(self._countx*self._county)), initial_values=initial_value)
271
+ super().__init__(list(range(self._countx*self._county)), initial_value=initial_value, terminals=terminals)
216
272
 
217
273
  def _encode_coordinates(self, x, y) -> int:
218
274
  return (y - self._y_lowlim) * self._countx + (x - self._x_lowlim)
219
275
 
220
- def encode_coordinates(self, coordinates: List[int]) -> int:
221
- assert len(coordinates) == 2
276
+ def encode_coordinates(self, coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]]) -> int:
277
+ if isinstance(coordinates, list):
278
+ assert len(coordinates) == 2
222
279
  return self._encode_coordinates(coordinates[0], coordinates[1])
223
280
 
224
- def decode_coordinates(self, hashcode) -> List[int]:
225
- return [hashcode % self._countx, hashcode // self._countx]
281
+ def decode_coordinates(self, hashcode) -> list[int]:
282
+ return [hashcode % self._countx + self._x_lowlim, hashcode // self._countx + self._y_lowlim]
283
+
284
+ def get_whether_terminal_given_coordinates(
285
+ self,
286
+ coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]]
287
+ ) -> bool:
288
+ if isinstance(coordinates, list):
289
+ assert len(coordinates) == 2
290
+ hashcode = self._encode_coordinates(coordinates[0], coordinates[1])
291
+ return self._terminal_dict.get(hashcode, False)
292
+
293
+ def set_terminal_given_coordinates(
294
+ self,
295
+ coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]],
296
+ terminal_value: bool
297
+ ) -> None:
298
+ if isinstance(coordinates, list):
299
+ assert len(coordinates) == 2
300
+ hashcode = self._encode_coordinates(coordinates[0], coordinates[1])
301
+ self._terminal_dict[hashcode] = terminal_value
302
+
303
+ @property
304
+ def x_lowlim(self) -> int:
305
+ return self._x_lowlim
306
+
307
+ @property
308
+ def x_hilim(self) -> int:
309
+ return self._x_hilim
310
+
311
+ @property
312
+ def y_lowlim(self) -> int:
313
+ return self._y_lowlim
314
+
315
+ @property
316
+ def y_hilim(self) -> int:
317
+ return self._y_hilim
File without changes
pyrlutils/td/td.py ADDED
@@ -0,0 +1,101 @@
1
+
2
+ from typing import Annotated
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+
7
+ from .utils import decay_schedule, AbstractTemporalDifferenceLearner, TimeDifferencePathElements
8
+
9
+
10
+ class SingleStepTemporalDifferenceLearner(AbstractTemporalDifferenceLearner):
11
+ def learn(
12
+ self,
13
+ episodes: int
14
+ ) -> tuple[Annotated[NDArray[np.float64], "1D Array"], Annotated[NDArray[np.float64], "2D Array"]]:
15
+ V = np.zeros(self.nb_states)
16
+ V_track = np.zeros((episodes, self.nb_states))
17
+ alphas = decay_schedule(
18
+ self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
19
+ )
20
+
21
+ for i in range(episodes):
22
+ self._state.set_state_value(self.initial_state_index)
23
+ done = False
24
+ while not done:
25
+ old_state_index = self._state.state_index
26
+ old_state_value = self._state.state_value
27
+ action_value = self._policy.get_action_value(self._state.state_value)
28
+ action_func = self._actions_dict[action_value]
29
+ self._state = action_func(self._state)
30
+ new_state_index = self._state.state_index
31
+ new_state_value = self._state.state_value
32
+ reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
33
+ done = self._state.is_terminal
34
+
35
+ td_target = reward + self.gamma * V[new_state_index] * (not done)
36
+ td_error = td_target - V[old_state_index]
37
+ V[old_state_index] = V[old_state_index] + alphas[i] * td_error
38
+
39
+ V_track[i, :] = V
40
+
41
+ return V, V_track
42
+
43
+
44
+ class MultipleStepTemporalDifferenceLearner(AbstractTemporalDifferenceLearner):
45
+ def learn(
46
+ self,
47
+ episodes: int,
48
+ n_steps: int=3
49
+ ) -> tuple[Annotated[NDArray[np.float64], "1D Array"], Annotated[NDArray[np.float64], "2D Array"]]:
50
+ V = np.zeros(self.nb_states)
51
+ V_track = np.zeros((episodes, self.nb_states))
52
+ alphas = decay_schedule(
53
+ self.init_alpha, self.min_alpha, self.alpha_decay_ratio, episodes
54
+ )
55
+ discounts = np.logspace(0, n_steps-1, num=n_steps+1, base=self.gamma, endpoint=False)
56
+
57
+ for i in range(episodes):
58
+ self._state.set_state_value(self.initial_state_index)
59
+ done = False
60
+ path = []
61
+
62
+ while not done or path is not None:
63
+ path = path[1:] # worth revisiting this line
64
+
65
+ next_state_index = -1
66
+ while not done and len(path) < n_steps:
67
+ old_state_index = self._state.state_index
68
+ old_state_value = self._state.state_value
69
+ action_value = self._policy.get_action_value(self._state.state_value)
70
+ action_func = self._actions_dict[action_value]
71
+ self._state = action_func(self._state)
72
+ new_state_index = self._state.state_index
73
+ new_state_value = self._state.state_value
74
+ reward = self._indrewardfcn(old_state_value, action_value, new_state_value)
75
+ done = self._state.is_terminal
76
+
77
+ path.append(
78
+ TimeDifferencePathElements(
79
+ this_state_index=old_state_index,
80
+ reward=reward,
81
+ next_state_index=new_state_index,
82
+ done=done
83
+ )
84
+ )
85
+ if done:
86
+ break
87
+
88
+ n = len(path)
89
+ estimated_state_index = path[0].this_state_index
90
+ rewards = np.array([this_moment.reward for this_moment in path])
91
+ partial_return = discounts[n:] * rewards
92
+ bs_val = discounts[-1] * V[next_state_index] * (not done)
93
+ ntd_target = np.sum(np.append(partial_return, bs_val))
94
+ ntd_error = ntd_target - V[estimated_state_index]
95
+ V[estimated_state_index] = V[estimated_state_index] + alphas[i] * ntd_error
96
+ if len(path) == 1 and path[0].done:
97
+ path = None
98
+
99
+ V_track[i, :] = V
100
+
101
+ return V, V_track
pyrlutils/td/utils.py ADDED
@@ -0,0 +1,119 @@
1
+
2
+ from abc import ABC, abstractmethod
3
+ from typing import Optional, Annotated
4
+ from dataclasses import dataclass
5
+
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+
9
+ from ..policy import DiscretePolicy
10
+ from ..transition import TransitionProbabilityFactory
11
+
12
+
13
+ def decay_schedule(
14
+ init_value: float,
15
+ min_value: float,
16
+ decay_ratio: float,
17
+ max_steps: int,
18
+ log_start: int=-2,
19
+ log_base: int=10
20
+ ) -> Annotated[NDArray[np.float64], "1D Array"]:
21
+ decay_steps = int(max_steps*decay_ratio)
22
+ rem_steps = max_steps - decay_steps
23
+
24
+ values = np.logspace(log_start, 0, decay_steps, base=log_base, endpoint=True)[::-1]
25
+ values = (values - values.min()) / (values.max() - values.min())
26
+ values = (init_value - min_value) * values + min_value
27
+ values = np.pad(values, (0, rem_steps), 'edge')
28
+ return values
29
+
30
+
31
+ class AbstractTemporalDifferenceLearner(ABC):
32
+ def __init__(
33
+ self,
34
+ transprobfac: TransitionProbabilityFactory,
35
+ gamma: float=1.0,
36
+ init_alpha: float=0.5,
37
+ min_alpha: float=0.01,
38
+ alpha_decay_ratio: float=0.3,
39
+ policy: Optional[DiscretePolicy]=None,
40
+ initial_state_index: int=0
41
+ ):
42
+ self._gamma = gamma
43
+ self._init_alpha = init_alpha
44
+ self._min_alpha = min_alpha
45
+ try:
46
+ assert 0.0 <= alpha_decay_ratio <= 1.0
47
+ except AssertionError:
48
+ raise ValueError("alpha_decay_ratio must be between 0 and 1!")
49
+ self._alpha_decay_ratio = alpha_decay_ratio
50
+ self._transprobfac = transprobfac
51
+ self._state, self._actions_dict, self._indrewardfcn = self._transprobfac.generate_mdp_objects()
52
+ self._action_names = list(self._actions_dict.keys())
53
+ self._actions_to_indices = {action_value: idx for idx, action_value in enumerate(self._action_names)}
54
+ self._policy = policy
55
+ try:
56
+ assert 0 <= initial_state_index < self._state.nb_state_values
57
+ except AssertionError:
58
+ raise ValueError("Initial state index must be between 0 and {}".format(len(self._state_names)))
59
+ self._init_state_index = initial_state_index
60
+
61
+ @abstractmethod
62
+ def learn(self, *args, **kwargs) -> tuple[Annotated[NDArray[np.float64], "1D Array"], Annotated[NDArray[np.float64], "2D Array"]]:
63
+ raise NotImplementedError()
64
+
65
+ @property
66
+ def nb_states(self) -> int:
67
+ return self._state.nb_state_values
68
+
69
+ @property
70
+ def policy(self) -> DiscretePolicy:
71
+ return self._policy
72
+
73
+ @policy.setter
74
+ def policy(self, val: DiscretePolicy):
75
+ self._policy = val
76
+
77
+ @property
78
+ def gamma(self) -> float:
79
+ return self._gamma
80
+
81
+ @gamma.setter
82
+ def gamma(self, val: float):
83
+ self._gamma = val
84
+
85
+ @property
86
+ def init_alpha(self) -> float:
87
+ return self._init_alpha
88
+
89
+ @init_alpha.setter
90
+ def init_alpha(self, val: float):
91
+ self._init_alpha = val
92
+
93
+ @property
94
+ def min_alpha(self) -> float:
95
+ return self._min_alpha
96
+
97
+ @min_alpha.setter
98
+ def min_alpha(self, val: float):
99
+ self._min_alpha = val
100
+
101
+ @property
102
+ def alpha_decay_ratio(self) -> float:
103
+ return self._alpha_decay_ratio
104
+
105
+ @property
106
+ def initial_state_index(self) -> int:
107
+ return self._init_state_index
108
+
109
+ @initial_state_index.setter
110
+ def initial_state_index(self, val: int):
111
+ self._init_state_index = val
112
+
113
+
114
+ @dataclass
115
+ class TimeDifferencePathElements:
116
+ this_state_index: int
117
+ reward: float
118
+ next_state_index: int
119
+ done: bool
pyrlutils/transition.py CHANGED
@@ -1,6 +1,7 @@
1
1
 
2
- from types import LambdaType
3
- from typing import Tuple, Dict
2
+ from types import LambdaType, FunctionType
3
+ from typing import Union
4
+ from dataclasses import dataclass
4
5
 
5
6
  import numpy as np
6
7
 
@@ -9,28 +10,12 @@ from .reward import IndividualRewardFunction
9
10
  from .action import Action, DiscreteActionValueType
10
11
 
11
12
 
13
+ @dataclass
12
14
  class NextStateTuple:
13
- def __init__(self, next_state_value: DiscreteStateValueType, probability: float, reward: float, terminal: bool):
14
- self._next_state_value = next_state_value
15
- self._probability = probability
16
- self._reward = reward
17
- self._terminal = terminal
18
-
19
- @property
20
- def next_state_value(self) -> DiscreteStateValueType:
21
- return self._next_state_value
22
-
23
- @property
24
- def probability(self) -> float:
25
- return self._probability
26
-
27
- @property
28
- def reward(self) -> float:
29
- return self._reward
30
-
31
- @property
32
- def terminal(self) -> bool:
33
- return self._terminal
15
+ next_state_value: DiscreteStateValueType
16
+ probability: float
17
+ reward: float
18
+ terminal: bool
34
19
 
35
20
 
36
21
  class TransitionProbabilityFactory:
@@ -40,7 +25,11 @@ class TransitionProbabilityFactory:
40
25
  self._all_action_values = []
41
26
  self._objects_generated = False
42
27
 
43
- def add_state_transitions(self, state_value: DiscreteStateValueType, action_values_to_next_state: dict):
28
+ def add_state_transitions(
29
+ self,
30
+ state_value: DiscreteStateValueType,
31
+ action_values_to_next_state: dict[DiscreteActionValueType, Union[list[NextStateTuple], dict]]
32
+ ):
44
33
  if state_value not in self._all_state_values:
45
34
  self._all_state_values.append(state_value)
46
35
 
@@ -69,7 +58,10 @@ class TransitionProbabilityFactory:
69
58
 
70
59
  self._transprobs[state_value] = this_state_transition_dict
71
60
 
72
- def _get_probs_for_eachstate(self, action_value: DiscreteActionValueType) -> Dict[DiscreteStateValueType, NextStateTuple]:
61
+ def _get_probs_for_eachstate(
62
+ self,
63
+ action_value: DiscreteActionValueType
64
+ ) -> dict[DiscreteStateValueType, list[NextStateTuple]]:
73
65
  state_nexttuples = {}
74
66
  for state_value, action_nexttuples_pair in self._transprobs.items():
75
67
  for this_action_value, nexttuples in action_nexttuples_pair.items():
@@ -77,7 +69,10 @@ class TransitionProbabilityFactory:
77
69
  state_nexttuples[state_value] = nexttuples
78
70
  return state_nexttuples
79
71
 
80
- def _generate_action_function(self, state_nexttuples: dict) -> LambdaType:
72
+ def _generate_action_function(
73
+ self,
74
+ state_nexttuples: dict[DiscreteStateValueType, list[NextStateTuple]]
75
+ ) -> Union[FunctionType, LambdaType]:
81
76
 
82
77
  def _action_function(state: DiscreteState) -> DiscreteState:
83
78
  nexttuples = state_nexttuples[state.state_value]
@@ -91,7 +86,11 @@ class TransitionProbabilityFactory:
91
86
 
92
87
  def _generate_individual_reward_function(self) -> IndividualRewardFunction:
93
88
 
94
- def _individual_reward_function(state_value, action_value, next_state_value) -> float:
89
+ def _individual_reward_function(
90
+ state_value: DiscreteStateValueType,
91
+ action_value: DiscreteActionValueType,
92
+ next_state_value: DiscreteStateValueType
93
+ ) -> float:
95
94
  if state_value not in self._transprobs.keys():
96
95
  return 0.
97
96
 
@@ -105,15 +104,22 @@ class TransitionProbabilityFactory:
105
104
  return reward
106
105
 
107
106
  class ThisIndividualRewardFunction(IndividualRewardFunction):
108
- def __init__(self):
109
- super().__init__()
110
-
111
- def reward(self, state_value, action_value, next_state_value) -> float:
107
+ def reward(
108
+ self,
109
+ state_value: DiscreteStateValueType,
110
+ action_value: DiscreteActionValueType,
111
+ next_state_value: DiscreteStateValueType
112
+ ) -> float:
112
113
  return _individual_reward_function(state_value, action_value, next_state_value)
113
114
 
114
115
  return ThisIndividualRewardFunction()
115
116
 
116
- def get_probability(self, state_value, action_value, new_state_value) -> float:
117
+ def get_probability(
118
+ self,
119
+ state_value: DiscreteStateValueType,
120
+ action_value: DiscreteActionValueType,
121
+ new_state_value: DiscreteStateValueType
122
+ ) -> float:
117
123
  if state_value not in self._transprobs.keys():
118
124
  return 0.
119
125
 
@@ -127,18 +133,21 @@ class TransitionProbabilityFactory:
127
133
  return probs
128
134
 
129
135
  @property
130
- def transition_probabilities(self) -> dict:
136
+ def transition_probabilities(self) -> dict[DiscreteStateValueType, dict[DiscreteActionValueType, list[NextStateTuple]]]:
131
137
  return self._transprobs
132
138
 
133
- def generate_mdp_objects(self) -> Tuple[DiscreteState, Dict[DiscreteActionValueType, Action], IndividualRewardFunction]:
139
+ def generate_mdp_objects(self) -> tuple[DiscreteState, dict[DiscreteActionValueType, Action], IndividualRewardFunction]:
134
140
  state = DiscreteState(self._all_state_values)
135
141
  actions_dict = {}
136
142
  for action_value in self._all_action_values:
137
143
  state_nexttuple = self._get_probs_for_eachstate(action_value)
138
144
  actions_dict[action_value] = Action(self._generate_action_function(state_nexttuple))
145
+ for next_tuples in state_nexttuple.values():
146
+ for next_tuple in next_tuples:
147
+ state._terminal_dict[next_tuple.next_state_value] = next_tuple.terminal
139
148
 
140
149
  individual_reward_fcn = self._generate_individual_reward_function()
141
-
150
+ self._objects_generated = True
142
151
  return state, actions_dict, individual_reward_fcn
143
152
 
144
153
  @property
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: pyrlutils
3
- Version: 0.0.4
3
+ Version: 0.1.0
4
4
  Summary: Utility and Helpers for Reinformcement Learning
5
5
  Author-email: Kwan Yuet Stephen Ho <stephenhky@yahoo.com.hk>
6
6
  License: MIT
@@ -11,22 +11,22 @@ Classifier: Topic :: Scientific/Engineering :: Mathematics
11
11
  Classifier: License :: OSI Approved :: MIT License
12
12
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
13
13
  Classifier: Topic :: Software Development :: Version Control :: Git
14
- Classifier: Programming Language :: Python :: 3.7
15
- Classifier: Programming Language :: Python :: 3.8
16
- Classifier: Programming Language :: Python :: 3.9
17
14
  Classifier: Programming Language :: Python :: 3.10
18
15
  Classifier: Programming Language :: Python :: 3.11
19
16
  Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
20
18
  Classifier: Intended Audience :: Science/Research
21
19
  Classifier: Intended Audience :: Developers
22
- Requires-Python: >=3.7
20
+ Requires-Python: >=3.10
23
21
  Description-Content-Type: text/markdown
24
22
  License-File: LICENSE
25
23
  Requires-Dist: numpy
24
+ Requires-Dist: typing-extensions
26
25
  Provides-Extra: openaigym
27
26
  Requires-Dist: gymnasium; extra == "openaigym"
28
27
  Provides-Extra: test
29
28
  Requires-Dist: unittest; extra == "test"
29
+ Dynamic: license-file
30
30
 
31
31
  # PyRLUtils
32
32
 
@@ -0,0 +1,23 @@
1
+ pyrlutils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ pyrlutils/action.py,sha256=QoBdtcGtK_EkYAjb50bruhoB_XIz0agLpQjdGFnGbRQ,732
3
+ pyrlutils/policy.py,sha256=A9bj2eVd6XjNNkClSYVJDoxoGuGkyoYVr1DpVdI0wzs,5120
4
+ pyrlutils/reward.py,sha256=are0swsobMqI1IbrBVBaPMYXWpJnp6lZwAyfgBEm2zg,1211
5
+ pyrlutils/state.py,sha256=A3XJSjNJrsInXUWsUvb1GE7Oq-CY6DNEB-ulrVa1rR4,11774
6
+ pyrlutils/transition.py,sha256=_32jxeYbsiKyaHR9Y2XceUQYbb1jslLCQO2AWL61_EU,6260
7
+ pyrlutils/bandit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ pyrlutils/bandit/algo.py,sha256=X2Pn4DOi-RXWz5CNg1h0RJCoV3VlAwEGHRMjkfbckfw,3969
9
+ pyrlutils/bandit/reward.py,sha256=l2H_gZk2qqDxZioHe1M28pD8N47fgSR-K0Q6muchVd0,282
10
+ pyrlutils/dp/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ pyrlutils/dp/valuefcns.py,sha256=0T7vzdKRIKhLMsaq7JgPqONMmq4lWRGPj7xPtuxVtbE,6546
12
+ pyrlutils/helpers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ pyrlutils/helpers/exceptions.py,sha256=4fPGW839BChfap-Gd7b-75Dz-Ed3foqbJQ1lg15TZ-4,192
14
+ pyrlutils/openai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ pyrlutils/openai/utils.py,sha256=PJc9WHZM8aM4Z9MlACUxUC8TO7VARp8taatba_ikhew,1056
16
+ pyrlutils/td/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
+ pyrlutils/td/td.py,sha256=EnecL84yyUm7rO2idaHgVfvtWW5LYPxEkefHhI1SVPQ,4269
18
+ pyrlutils/td/utils.py,sha256=PALXGaDLd3PjFh8qDV9DY_MkaBuj3_GpfVWJOb424vE,3571
19
+ pyrlutils-0.1.0.dist-info/licenses/LICENSE,sha256=bnQPjIcaeBdr2ZofX-_j-nELs8pAx5fQ4Cdfgeaspew,1063
20
+ pyrlutils-0.1.0.dist-info/METADATA,sha256=qKVydib9iWVw-NXgMnB3y0JtDibVQcvclyc7zP2PYH0,2185
21
+ pyrlutils-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ pyrlutils-0.1.0.dist-info/top_level.txt,sha256=gOBuxugE2MA4WDXlLhzkQh_rUonZU6nvJnMuomeHMCU,10
23
+ pyrlutils-0.1.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,17 +0,0 @@
1
- pyrlutils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- pyrlutils/action.py,sha256=2kJqNZxsLOV8yOTl-RcpM8b0zu-WXNREJCrl49uZi2c,437
3
- pyrlutils/policy.py,sha256=Cx4vsIXzFZi_KEgI06S378Y5E6g-AfK90skDYoGsfOI,2794
4
- pyrlutils/reward.py,sha256=are0swsobMqI1IbrBVBaPMYXWpJnp6lZwAyfgBEm2zg,1211
5
- pyrlutils/state.py,sha256=w0YJ50FUyNboPoYduLMX1xaBJJHAOaSlsr3Og1dd0dY,7840
6
- pyrlutils/transition.py,sha256=lgh4YfOi-YjSIyymWfrXe-ugDWpZYK3MvjdeehgcQhk,5816
7
- pyrlutils/valuefcns.py,sha256=CJxu0EIFgrdbP0n0x6nzs3X08accFsuJW71tv1rMTkQ,6342
8
- pyrlutils/bandit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- pyrlutils/bandit/algo.py,sha256=X2Pn4DOi-RXWz5CNg1h0RJCoV3VlAwEGHRMjkfbckfw,3969
10
- pyrlutils/bandit/reward.py,sha256=S_uECjMOg3cmK24J-5uPcckLvtxmU4yllR7JEvMwAQE,249
11
- pyrlutils/openai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- pyrlutils/openai/utils.py,sha256=ehj1cGlDYjQLno3pKMCS3CzZwbZGSTmjxDlU07aSBFo,1033
13
- pyrlutils-0.0.4.dist-info/LICENSE,sha256=bnQPjIcaeBdr2ZofX-_j-nELs8pAx5fQ4Cdfgeaspew,1063
14
- pyrlutils-0.0.4.dist-info/METADATA,sha256=7ncLjVrpqIZpdMFMrRjqRNgfZl9LUKc_SZFkw_CoTFc,2228
15
- pyrlutils-0.0.4.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
16
- pyrlutils-0.0.4.dist-info/top_level.txt,sha256=gOBuxugE2MA4WDXlLhzkQh_rUonZU6nvJnMuomeHMCU,10
17
- pyrlutils-0.0.4.dist-info/RECORD,,