pyrlutils 0.0.4__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyrlutils might be problematic. Click here for more details.

pyrlutils/action.py CHANGED
@@ -1,5 +1,5 @@
1
1
 
2
- from types import LambdaType
2
+ from types import LambdaType, FunctionType
3
3
  from typing import Union
4
4
 
5
5
  from .state import State
@@ -8,7 +8,7 @@ from .state import State
8
8
  DiscreteActionValueType = Union[float, str]
9
9
 
10
10
  class Action:
11
- def __init__(self, actionfunc: LambdaType):
11
+ def __init__(self, actionfunc: Union[FunctionType, LambdaType]):
12
12
  self._actionfunc = actionfunc
13
13
 
14
14
  def act(self, state: State, *args, **kwargs) -> State:
@@ -17,3 +17,11 @@ class Action:
17
17
 
18
18
  def __call__(self, state: State) -> State:
19
19
  return self.act(state)
20
+
21
+ @property
22
+ def action_function(self) -> Union[FunctionType, LambdaType]:
23
+ return self._actionfunc
24
+
25
+ @action_function.setter
26
+ def action_function(self, new_func: Union[FunctionType, LambdaType]) -> None:
27
+ self._actionfunc = new_func
@@ -1,11 +1,12 @@
1
1
 
2
2
  from abc import ABC, abstractmethod
3
+ from typing import Any
3
4
 
4
5
 
5
6
  class IndividualBanditRewardFunction(ABC):
6
7
  @abstractmethod
7
- def reward(self, action_value) -> float:
8
+ def reward(self, action_value: Any) -> float:
8
9
  pass
9
10
 
10
- def __call__(self, action_value) -> float:
11
+ def __call__(self, action_value: Any) -> float:
11
12
  return self.reward(action_value)
File without changes
@@ -1,18 +1,23 @@
1
1
 
2
2
  import random
3
3
  from copy import copy
4
- from typing import Tuple, Dict
5
4
  from itertools import product
5
+ from typing import Annotated
6
6
 
7
7
  import numpy as np
8
+ from numpy.typing import NDArray
8
9
 
9
- from .state import DiscreteStateValueType
10
- from .transition import TransitionProbabilityFactory
11
- from .policy import DiscreteDeterminsticPolicy
10
+ from ..state import DiscreteStateValueType
11
+ from ..transition import TransitionProbabilityFactory
12
+ from ..policy import DiscreteDeterminsticPolicy
12
13
 
13
14
 
14
15
  class OptimalPolicyOnValueFunctions:
15
- def __init__(self, discount_factor: float, transprobfac: TransitionProbabilityFactory):
16
+ def __init__(
17
+ self,
18
+ discount_factor: float,
19
+ transprobfac: TransitionProbabilityFactory
20
+ ):
16
21
  try:
17
22
  assert 0. <= discount_factor <= 1.
18
23
  except AssertionError:
@@ -31,7 +36,7 @@ class OptimalPolicyOnValueFunctions:
31
36
  self._theta = 1e-10
32
37
  self._policy_evaluation_maxiter = 10000
33
38
 
34
- def _policy_evaluation(self, policy: DiscreteDeterminsticPolicy) -> np.ndarray:
39
+ def _policy_evaluation(self, policy: DiscreteDeterminsticPolicy) -> Annotated[NDArray[np.float64], "1D Array"]:
35
40
  prev_V = np.zeros(len(self._states_to_indices))
36
41
 
37
42
  for _ in range(self._policy_evaluation_maxiter):
@@ -55,7 +60,7 @@ class OptimalPolicyOnValueFunctions:
55
60
 
56
61
  return V
57
62
 
58
- def _policy_improvement(self, V: np.ndarray) -> DiscreteDeterminsticPolicy:
63
+ def _policy_improvement(self, V: Annotated[NDArray[np.float64], "1D Array"]) -> DiscreteDeterminsticPolicy:
59
64
  Q = np.zeros((len(self._states_to_indices), len(self._actions_to_indices)))
60
65
 
61
66
  for state_value in self._state_names:
@@ -78,7 +83,7 @@ class OptimalPolicyOnValueFunctions:
78
83
  optimal_policy.add_deterministic_rule(state_value, action_value)
79
84
  return optimal_policy
80
85
 
81
- def _policy_iteration(self) -> Tuple[np.ndarray, DiscreteDeterminsticPolicy]:
86
+ def _policy_iteration(self) -> tuple[Annotated[NDArray[np.float64], "1D Array"], DiscreteDeterminsticPolicy]:
82
87
  policy = DiscreteDeterminsticPolicy(self._actions_dict)
83
88
  for state_value in self._state_names:
84
89
  policy.add_deterministic_rule(state_value, random.choice(self._action_names))
@@ -97,7 +102,7 @@ class OptimalPolicyOnValueFunctions:
97
102
  return V, policy
98
103
 
99
104
 
100
- def _value_iteration(self) -> Tuple[np.ndarray, DiscreteDeterminsticPolicy]:
105
+ def _value_iteration(self) -> tuple[Annotated[NDArray[np.float64], "1D Array"], DiscreteDeterminsticPolicy]:
101
106
  V = np.zeros(len(self._state_names))
102
107
 
103
108
  for _ in range(self._policy_evaluation_maxiter):
@@ -127,7 +132,7 @@ class OptimalPolicyOnValueFunctions:
127
132
 
128
133
  return V, policy
129
134
 
130
- def policy_iteration(self) -> Tuple[Dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
135
+ def policy_iteration(self) -> tuple[dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
131
136
  V, policy = self._policy_iteration()
132
137
  state_values_dict = {
133
138
  self._state_names[i]: V[i]
@@ -135,7 +140,7 @@ class OptimalPolicyOnValueFunctions:
135
140
  }
136
141
  return state_values_dict, policy
137
142
 
138
- def value_iteration(self) -> Tuple[Dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
143
+ def value_iteration(self) -> tuple[dict[DiscreteStateValueType, float], DiscreteDeterminsticPolicy]:
139
144
  V, policy = self._value_iteration()
140
145
  state_values_dict = {
141
146
  self._state_names[i]: V[i]
File without changes
@@ -0,0 +1,5 @@
1
+
2
+ class InvalidRangeError(Exception):
3
+ def __init__(self, message=None):
4
+ self.message = "Invalid range error!" if message is None else message
5
+ super().__init__(self.message)
pyrlutils/openai/utils.py CHANGED
@@ -5,7 +5,7 @@ from ..transition import TransitionProbabilityFactory, NextStateTuple
5
5
 
6
6
 
7
7
  class OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory(TransitionProbabilityFactory):
8
- def __init__(self, envname):
8
+ def __init__(self, envname: str):
9
9
  super().__init__()
10
10
  self._envname = envname
11
11
  self._gymenv = gym.make(envname)
@@ -23,9 +23,9 @@ class OpenAIGymDiscreteEnvironmentTransitionProbabilityFactory(TransitionProbabi
23
23
  self.add_state_transitions(state_value, new_trans_dict)
24
24
 
25
25
  @property
26
- def envname(self):
26
+ def envname(self) -> str:
27
27
  return self._envname
28
28
 
29
29
  @property
30
- def gymenv(self):
30
+ def gymenv(self) -> gym.Env:
31
31
  return self._gymenv
pyrlutils/policy.py CHANGED
@@ -1,9 +1,10 @@
1
1
 
2
2
  from abc import ABC, abstractmethod
3
- from typing import Union, Dict
3
+ from typing import Union, Annotated
4
4
  from warnings import warn
5
5
 
6
6
  import numpy as np
7
+ from numpy.typing import NDArray
7
8
 
8
9
  from .state import State, DiscreteState, DiscreteStateValueType
9
10
  from .action import Action, DiscreteActionValueType
@@ -12,7 +13,11 @@ from .action import Action, DiscreteActionValueType
12
13
  class Policy(ABC):
13
14
  @abstractmethod
14
15
  def get_action(self, state: State) -> Action:
15
- pass
16
+ raise NotImplemented()
17
+
18
+ @abstractmethod
19
+ def get_action_value(self, state: State) -> DiscreteActionValueType:
20
+ raise NotImplemented()
16
21
 
17
22
  def __call__(self, state: State) -> Action:
18
23
  return self.get_action(state)
@@ -25,7 +30,7 @@ class Policy(ABC):
25
30
  class DeterministicPolicy(Policy):
26
31
  @abstractmethod
27
32
  def add_deterministic_rule(self, *args, **kwargs):
28
- pass
33
+ raise NotImplemented()
29
34
 
30
35
  @property
31
36
  def is_stochastic(self) -> bool:
@@ -33,16 +38,23 @@ class DeterministicPolicy(Policy):
33
38
 
34
39
 
35
40
  class DiscreteDeterminsticPolicy(DeterministicPolicy):
36
- def __init__(self, actions_dict: Dict[DiscreteActionValueType, Action]):
41
+ def __init__(self, actions_dict: dict[DiscreteActionValueType, Action]):
37
42
  self._state_to_action = {}
38
43
  self._actions_dict = actions_dict
39
44
 
40
- def add_deterministic_rule(self, state_value: DiscreteStateValueType, action_value: DiscreteActionValueType):
45
+ def add_deterministic_rule(
46
+ self,
47
+ state_value: DiscreteStateValueType,
48
+ action_value: DiscreteActionValueType
49
+ ) -> None:
41
50
  if state_value in self._state_to_action:
42
51
  warn('State value {} exists in rule; it will be replaced.'.format(state_value))
43
52
  self._state_to_action[state_value] = action_value
44
53
 
45
- def get_action_value(self, state_value: DiscreteStateValueType) -> DiscreteActionValueType:
54
+ def get_action_value(
55
+ self,
56
+ state_value: DiscreteStateValueType
57
+ ) -> DiscreteActionValueType:
46
58
  return self._state_to_action.get(state_value)
47
59
 
48
60
  def get_action(self, state: DiscreteState) -> Action:
@@ -62,10 +74,16 @@ class DiscreteDeterminsticPolicy(DeterministicPolicy):
62
74
  return True
63
75
 
64
76
 
77
+ class DiscreteContinuousPolicy(DeterministicPolicy):
78
+ @abstractmethod
79
+ def get_action(self, state: State) -> Action:
80
+ raise NotImplemented()
81
+
82
+
65
83
  class StochasticPolicy(Policy):
66
84
  @abstractmethod
67
85
  def get_probability(self, *args, **kwargs) -> float:
68
- pass
86
+ raise NotImplemented()
69
87
 
70
88
  @property
71
89
  def is_stochastic(self) -> bool:
@@ -73,12 +91,61 @@ class StochasticPolicy(Policy):
73
91
 
74
92
 
75
93
  class DiscreteStochasticPolicy(StochasticPolicy):
76
- @abstractmethod
77
- def get_probability(self, state_value: DiscreteStateValueType, action_value: DiscreteActionValueType) -> float:
78
- pass
94
+ def __init__(self, actions_dict: dict[DiscreteActionValueType, Action]):
95
+ self._state_to_action = {}
96
+ self._actions_dict = actions_dict
97
+
98
+ def add_stochastic_rule(
99
+ self,
100
+ state_value: DiscreteStateValueType,
101
+ action_values: list[DiscreteActionValueType],
102
+ probs: Union[list[float], Annotated[NDArray[np.float64], "1D Array"]] = None
103
+ ):
104
+ if probs is not None:
105
+ assert len(action_values) == len(probs)
106
+ probs = np.array(probs)
107
+ else:
108
+ probs = np.repeat(1./len(action_values), len(action_values))
109
+
110
+ if state_value in self._state_to_action:
111
+ warn('State value {} exists in rule; it will be replaced.'.format(state_value))
112
+ self._state_to_action[state_value] = {
113
+ action_value: prob
114
+ for action_value, prob in zip(action_values, probs)
115
+ }
116
+
117
+ def get_probability(
118
+ self,
119
+ state_value: DiscreteStateValueType,
120
+ action_value: DiscreteActionValueType
121
+ ) -> float:
122
+ if state_value not in self._state_to_action:
123
+ return 0.0
124
+ if action_value in self._state_to_action[state_value]:
125
+ return self._state_to_action[state_value][action_value]
126
+ else:
127
+ return 0.0
128
+
129
+ def get_action_value(self, state: State) -> DiscreteActionValueType:
130
+ allowed_actions = list(self._state_to_action[state].keys())
131
+ probs = np.array(list(self._state_to_action[state].values()))
132
+ sumprobs = np.sum(probs)
133
+ return np.random.choice(allowed_actions, p=probs/sumprobs)
134
+
135
+ def get_action(self, state: DiscreteState) -> Action:
136
+ return self._actions_dict[self.get_action_value(state.state_value)]
79
137
 
80
138
 
81
139
  class ContinuousStochasticPolicy(StochasticPolicy):
82
140
  @abstractmethod
83
- def get_probability(self, state_value: Union[float, np.ndarray], action_value: DiscreteActionValueType, value: Union[float, np.ndarray]) -> float:
84
- pass
141
+ def get_probability(
142
+ self,
143
+ state_value: Union[float, Annotated[NDArray[np.float64], "1D Array"]],
144
+ action_value: DiscreteActionValueType,
145
+ value: Union[float, Annotated[NDArray[np.float64], "1D Array"]]
146
+ ) -> float:
147
+ raise NotImplemented()
148
+
149
+
150
+ DiscretePolicy = Union[DiscreteDeterminsticPolicy, DiscreteStochasticPolicy]
151
+ ContinuousPolicy = Union[ContinuousStochasticPolicy]
pyrlutils/state.py CHANGED
@@ -1,81 +1,87 @@
1
1
 
2
- from abc import ABC, abstractmethod
2
+ import sys
3
+ from abc import ABC
3
4
  from enum import Enum
4
- from dataclasses import dataclass
5
- from typing import Tuple, List, Optional, Union
5
+ from typing import Optional, Union, Annotated, Literal
6
6
 
7
7
  import numpy as np
8
+ from numpy.typing import NDArray
9
+ if sys.version_info < (3, 11):
10
+ from typing_extensions import Self
11
+ else:
12
+ from typing import Self
8
13
 
9
-
10
- class StateValue(ABC):
11
- @property
12
- @abstractmethod
13
- def value(self):
14
- pass
15
-
16
-
17
- @dataclass
18
- class DiscreteStateValue(StateValue):
19
- enum: Enum
20
-
21
- @property
22
- def value(self):
23
- return self.enum.value
24
-
25
- def name(self):
26
- return self.enum.name
27
-
28
-
29
- class ContinuousStateValue(StateValue):
30
- _value: float
31
-
32
- @property
33
- def value(self) -> float:
34
- return self._value
14
+ from .helpers.exceptions import InvalidRangeError
35
15
 
36
16
 
37
17
  class State(ABC):
38
18
  @property
39
19
  def state_value(self):
40
- return self.get_state_value()
41
-
42
- @abstractmethod
43
- def set_state_value(self, state_value):
44
- pass
20
+ raise NotImplemented()
45
21
 
46
- @abstractmethod
47
- def get_state_value(self):
48
- pass
49
-
50
- @state_value.setter
51
- def state_value(self, new_state_value):
52
- self.set_state_value(new_state_value)
53
22
 
54
-
55
- DiscreteStateValueType = Union[float, str, Tuple[int], Enum]
23
+ DiscreteStateValueType = Union[str, int, tuple[int], Enum]
56
24
 
57
25
 
58
26
  class DiscreteState(State):
59
- def __init__(self, all_state_values: List[DiscreteStateValueType], initial_values: Optional[List[DiscreteStateValueType]] = None):
27
+ def __init__(
28
+ self,
29
+ all_state_values: list[DiscreteStateValueType],
30
+ initial_value: Optional[DiscreteStateValueType] = None,
31
+ terminals: Optional[dict[DiscreteStateValueType, bool]]=None
32
+ ):
60
33
  super().__init__()
61
34
  self._all_state_values = all_state_values
62
- self._state_value = initial_values if initial_values is not None and initial_values in self._all_state_values else self._all_state_values[0]
35
+ self._state_values_to_indices = {
36
+ state_value: idx
37
+ for idx, state_value in enumerate(self._all_state_values)
38
+ }
39
+ if initial_value is not None:
40
+ self._current_index = self._state_values_to_indices[initial_value]
41
+ else:
42
+ self._current_index = 0
43
+ if terminals is None:
44
+ self._terminal_dict = {
45
+ state_value: False
46
+ for state_value in self._all_state_values
47
+ }
48
+ else:
49
+ self._terminal_dict = terminals.copy()
50
+ for state_value in self._all_state_values:
51
+ if self._terminal_dict.get(state_value) is None:
52
+ self._terminal_dict[state_value] = False
53
+
54
+ def _get_state_value_from_index(self, index: int) -> DiscreteStateValueType:
55
+ return self._all_state_values[index]
63
56
 
64
57
  def get_state_value(self) -> DiscreteStateValueType:
65
- return self._state_value
58
+ return self._get_state_value_from_index(self._current_index)
66
59
 
67
- def set_state_value(self, state_value: DiscreteStateValueType):
60
+ def set_state_value(self, state_value: DiscreteStateValueType) -> None:
68
61
  if state_value in self._all_state_values:
69
- self._state_value = state_value
62
+ self._current_index = self._state_values_to_indices[state_value]
70
63
  else:
71
64
  raise ValueError('State value {} is invalid.'.format(state_value))
72
65
 
73
- def get_all_possible_state_values(self) -> List[DiscreteStateValueType]:
66
+ def get_all_possible_state_values(self) -> list[DiscreteStateValueType]:
74
67
  return self._all_state_values
75
68
 
69
+ def query_state_index_from_value(self, value: DiscreteStateValueType) -> int:
70
+ return self._state_values_to_indices[value]
71
+
72
+ @property
73
+ def state_index(self) -> int:
74
+ return self._current_index
75
+
76
+ @state_index.setter
77
+ def state_index(self, new_index: int) -> None:
78
+ if new_index >= len(self._all_state_values):
79
+ raise ValueError(f"Invalid index {new_index}; it must be less than {self.nb_state_values}.")
80
+ self._current_index = new_index
81
+
76
82
  @property
77
83
  def state_value(self) -> DiscreteStateValueType:
78
- return self._state_value
84
+ return self._all_state_values[self._current_index]
79
85
 
80
86
  @state_value.setter
81
87
  def state_value(self, new_state_value: DiscreteStateValueType):
@@ -85,22 +91,53 @@ class DiscreteState(State):
85
91
  def state_space_size(self):
86
92
  return len(self._all_state_values)
87
93
 
94
+ @property
95
+ def nb_state_values(self) -> int:
96
+ return len(self._all_state_values)
88
97
 
89
- class InvalidRangeError(Exception):
90
- def __init__(self, message=None):
91
- self.message = "Invalid range error!" if message is None else message
92
- super().__init__(self.message)
98
+ @property
99
+ def is_terminal(self) -> bool:
100
+ return self._terminal_dict[self._all_state_values[self._current_index]]
101
+
102
+ def __hash__(self):
103
+ return self._current_index
104
+
105
+ def __eq__(self, other: Self) -> bool:
106
+ return self._current_index == other._current_index
93
107
 
94
108
 
95
109
  class ContinuousState(State):
96
- def __init__(self, nbdims: int, ranges: np.array, init_value: Optional[Union[float, np.ndarray]] = None):
110
+ def __init__(
111
+ self,
112
+ nbdims: int,
113
+ ranges: Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]],
114
+ init_value: Optional[Union[float, Annotated[NDArray[np.float64], "1D Array"]]] = None
115
+ ):
116
+ super().__init__()
97
117
  self._nbdims = nbdims
98
118
 
119
+ try:
120
+ assert isinstance(ranges, np.ndarray)
121
+ except AssertionError:
122
+ raise TypeError('Range must be a numpy array.')
123
+
99
124
  try:
100
125
  assert (ranges.dtype == np.float64) or (ranges.dtype == np.float32) or (ranges.dtype == np.float16)
101
126
  except AssertionError:
102
127
  raise TypeError('It has to be floating type numpy.ndarray.')
103
128
 
129
+ try:
130
+ assert ranges.ndim == 1 or ranges.ndim == 2
131
+ match ranges.ndim:
132
+ case 1:
133
+ assert ranges.shape[0] == 2
134
+ case 2:
135
+ assert ranges.shape[1] == 2
136
+ case _:
137
+ raise ValueError("Ranges must be of shape (2, ) or (*, 2).")
138
+ except AssertionError:
139
+ raise ValueError("Ranges must be of shape (2, ) or (*, 2).")
140
+
104
141
  try:
105
142
  assert self._nbdims > 0
106
143
  except AssertionError:
@@ -146,50 +183,53 @@ class ContinuousState(State):
146
183
  raise ValueError('Initialized value does not have the right dimension.')
147
184
  for i in range(self._nbdims):
148
185
  try:
149
- assert (init_value[i] >= self._ranges[i, 0]) and (init_value[i] <= self.ranges[i, 1])
186
+ assert self._ranges[i, 0] <= init_value[i] <= self.ranges[i, 1]
150
187
  except AssertionError:
151
188
  raise InvalidRangeError('Initialized value at dimension {} (value: {}) is not within the permitted range ({} -> {})!'.format(i, init_value[i], self._ranges[i, 0], self._ranges[i, 1]))
152
189
  else:
153
190
  try:
154
- assert (init_value >= self._ranges[0, 0]) and (init_value <= self.ranges[0, 1])
191
+ assert self._ranges[0, 0] <= init_value <= self.ranges[0, 1]
155
192
  except AssertionError:
156
193
  raise InvalidRangeError('Initialized value is out of range.')
157
194
  self._state_value = init_value
158
195
 
159
- def set_state_value(self, state_value: Union[float, np.ndarray]):
160
- if self.nbdims > 1:
196
+ def set_state_value(self, state_value: Union[float, Annotated[NDArray[np.float64], "1D Array"]]):
197
+ if self._nbdims > 1:
161
198
  try:
162
199
  assert state_value.shape[0] == self._nbdims
163
200
  except AssertionError:
164
201
  raise ValueError('Given value does not have the right dimension.')
165
- for i in range(self.nbdims):
202
+ for i in range(self._nbdims):
166
203
  try:
167
- assert state_value[i] >= self.ranges[i, 0] and state_value[i] <= self.ranges[i, 1]
204
+ assert self.ranges[i, 0] <= state_value[i] <= self.ranges[i, 1]
168
205
  except AssertionError:
169
206
  raise InvalidRangeError()
170
207
  else:
171
208
  try:
172
- assert state_value >= self.ranges[0, 0] and state_value <= self.ranges[0, 1]
209
+ assert self.ranges[0, 0] <= state_value <= self.ranges[0, 1]
173
210
  except AssertionError:
174
211
  raise InvalidRangeError()
175
212
 
176
213
  self._state_value = state_value
177
214
 
178
- def get_state_value(self) -> np.ndarray:
215
+ def get_state_value(self) -> Annotated[NDArray[np.float64], "1D Array"]:
179
216
  return self._state_value
180
217
 
181
- def get_state_value_ranges(self) -> np.ndarray:
218
+ def get_state_value_ranges(self) -> Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]]:
182
219
  return self._ranges
183
220
 
184
- def get_state_value_range_at_dimension(self, dimension: int) -> np.ndarray:
185
- return self._ranges[dimension]
221
+ def get_state_value_range_at_dimension(self, dimension: int) -> Annotated[NDArray[np.float64], Literal["2"]]:
222
+ if dimension < self._nbdims:
223
+ return self._ranges[dimension]
224
+ else:
225
+ raise ValueError(f"There are only {self._nbdims} dimensions!")
186
226
 
187
227
  @property
188
- def ranges(self) -> np.ndarray:
228
+ def ranges(self) -> Union[Annotated[NDArray[np.float64], Literal["2"]], Annotated[NDArray[np.float64], Literal["*", "2"]]]:
189
229
  return self.get_state_value_ranges()
190
230
 
191
231
  @property
192
- def state_value(self) -> Union[float, np.ndarray]:
232
+ def state_value(self) -> Union[float, NDArray[np.float64]]:
193
233
  return self.get_state_value()
194
234
 
195
235
  @state_value.setter
@@ -200,9 +240,28 @@ class ContinuousState(State):
200
240
  def nbdims(self) -> int:
201
241
  return self._nbdims
202
242
 
243
+ def __hash__(self):
244
+ return hash(tuple(self._state_value))
245
+
246
+ def __eq__(self, other: Self):
247
+ if self.nbdims != other.nbdims:
248
+ raise ValueError(f"The two states have two different dimensions. ({self.nbdims} vs. {other.nbdims})")
249
+ for i in range(self.nbdims):
250
+ if self.state_value[i] != other.state_value[i]:
251
+ return False
252
+ return True
253
+
203
254
 
204
255
  class Discrete2DCartesianState(DiscreteState):
205
- def __init__(self, x_lowlim: int, x_hilim: int, y_lowlim: int, y_hilim: int, initial_coordinate: List[int]=None):
256
+ def __init__(
257
+ self,
258
+ x_lowlim: int,
259
+ x_hilim: int,
260
+ y_lowlim: int,
261
+ y_hilim: int,
262
+ initial_coordinate: list[int]=None,
263
+ terminals: Optional[dict[DiscreteStateValueType, bool]] = None
264
+ ):
206
265
  self._x_lowlim = x_lowlim
207
266
  self._x_hilim = x_hilim
208
267
  self._y_lowlim = y_lowlim
@@ -212,14 +271,50 @@ class Discrete2DCartesianState(DiscreteState):
212
271
  if initial_coordinate is None:
213
272
  initial_coordinate = [self._x_lowlim, self._y_lowlim]
214
273
  initial_value = (initial_coordinate[1] - self._y_lowlim) * self._countx + (initial_coordinate[0] - self._x_lowlim)
215
- super().__init__(list(range(self._countx*self._county)), initial_values=initial_value)
274
+ super().__init__(list(range(self._countx*self._county)), initial_value=initial_value, terminals=terminals)
216
275
 
217
276
  def _encode_coordinates(self, x, y) -> int:
218
277
  return (y - self._y_lowlim) * self._countx + (x - self._x_lowlim)
219
278
 
220
- def encode_coordinates(self, coordinates: List[int]) -> int:
221
- assert len(coordinates) == 2
279
+ def encode_coordinates(self, coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]]) -> int:
280
+ if isinstance(coordinates, list):
281
+ assert len(coordinates) == 2
222
282
  return self._encode_coordinates(coordinates[0], coordinates[1])
223
283
 
224
- def decode_coordinates(self, hashcode) -> List[int]:
225
- return [hashcode % self._countx, hashcode // self._countx]
284
+ def decode_coordinates(self, hashcode) -> list[int]:
285
+ return [hashcode % self._countx + self._x_lowlim, hashcode // self._countx + self._y_lowlim]
286
+
287
+ def get_whether_terminal_given_coordinates(
288
+ self,
289
+ coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]]
290
+ ) -> bool:
291
+ if isinstance(coordinates, list):
292
+ assert len(coordinates) == 2
293
+ hashcode = self._encode_coordinates(coordinates[0], coordinates[1])
294
+ return self._terminal_dict.get(hashcode, False)
295
+
296
+ def set_terminal_given_coordinates(
297
+ self,
298
+ coordinates: Union[list[int], Annotated[NDArray[np.int64], Literal["2"]]],
299
+ terminal_value: bool
300
+ ) -> None:
301
+ if isinstance(coordinates, list):
302
+ assert len(coordinates) == 2
303
+ hashcode = self._encode_coordinates(coordinates[0], coordinates[1])
304
+ self._terminal_dict[hashcode] = terminal_value
305
+
306
+ @property
307
+ def x_lowlim(self) -> int:
308
+ return self._x_lowlim
309
+
310
+ @property
311
+ def x_hilim(self) -> int:
312
+ return self._x_hilim
313
+
314
+ @property
315
+ def y_lowlim(self) -> int:
316
+ return self._y_lowlim
317
+
318
+ @property
319
+ def y_hilim(self) -> int:
320
+ return self._y_hilim
File without changes