gym-csle-stopping-game 0.6.1__tar.gz → 0.6.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of gym-csle-stopping-game might be problematic. Click here for more details.

Files changed (31) hide show
  1. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/PKG-INFO +1 -1
  2. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/setup.cfg +5 -5
  3. gym_csle_stopping_game-0.6.2/src/gym_csle_stopping_game/__version__.py +1 -0
  4. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/envs/stopping_game_env.py +4 -4
  5. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game.egg-info/PKG-INFO +1 -1
  6. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game.egg-info/SOURCES.txt +3 -0
  7. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game.egg-info/requires.txt +5 -5
  8. gym_csle_stopping_game-0.6.2/tests/test_stopping_game_env.py +428 -0
  9. gym_csle_stopping_game-0.6.2/tests/test_stopping_game_mdp_attacker_env.py +343 -0
  10. gym_csle_stopping_game-0.6.2/tests/test_stopping_game_pomdp_defender_env.py +336 -0
  11. gym_csle_stopping_game-0.6.1/src/gym_csle_stopping_game/__version__.py +0 -1
  12. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/pyproject.toml +0 -0
  13. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/setup.py +0 -0
  14. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/__init__.py +0 -0
  15. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/constants/__init__.py +0 -0
  16. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/constants/constants.py +0 -0
  17. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/dao/__init__.py +0 -0
  18. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/dao/stopping_game_attacker_mdp_config.py +0 -0
  19. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/dao/stopping_game_config.py +0 -0
  20. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/dao/stopping_game_defender_pomdp_config.py +0 -0
  21. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/dao/stopping_game_state.py +0 -0
  22. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/envs/__init__.py +0 -0
  23. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/envs/stopping_game_mdp_attacker_env.py +0 -0
  24. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/envs/stopping_game_pomdp_defender_env.py +0 -0
  25. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/util/__init__.py +0 -0
  26. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/util/stopping_game_util.py +0 -0
  27. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game.egg-info/dependency_links.txt +0 -0
  28. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game.egg-info/not-zip-safe +0 -0
  29. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game.egg-info/top_level.txt +0 -0
  30. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/tests/test_stopping_game_dao.py +0 -0
  31. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/tests/test_stopping_game_util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gym_csle_stopping_game
3
- Version: 0.6.1
3
+ Version: 0.6.2
4
4
  Summary: OpenAI gym reinforcement learning environment of a Dynkin (Optimal stopping) game in CSLE
5
5
  Author: Kim Hammar
6
6
  Author-email: hammar.kim@gmail.com
@@ -20,11 +20,11 @@ classifiers =
20
20
  [options]
21
21
  install_requires =
22
22
  gymnasium>=0.27.1
23
- csle-base>=0.6.1
24
- csle-common>=0.6.1
25
- csle-attacker>=0.6.1
26
- csle-defender>=0.6.1
27
- csle-collector>=0.6.1
23
+ csle-base>=0.6.2
24
+ csle-common>=0.6.2
25
+ csle-attacker>=0.6.2
26
+ csle-defender>=0.6.2
27
+ csle-collector>=0.6.2
28
28
  python_requires = >=3.8
29
29
  package_dir =
30
30
  =src
@@ -0,0 +1 @@
1
+ __version__ = '0.6.2'
@@ -42,7 +42,6 @@ class StoppingGameEnv(BaseEnv):
42
42
 
43
43
  # Initialize environment state
44
44
  self.state = StoppingGameState(b1=self.config.b1, L=self.config.L)
45
-
46
45
  # Setup spaces
47
46
  self.attacker_observation_space = self.config.attacker_observation_space()
48
47
  self.defender_observation_space = self.config.defender_observation_space()
@@ -73,7 +72,7 @@ class StoppingGameEnv(BaseEnv):
73
72
  a1, a2_profile = action_profile
74
73
  pi2, a2 = a2_profile
75
74
  assert pi2.shape[0] == len(self.config.S)
76
- assert pi2.shape[1] == len(self.config.A1)
75
+ assert pi2.shape[1] == len(self.config.A2)
77
76
  done = False
78
77
  info: Dict[str, Any] = {}
79
78
 
@@ -84,8 +83,7 @@ class StoppingGameEnv(BaseEnv):
84
83
  else:
85
84
  # Compute r, s', b',o'
86
85
  r = self.config.R[self.state.l - 1][a1][a2][self.state.s]
87
- self.state.s = StoppingGameUtil.sample_next_state(l=self.state.l, a1=a1, a2=a2,
88
- T=self.config.T,
86
+ self.state.s = StoppingGameUtil.sample_next_state(l=self.state.l, a1=a1, a2=a2, T=self.config.T,
89
87
  S=self.config.S, s=self.state.s)
90
88
  o = StoppingGameUtil.sample_next_observation(Z=self.config.Z,
91
89
  O=self.config.O, s_prime=self.state.s)
@@ -437,6 +435,8 @@ class StoppingGameEnv(BaseEnv):
437
435
  :param l: the number of stops remaining
438
436
  :return: the observation
439
437
  """
438
+ if not history:
439
+ raise ValueError("History must not be empty")
440
440
  return [history[-1]]
441
441
 
442
442
  def generate_random_particles(self, o: int, num_particles: int) -> List[int]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gym-csle-stopping-game
3
- Version: 0.6.1
3
+ Version: 0.6.2
4
4
  Summary: OpenAI gym reinforcement learning environment of a Dynkin (Optimal stopping) game in CSLE
5
5
  Author: Kim Hammar
6
6
  Author-email: hammar.kim@gmail.com
@@ -23,4 +23,7 @@ src/gym_csle_stopping_game/envs/stopping_game_pomdp_defender_env.py
23
23
  src/gym_csle_stopping_game/util/__init__.py
24
24
  src/gym_csle_stopping_game/util/stopping_game_util.py
25
25
  tests/test_stopping_game_dao.py
26
+ tests/test_stopping_game_env.py
27
+ tests/test_stopping_game_mdp_attacker_env.py
28
+ tests/test_stopping_game_pomdp_defender_env.py
26
29
  tests/test_stopping_game_util.py
@@ -1,9 +1,9 @@
1
1
  gymnasium>=0.27.1
2
- csle-base>=0.6.1
3
- csle-common>=0.6.1
4
- csle-attacker>=0.6.1
5
- csle-defender>=0.6.1
6
- csle-collector>=0.6.1
2
+ csle-base>=0.6.2
3
+ csle-common>=0.6.2
4
+ csle-attacker>=0.6.2
5
+ csle-defender>=0.6.2
6
+ csle-collector>=0.6.2
7
7
 
8
8
  [testing]
9
9
  pytest>=6.0
@@ -0,0 +1,428 @@
1
+ from typing import Dict, Any
2
+ import pytest
3
+ from unittest.mock import patch, MagicMock
4
+ from gym.spaces import Box, Discrete
5
+ import numpy as np
6
+ from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
7
+ from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
8
+ from gym_csle_stopping_game.dao.stopping_game_state import StoppingGameState
9
+ import gym_csle_stopping_game.constants.constants as env_constants
10
+ from csle_common.constants import constants
11
+
12
+
13
+ class TestStoppingGameEnvSuite:
14
+ """
15
+ Test suite for stopping_game_env.py
16
+ """
17
+
18
+ @pytest.fixture(autouse=True)
19
+ def setup_env(self) -> None:
20
+ """
21
+ Sets up the configuration of the stopping game
22
+
23
+ :return: None
24
+ """
25
+ env_name = "test_env"
26
+ T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
27
+ O = np.array([0, 1])
28
+ Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]])
29
+ R = np.zeros((2, 3, 3, 3))
30
+ S = np.array([0, 1, 2])
31
+ A1 = np.array([0, 1, 2])
32
+ A2 = np.array([0, 1, 2])
33
+ L = 2
34
+ R_INT = 1
35
+ R_COST = 2
36
+ R_SLA = 3
37
+ R_ST = 4
38
+ b1 = np.array([0.6, 0.4])
39
+ save_dir = "save_directory"
40
+ checkpoint_traces_freq = 100
41
+ gamma = 0.9
42
+ compute_beliefs = True
43
+ save_trace = True
44
+ self.config = StoppingGameConfig(
45
+ env_name,
46
+ T,
47
+ O,
48
+ Z,
49
+ R,
50
+ S,
51
+ A1,
52
+ A2,
53
+ L,
54
+ R_INT,
55
+ R_COST,
56
+ R_SLA,
57
+ R_ST,
58
+ b1,
59
+ save_dir,
60
+ checkpoint_traces_freq,
61
+ gamma,
62
+ compute_beliefs,
63
+ save_trace,
64
+ )
65
+
66
+ def test_stopping_game_init_(self) -> None:
67
+ """
68
+ Tests the initializing function
69
+
70
+ :return: None
71
+ """
72
+ T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
73
+ O = np.array([0, 1])
74
+ A1 = np.array([0, 1, 2])
75
+ A2 = np.array([0, 1, 2])
76
+ L = 2
77
+ b1 = np.array([0.6, 0.4])
78
+ attacker_observation_space = Box(
79
+ low=np.array([0.0, 0.0, 0.0]),
80
+ high=np.array([float(L), 1.0, 2.0]),
81
+ dtype=np.float64,
82
+ )
83
+ defender_observation_space = Box(
84
+ low=np.array([0.0, 0.0]),
85
+ high=np.array([float(L), 1.0]),
86
+ dtype=np.float64,
87
+ )
88
+ attacker_action_space = Discrete(len(A2))
89
+ defender_action_space = Discrete(len(A1))
90
+
91
+ assert self.config.T.any() == T.any()
92
+ assert self.config.O.any() == O.any()
93
+ assert self.config.b1.any() == b1.any()
94
+ assert self.config.L == L
95
+
96
+ env = StoppingGameEnv(self.config)
97
+ assert env.config == self.config
98
+ assert env.attacker_observation_space.low.any() == attacker_observation_space.low.any()
99
+ assert env.defender_observation_space.low.any() == defender_observation_space.low.any()
100
+ assert env.attacker_action_space.n == attacker_action_space.n
101
+ assert env.defender_action_space.n == defender_action_space.n
102
+ assert env.traces == []
103
+
104
+ with patch("gym_csle_stopping_game.dao.stopping_game_state.StoppingGameState") as MockStoppingGameState:
105
+ MockStoppingGameState(b1=self.config.b1, L=self.config.L)
106
+ with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_initial_state"
107
+ ) as MockSampleInitialState:
108
+ MockSampleInitialState.return_value = 0
109
+ StoppingGameEnv(self.config)
110
+ MockSampleInitialState.assert_called()
111
+ MockStoppingGameState.assert_called_once_with(b1=self.config.b1, L=self.config.L)
112
+
113
+ with patch("csle_common.dao.simulation_config.simulation_trace.SimulationTrace") as MockSimulationTrace:
114
+ MockSimulationTrace(self.config.env_name).return_value
115
+ StoppingGameEnv(self.config)
116
+ MockSimulationTrace.assert_called_once_with(self.config.env_name)
117
+
118
+ def test_mean(self) -> None:
119
+ """
120
+ Tests the utility function for getting the mean of a vector
121
+
122
+ :return: None
123
+ """
124
+ test_cases = [
125
+ ([], 0), # Test case for an empty vector
126
+ ([5], 0), # Test case for a vector with a single element
127
+ ([0.2, 0.3, 0.5], 1.3), # Test case for a vector with multiple elements
128
+ ]
129
+ for prob_vector, expected_mean in test_cases:
130
+ result = StoppingGameEnv(self.config).mean(prob_vector)
131
+ assert result == expected_mean
132
+
133
+ def test_weighted_intrusion_prediction_distance(self) -> None:
134
+ """
135
+ Tests the function of computing the weighed intrusion start time prediction distance
136
+ """
137
+ # Test case when first_stop is before intrusion_start
138
+ result1 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance(5, 3)
139
+ assert result1 == 0
140
+
141
+ # Test case when first_stop is after intrusion_start
142
+ result2 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance(3, 5)
143
+ assert result2 == 0.95
144
+
145
+ # Test case when first_stop is equal to intrusion_start
146
+ result3 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance(3, 3)
147
+ assert result3 == 0
148
+
149
+ def test_reset(self) -> None:
150
+ """
151
+ Tests the reset function for reseting the environment state
152
+
153
+ :return: None
154
+ """
155
+ env = StoppingGameEnv(self.config)
156
+ env.state = MagicMock()
157
+ env.state.l = 10
158
+ env.state.s = "initial_state"
159
+ env.state.t = 0
160
+ env.state.attacker_observation.return_value = np.array([1, 2, 3])
161
+ env.state.defender_observation.return_value = np.array([4, 5, 6])
162
+
163
+ env.trace = MagicMock()
164
+ env.trace.attacker_rewards = [1]
165
+ env.traces = []
166
+ # Call the reset method
167
+ observation, info = env.reset()
168
+ # Assertions
169
+ assert env.state.reset.called, "State's reset method was not called."
170
+ assert env.trace.simulation_env == self.config.env_name, "Trace was not initialized correctly."
171
+ assert observation[0].all() == np.array([4, 5, 6]).all(), "Observation does not match expected values."
172
+ assert info[env_constants.ENV_METRICS.STOPS_REMAINING] == env.state.l, \
173
+ "Stops remaining does not match expected value."
174
+ assert info[env_constants.ENV_METRICS.STATE] == env.state.s, "State info does not match expected value."
175
+ assert info[env_constants.ENV_METRICS.OBSERVATION] == 0, "Observation info does not match expected value."
176
+ assert info[env_constants.ENV_METRICS.TIME_STEP] == env.state.t, "Time step info does not match expected value."
177
+
178
+ # Check if trace was appended correctly
179
+ if len(env.trace.attacker_rewards) > 0:
180
+ assert env.traces[-1] == env.trace, "Trace was not appended correctly."
181
+
182
+ def test_render(self) -> None:
183
+ """
184
+ Tests the function of rendering the environment
185
+
186
+ :return: None
187
+ """
188
+ with pytest.raises(NotImplementedError):
189
+ StoppingGameEnv(self.config).render()
190
+
191
+ def test_is_defense_action_legal(self) -> None:
192
+ """
193
+ Tests the function of checking whether a defender action in the environment is legal or not
194
+
195
+ :return: None
196
+ """
197
+ assert StoppingGameEnv(self.config).is_defense_action_legal(1)
198
+
199
+ def test_is_attack_action_legal(self) -> None:
200
+ """
201
+ Tests the function of checking whether an attacker action in the environment is legal or not
202
+
203
+ :return: None
204
+ """
205
+ assert StoppingGameEnv(self.config).is_attack_action_legal(1)
206
+
207
+ def test_get_traces(self) -> None:
208
+ """
209
+ Tests the function of getting the list of simulation traces
210
+
211
+ :return: None
212
+ """
213
+ assert StoppingGameEnv(self.config).get_traces() == StoppingGameEnv(self.config).traces
214
+
215
+ def test_reset_traces(self) -> None:
216
+ """
217
+ Tests the function of resetting the list of traces
218
+
219
+ :return: None
220
+ """
221
+ env = StoppingGameEnv(self.config)
222
+ env.traces = ["trace1", "trace2"]
223
+ env.reset_traces()
224
+ assert env.traces == []
225
+
226
+ def test_checkpoint_traces(self) -> None:
227
+ """
228
+ Tests the function of checkpointing agent traces
229
+
230
+ :return: None
231
+ """
232
+ env = StoppingGameEnv(self.config)
233
+ fixed_timestamp = 123
234
+ with patch("time.time", return_value=fixed_timestamp):
235
+ with patch(
236
+ "csle_common.dao.simulation_config.simulation_trace.SimulationTrace.save_traces"
237
+ ) as mock_save_traces:
238
+ env.traces = ["trace1", "trace2"]
239
+ env._StoppingGameEnv__checkpoint_traces()
240
+ mock_save_traces.assert_called_once_with(
241
+ traces_save_dir=constants.LOGGING.DEFAULT_LOG_DIR,
242
+ traces=env.traces,
243
+ traces_file=f"taus{fixed_timestamp}.json",
244
+ )
245
+
246
+ def test_set_model(self) -> None:
247
+ """
248
+ Tests the function of setting the model
249
+
250
+ :return: None
251
+ """
252
+ env = StoppingGameEnv(self.config)
253
+ mock_model = MagicMock()
254
+ env.set_model(mock_model)
255
+ assert env.model == mock_model
256
+
257
+ def test_set_state(self) -> None:
258
+ """
259
+ Tests the function of setting the state
260
+
261
+ :return: None
262
+ """
263
+ env = StoppingGameEnv(self.config)
264
+ env.state = MagicMock()
265
+
266
+ mock_state = MagicMock(spec=StoppingGameState)
267
+ env.set_state(mock_state)
268
+ assert env.state == mock_state
269
+
270
+ state_int = 5
271
+ env.set_state(state_int)
272
+ assert env.state.s == state_int
273
+ assert env.state.l == self.config.L
274
+
275
+ state_tuple = (3, 7)
276
+ env.set_state(state_tuple)
277
+ assert env.state.s == state_tuple[0]
278
+ assert env.state.l == state_tuple[1]
279
+
280
+ with pytest.raises(ValueError):
281
+ env.set_state([1, 2, 3]) # type: ignore
282
+
283
+ def test_is_state_terminal(self) -> None:
284
+ """
285
+ Tests the function of checking whether a given state is terminal or not
286
+
287
+ :return: None
288
+ """
289
+ env = StoppingGameEnv(self.config)
290
+ env.state = MagicMock()
291
+
292
+ mock_state = MagicMock(spec=StoppingGameState)
293
+ mock_state.s = 2
294
+ assert env.is_state_terminal(mock_state)
295
+ mock_state.s = 1
296
+ assert not env.is_state_terminal(mock_state)
297
+ state_int = 2
298
+ assert env.is_state_terminal(state_int)
299
+ state_int = 1
300
+ assert not env.is_state_terminal(state_int)
301
+ state_tuple = (2, 5)
302
+ assert env.is_state_terminal(state_tuple)
303
+ state_tuple = (1, 5)
304
+ assert not env.is_state_terminal(state_tuple)
305
+
306
+ with pytest.raises(ValueError):
307
+ env.is_state_terminal([1, 2, 3]) # type: ignore
308
+
309
+ def test_get_observation_from_history(self) -> None:
310
+ """
311
+ Tests the function of getting a hidden observation based on a history
312
+
313
+ :return: None
314
+ """
315
+ env = StoppingGameEnv(self.config)
316
+ history = [1, 2, 3, 4, 5]
317
+ pi2 = np.array([0.1, 0.9])
318
+ l = 3
319
+ observation = env.get_observation_from_history(history, pi2, l)
320
+ assert observation == [5]
321
+
322
+ history = []
323
+ with pytest.raises(ValueError, match="History must not be empty"):
324
+ env.get_observation_from_history(history, pi2, l)
325
+
326
+ def test_generate_random_particles(self) -> None:
327
+ """
328
+ Tests the funtion of generating a random list of state particles from a given observation
329
+
330
+ :return: None
331
+ """
332
+ env = StoppingGameEnv(self.config)
333
+ num_particles = 10
334
+ particles = env.generate_random_particles(o=1, num_particles=num_particles)
335
+ assert len(particles) == num_particles
336
+ assert all(p in [0, 1] for p in particles)
337
+
338
+ num_particles = 0
339
+ particles = env.generate_random_particles(o=1, num_particles=num_particles)
340
+ assert len(particles) == num_particles
341
+
342
+ def test_step(self) -> None:
343
+ """
344
+ Tests the funtion of taking a step in the environment by executing the given action
345
+
346
+ :return: None
347
+ """
348
+ env = StoppingGameEnv(self.config)
349
+ env.state = MagicMock()
350
+ env.state.s = 1
351
+ env.state.l = 2
352
+ env.state.t = 0
353
+ env.state.attacker_observation.return_value = np.array([1, 2, 3])
354
+ env.state.defender_observation.return_value = np.array([4, 5, 6])
355
+ env.state.b = np.array([0.5, 0.5, 0.0])
356
+
357
+ env.trace = MagicMock()
358
+ env.trace.defender_rewards = []
359
+ env.trace.attacker_rewards = []
360
+ env.trace.attacker_actions = []
361
+ env.trace.defender_actions = []
362
+ env.trace.infos = []
363
+ env.trace.states = []
364
+ env.trace.beliefs = []
365
+ env.trace.infrastructure_metrics = []
366
+ env.trace.attacker_observations = []
367
+ env.trace.defender_observations = []
368
+
369
+ with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_state",
370
+ return_value=2):
371
+ with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_observation",
372
+ return_value=1):
373
+ with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.next_belief",
374
+ return_value=np.array([0.3, 0.7, 0.0])):
375
+ action_profile = (
376
+ 1,
377
+ (
378
+ np.array(
379
+ [[0.2, 0.8, 0.0], [0.6, 0.4, 0.0], [0.5, 0.5, 0.0]]
380
+ ),
381
+ 2,
382
+ ),
383
+ )
384
+ observations, rewards, terminated, truncated, info = env.step(
385
+ action_profile
386
+ )
387
+
388
+ assert (observations[0] == np.array([4, 5, 6])).all(), "Incorrect defender observations"
389
+ assert (observations[1] == np.array([1, 2, 3])).all(), "Incorrect attacker observations"
390
+ assert rewards == (0, 0)
391
+ assert not terminated
392
+ assert not truncated
393
+ assert env.trace.defender_rewards[-1] == 0
394
+ assert env.trace.attacker_rewards[-1] == 0
395
+ assert env.trace.attacker_actions[-1] == 2
396
+ assert env.trace.defender_actions[-1] == 1
397
+ assert env.trace.infos[-1] == info
398
+ assert env.trace.states[-1] == 2
399
+ print(env.trace.beliefs)
400
+ assert env.trace.beliefs[-1] == 0.7
401
+ assert env.trace.infrastructure_metrics[-1] == 1
402
+ assert (env.trace.attacker_observations[-1] == np.array([1, 2, 3])).all()
403
+ assert (env.trace.defender_observations[-1] == np.array([4, 5, 6])).all()
404
+
405
+ def test_info(self) -> None:
406
+ """
407
+ Tests the function of adding the cumulative reward and episode length to the info dict
408
+
409
+ :return: None
410
+ """
411
+ env = StoppingGameEnv(self.config)
412
+ env.trace = MagicMock()
413
+ env.trace.defender_rewards = [1, 2]
414
+ env.trace.attacker_actions = [0, 1]
415
+ env.trace.defender_actions = [0, 1]
416
+ env.trace.states = [0, 1]
417
+ env.trace.infrastructure_metrics = [0, 1]
418
+ info: Dict[str, Any] = {}
419
+ updated_info = env._info(info)
420
+ assert updated_info[env_constants.ENV_METRICS.RETURN] == sum(env.trace.defender_rewards)
421
+
422
+ def test_emulation_evaluation(self) -> None:
423
+ """
424
+ Tests the function for evaluating a strategy profile in the emulation environment
425
+
426
+ :return: None
427
+ """
428
+ StoppingGameEnv(self.config)
@@ -0,0 +1,343 @@
1
+ from gym_csle_stopping_game.envs.stopping_game_mdp_attacker_env import (
2
+ StoppingGameMdpAttackerEnv,
3
+ )
4
+ from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
5
+ from gym_csle_stopping_game.dao.stopping_game_attacker_mdp_config import (
6
+ StoppingGameAttackerMdpConfig,
7
+ )
8
+ from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
9
+ from csle_common.dao.training.policy import Policy
10
+ import pytest
11
+ from unittest.mock import MagicMock
12
+ import numpy as np
13
+
14
+
15
+ class TestStoppingGameMdpAttackerEnvSuite:
16
+ """
17
+ Test suite for stopping_game_mdp_attacker_env.py
18
+ """
19
+
20
+ @pytest.fixture(autouse=True)
21
+ def setup_env(self) -> None:
22
+ """
23
+ Sets up the configuration of the stopping game
24
+
25
+ :return: None
26
+ """
27
+ env_name = "test_env"
28
+ T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
29
+ O = np.array([0, 1])
30
+ Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]])
31
+ R = np.zeros((2, 3, 3, 3))
32
+ S = np.array([0, 1, 2])
33
+ A1 = np.array([0, 1, 2])
34
+ A2 = np.array([0, 1, 2])
35
+ L = 2
36
+ R_INT = 1
37
+ R_COST = 2
38
+ R_SLA = 3
39
+ R_ST = 4
40
+ b1 = np.array([0.6, 0.4])
41
+ save_dir = "save_directory"
42
+ checkpoint_traces_freq = 100
43
+ gamma = 0.9
44
+ compute_beliefs = True
45
+ save_trace = True
46
+ self.config = StoppingGameConfig(
47
+ env_name,
48
+ T,
49
+ O,
50
+ Z,
51
+ R,
52
+ S,
53
+ A1,
54
+ A2,
55
+ L,
56
+ R_INT,
57
+ R_COST,
58
+ R_SLA,
59
+ R_ST,
60
+ b1,
61
+ save_dir,
62
+ checkpoint_traces_freq,
63
+ gamma,
64
+ compute_beliefs,
65
+ save_trace,
66
+ )
67
+
68
+ def test_init_(self) -> None:
69
+ """
70
+ Tests the initializing function
71
+
72
+ :return: None
73
+ """
74
+ # Mock the defender strategy
75
+ defender_strategy = MagicMock(spec=Policy)
76
+ # Create the attacker MDP configuration
77
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
78
+ env_name="test_env",
79
+ stopping_game_config=self.config,
80
+ defender_strategy=defender_strategy,
81
+ stopping_game_name="csle-stopping-game-v1",
82
+ )
83
+ # Initialize the StoppingGameMdpAttackerEnv
84
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
85
+ assert env.config == attacker_mdp_config
86
+ assert env.observation_space == self.config.attacker_observation_space()
87
+ assert env.action_space == self.config.attacker_action_space()
88
+ assert env.static_defender_strategy == defender_strategy
89
+ # print(env.latest_defender_obs)
90
+ # assert not env.latest_defender_obs
91
+ # assert not env.latest_attacker_obs
92
+ assert not env.model
93
+ assert not env.viewer
94
+
95
+ def test_reset(self) -> None:
96
+ """
97
+ Tests the function for reseting the environment state
98
+
99
+ :return: None
100
+ """
101
+ defender_strategy = MagicMock(spec=Policy)
102
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
103
+ env_name="test_env",
104
+ stopping_game_config=self.config,
105
+ defender_strategy=defender_strategy,
106
+ stopping_game_name="csle-stopping-game-v1",
107
+ )
108
+
109
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
110
+ attacker_obs, info = env.reset()
111
+ assert env.latest_defender_obs.all() == np.array([2, 0.4]).all() # type: ignore
112
+ assert info == {}
113
+
114
+ def test_set_model(self) -> None:
115
+ """
116
+ Tests the function for setting the model
117
+
118
+ :return: None
119
+ """
120
+ defender_strategy = MagicMock(spec=Policy)
121
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
122
+ env_name="test_env",
123
+ stopping_game_config=self.config,
124
+ defender_strategy=defender_strategy,
125
+ stopping_game_name="csle-stopping-game-v1",
126
+ )
127
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
128
+ mock_model = MagicMock()
129
+ env.set_model(mock_model)
130
+ assert env.model == mock_model
131
+
132
+ def test_set_state(self) -> None:
133
+ """
134
+ Tests the function for setting the state
135
+
136
+ :return: None
137
+ """
138
+ defender_strategy = MagicMock(spec=Policy)
139
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
140
+ env_name="test_env",
141
+ stopping_game_config=self.config,
142
+ defender_strategy=defender_strategy,
143
+ stopping_game_name="csle-stopping-game-v1",
144
+ )
145
+
146
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
147
+ assert not env.set_state(1) # type: ignore
148
+
149
+ def test_calculate_stage_policy(self) -> None:
150
+ """
151
+ Tests the function for calculating the stage policy of a given model and observation
152
+
153
+ :return: None
154
+ """
155
+ defender_strategy = MagicMock(spec=Policy)
156
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
157
+ env_name="test_env",
158
+ stopping_game_config=self.config,
159
+ defender_strategy=defender_strategy,
160
+ stopping_game_name="csle-stopping-game-v1",
161
+ )
162
+
163
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
164
+ env.model = None
165
+ observation = [1, 0.5]
166
+ stage_policy = env.calculate_stage_policy(o=observation)
167
+ expected_stage_policy = np.array([[1.0, 0.0], [1.0, 0.0], [0.5, 0.5]])
168
+ assert stage_policy.all() == expected_stage_policy.all()
169
+
170
+ def test_get_attacker_dist(self) -> None:
171
+ """
172
+ Tests the function for getting the attacker's action distribution based on a given observation
173
+
174
+ :return: None
175
+ """
176
+ defender_strategy = MagicMock(spec=Policy)
177
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
178
+ env_name="test_env",
179
+ stopping_game_config=self.config,
180
+ defender_strategy=defender_strategy,
181
+ stopping_game_name="csle-stopping-game-v1",
182
+ )
183
+
184
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
185
+ env.model = None
186
+ observation = [1, 0.5, 0]
187
+ with pytest.raises(ValueError, match="Model is None"):
188
+ env._get_attacker_dist(observation)
189
+
190
+ def test_render(self) -> None:
191
+ """
192
+ Tests the function for rendering the environment
193
+
194
+ :return: None
195
+ """
196
+ defender_strategy = MagicMock(spec=Policy)
197
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
198
+ env_name="test_env",
199
+ stopping_game_config=self.config,
200
+ defender_strategy=defender_strategy,
201
+ stopping_game_name="csle-stopping-game-v1",
202
+ )
203
+
204
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
205
+ with pytest.raises(NotImplementedError):
206
+ env.render("human")
207
+
208
+ def test_is_defense_action_legal(self) -> None:
209
+ """
210
+ Tests the function of checking whether a defender action in the environment is legal or not
211
+
212
+ :return: None
213
+ """
214
+ defender_strategy = MagicMock(spec=Policy)
215
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
216
+ env_name="test_env",
217
+ stopping_game_config=self.config,
218
+ defender_strategy=defender_strategy,
219
+ stopping_game_name="csle-stopping-game-v1",
220
+ )
221
+
222
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
223
+ assert env.is_defense_action_legal(1)
224
+
225
+ def test_is_attack_action_legal(self) -> None:
226
+ """
227
+ Tests the function of checking whether an attacker action in the environment is legal or not
228
+
229
+ :return: None
230
+ """
231
+ defender_strategy = MagicMock(spec=Policy)
232
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
233
+ env_name="test_env",
234
+ stopping_game_config=self.config,
235
+ defender_strategy=defender_strategy,
236
+ stopping_game_name="csle-stopping-game-v1",
237
+ )
238
+
239
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
240
+ assert env.is_attack_action_legal(1)
241
+
242
+ def test_get_traces(self) -> None:
243
+ """
244
+ Tests the function of getting the list of simulation traces
245
+
246
+ :return: None
247
+ """
248
+ defender_strategy = MagicMock(spec=Policy)
249
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
250
+ env_name="test_env",
251
+ stopping_game_config=self.config,
252
+ defender_strategy=defender_strategy,
253
+ stopping_game_name="csle-stopping-game-v1",
254
+ )
255
+
256
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
257
+ assert env.get_traces() == StoppingGameEnv(self.config).traces
258
+
259
+ def test_reset_traces(self) -> None:
260
+ """
261
+ Tests the function of resetting the list of traces
262
+
263
+ :return: None
264
+ """
265
+ defender_strategy = MagicMock(spec=Policy)
266
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
267
+ env_name="test_env",
268
+ stopping_game_config=self.config,
269
+ defender_strategy=defender_strategy,
270
+ stopping_game_name="csle-stopping-game-v1",
271
+ )
272
+
273
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
274
+ env.traces = ["trace1", "trace2"]
275
+ env.reset_traces()
276
+ assert StoppingGameEnv(self.config).traces == []
277
+
278
+ def test_generate_random_particles(self) -> None:
279
+ """
280
+ Tests the funtion of generating a random list of state particles from a given observation
281
+
282
+ :return: None
283
+ """
284
+ defender_strategy = MagicMock(spec=Policy)
285
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
286
+ env_name="test_env",
287
+ stopping_game_config=self.config,
288
+ defender_strategy=defender_strategy,
289
+ stopping_game_name="csle-stopping-game-v1",
290
+ )
291
+
292
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
293
+ num_particles = 10
294
+ particles = env.generate_random_particles(o=1, num_particles=num_particles)
295
+ assert len(particles) == num_particles
296
+ assert all(p in [0, 1] for p in particles)
297
+
298
+ num_particles = 0
299
+ particles = env.generate_random_particles(o=1, num_particles=num_particles)
300
+ assert len(particles) == num_particles
301
+
302
+ def test_get_actions_from_particles(self) -> None:
303
+ """
304
+ Tests the function for pruning the set of actions based on the current particle set
305
+
306
+ :return: None
307
+ """
308
+ defender_strategy = MagicMock(spec=Policy)
309
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
310
+ env_name="test_env",
311
+ stopping_game_config=self.config,
312
+ defender_strategy=defender_strategy,
313
+ stopping_game_name="csle-stopping-game-v1",
314
+ )
315
+
316
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
317
+ particles = [1, 2, 3]
318
+ t = 0
319
+ observation = 0
320
+ expected_actions = [0, 1, 2]
321
+ assert (
322
+ env.get_actions_from_particles(particles, t, observation)
323
+ == expected_actions
324
+ )
325
+
326
+ def test_step(self) -> None:
327
+ """
328
+ Tests the function for taking a step in the environment by executing the given action
329
+
330
+ :return: None
331
+ """
332
+ defender_strategy = MagicMock(spec=Policy)
333
+ attacker_mdp_config = StoppingGameAttackerMdpConfig(
334
+ env_name="test_env",
335
+ stopping_game_config=self.config,
336
+ defender_strategy=defender_strategy,
337
+ stopping_game_name="csle-stopping-game-v1",
338
+ )
339
+
340
+ env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
341
+ pi2 = np.array([[0.5, 0.5]])
342
+ with pytest.raises(AssertionError):
343
+ env.step(pi2)
@@ -0,0 +1,336 @@
1
+ from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import StoppingGamePomdpDefenderEnv
2
+ from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
3
+ from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import StoppingGameDefenderPomdpConfig
4
+ from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
5
+ from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
6
+ from csle_common.dao.training.policy import Policy
7
+ from csle_common.dao.training.random_policy import RandomPolicy
8
+ from csle_common.dao.training.player_type import PlayerType
9
+ import pytest
10
+ from unittest.mock import MagicMock
11
+ import numpy as np
12
+
13
+
14
+ class TestStoppingGamePomdpDefenderEnvSuite:
15
+ """
16
+ Test suite for stopping_game_pomdp_defender_env.py
17
+ """
18
+
19
+ @pytest.fixture(autouse=True)
20
+ def setup_env(self) -> None:
21
+ """
22
+ Sets up the configuration of the stopping game
23
+
24
+ :return: None
25
+ """
26
+ env_name = "test_env"
27
+ T = StoppingGameUtil.transition_tensor(L=3, p=0)
28
+ O = StoppingGameUtil.observation_space(n=100)
29
+ Z = StoppingGameUtil.observation_tensor(n=100)
30
+ R = np.zeros((2, 3, 3, 3))
31
+ S = StoppingGameUtil.state_space()
32
+ A1 = StoppingGameUtil.defender_actions()
33
+ A2 = StoppingGameUtil.attacker_actions()
34
+ L = 2
35
+ R_INT = 1
36
+ R_COST = 2
37
+ R_SLA = 3
38
+ R_ST = 4
39
+ b1 = StoppingGameUtil.b1()
40
+ save_dir = "save_directory"
41
+ checkpoint_traces_freq = 100
42
+ gamma = 0.9
43
+ compute_beliefs = True
44
+ save_trace = True
45
+ self.config = StoppingGameConfig(
46
+ env_name,
47
+ T,
48
+ O,
49
+ Z,
50
+ R,
51
+ S,
52
+ A1,
53
+ A2,
54
+ L,
55
+ R_INT,
56
+ R_COST,
57
+ R_SLA,
58
+ R_ST,
59
+ b1,
60
+ save_dir,
61
+ checkpoint_traces_freq,
62
+ gamma,
63
+ compute_beliefs,
64
+ save_trace,
65
+ )
66
+
67
+ def test_init_(self) -> None:
68
+ """
69
+ Tests the initializing function
70
+
71
+ :return: None
72
+ """
73
+ # Mock the attacker strategy
74
+ attacker_strategy = MagicMock(spec=Policy)
75
+ # Create the defender POMDP configuration
76
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
77
+ env_name="test_env",
78
+ stopping_game_config=self.config,
79
+ attacker_strategy=attacker_strategy,
80
+ stopping_game_name="csle-stopping-game-v1",
81
+ )
82
+ # Initialize the StoppingGamePomdpDefenderEnv
83
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
84
+ assert env.config == defender_pomdp_config
85
+ assert env.observation_space == self.config.defender_observation_space()
86
+ assert env.action_space == self.config.defender_action_space()
87
+ assert env.static_attacker_strategy == attacker_strategy
88
+ assert not env.viewer
89
+
90
+ def test_reset(self) -> None:
91
+ """
92
+ Tests the function for reseting the environment state
93
+
94
+ :return: None
95
+ """
96
+ attacker_strategy = MagicMock(spec=Policy)
97
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
98
+ env_name="test_env",
99
+ stopping_game_config=self.config,
100
+ attacker_strategy=attacker_strategy,
101
+ stopping_game_name="csle-stopping-game-v1",
102
+ )
103
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
104
+ _, info = env.reset()
105
+ assert info
106
+
107
+ def test_render(self) -> None:
108
+ """
109
+ Tests the function for rendering the environment
110
+
111
+ :return: None
112
+ """
113
+ attacker_strategy = MagicMock(spec=Policy)
114
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
115
+ env_name="test_env",
116
+ stopping_game_config=self.config,
117
+ attacker_strategy=attacker_strategy,
118
+ stopping_game_name="csle-stopping-game-v1",
119
+ )
120
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
121
+ with pytest.raises(NotImplementedError):
122
+ env.render("human")
123
+
124
+ def test_is_defense_action_legal(self) -> None:
125
+ """
126
+ Tests the function of checking whether a defender action in the environment is legal or not
127
+
128
+ :return: None
129
+ """
130
+ attacker_strategy = MagicMock(spec=Policy)
131
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
132
+ env_name="test_env",
133
+ stopping_game_config=self.config,
134
+ attacker_strategy=attacker_strategy,
135
+ stopping_game_name="csle-stopping-game-v1",
136
+ )
137
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
138
+ assert env.is_defense_action_legal(1)
139
+
140
+ def test_is_attack_action_legal(self) -> None:
141
+ """
142
+ Tests the function of checking whether an attacker action in the environment is legal or not
143
+
144
+ :return: None
145
+ """
146
+ attacker_strategy = MagicMock(spec=Policy)
147
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
148
+ env_name="test_env",
149
+ stopping_game_config=self.config,
150
+ attacker_strategy=attacker_strategy,
151
+ stopping_game_name="csle-stopping-game-v1",
152
+ )
153
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
154
+ assert env.is_attack_action_legal(1)
155
+
156
+ def test_get_traces(self) -> None:
157
+ """
158
+ Tests the function of getting the list of simulation traces
159
+
160
+ :return: None
161
+ """
162
+ attacker_strategy = MagicMock(spec=Policy)
163
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
164
+ env_name="test_env",
165
+ stopping_game_config=self.config,
166
+ attacker_strategy=attacker_strategy,
167
+ stopping_game_name="csle-stopping-game-v1",
168
+ )
169
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
170
+ assert env.get_traces() == StoppingGameEnv(self.config).traces
171
+
172
+ def test_reset_traces(self) -> None:
173
+ """
174
+ Tests the function of resetting the list of traces
175
+
176
+ :return: None
177
+ """
178
+ attacker_strategy = MagicMock(spec=Policy)
179
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
180
+ env_name="test_env",
181
+ stopping_game_config=self.config,
182
+ attacker_strategy=attacker_strategy,
183
+ stopping_game_name="csle-stopping-game-v1",
184
+ )
185
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
186
+ env.traces = ["trace1", "trace2"]
187
+ env.reset_traces()
188
+ assert StoppingGameEnv(self.config).traces == []
189
+
190
+ def test_set_model(self) -> None:
191
+ """
192
+ Tests the function for setting the model
193
+
194
+ :return: None
195
+ """
196
+ attacker_strategy = MagicMock(spec=Policy)
197
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
198
+ env_name="test_env",
199
+ stopping_game_config=self.config,
200
+ attacker_strategy=attacker_strategy,
201
+ stopping_game_name="csle-stopping-game-v1",
202
+ )
203
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
204
+ mock_model = MagicMock()
205
+ env.set_model(mock_model)
206
+ assert env.model == mock_model
207
+
208
+ def test_set_state(self) -> None:
209
+ """
210
+ Tests the function for setting the state
211
+
212
+ :return: None
213
+ """
214
+ attacker_strategy = MagicMock(spec=Policy)
215
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
216
+ env_name="test_env",
217
+ stopping_game_config=self.config,
218
+ attacker_strategy=attacker_strategy,
219
+ stopping_game_name="csle-stopping-game-v1",
220
+ )
221
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
222
+ assert env.set_state(1) is None # type: ignore
223
+
224
+ def test_get_observation_from_history(self) -> None:
225
+ """
226
+ Tests the function for getting a defender observation (belief) from a history
227
+
228
+ :return: None
229
+ """
230
+ attacker_strategy = MagicMock(spec=Policy)
231
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
232
+ env_name="test_env",
233
+ stopping_game_config=self.config,
234
+ attacker_strategy=attacker_strategy,
235
+ stopping_game_name="csle-stopping-game-v1",
236
+ )
237
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
238
+ history = [1, 2, 3]
239
+ l = self.config.L
240
+ pi2 = env.static_attacker_strategy.stage_policy(o=0)
241
+ assert env.get_observation_from_history(history) == StoppingGameEnv(
242
+ self.config
243
+ ).get_observation_from_history(history, pi2, l)
244
+
245
+ def test_is_state_terminal(self) -> None:
246
+ """
247
+ Tests the funciton for checking whether a state is terminal or not
248
+
249
+ :return: None
250
+ """
251
+ attacker_strategy = MagicMock(spec=Policy)
252
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
253
+ env_name="test_env",
254
+ stopping_game_config=self.config,
255
+ attacker_strategy=attacker_strategy,
256
+ stopping_game_name="csle-stopping-game-v1",
257
+ )
258
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
259
+ assert env.is_state_terminal(1) == StoppingGameEnv(
260
+ self.config
261
+ ).is_state_terminal(1)
262
+
263
+ def test_generate_random_particles(self) -> None:
264
+ """
265
+ Tests the funtion of generating a random list of state particles from a given observation
266
+
267
+ :return: None
268
+ """
269
+ attacker_strategy = MagicMock(spec=Policy)
270
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
271
+ env_name="test_env",
272
+ stopping_game_config=self.config,
273
+ attacker_strategy=attacker_strategy,
274
+ stopping_game_name="csle-stopping-game-v1",
275
+ )
276
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
277
+ num_particles = 10
278
+ particles = env.generate_random_particles(o=1, num_particles=num_particles)
279
+ assert len(particles) == num_particles
280
+ assert all(p in [0, 1] for p in particles)
281
+
282
+ num_particles = 0
283
+ particles = env.generate_random_particles(o=1, num_particles=num_particles)
284
+ assert len(particles) == num_particles
285
+
286
+ def test_get_actions_from_particles(self) -> None:
287
+ """
288
+ Tests the function for pruning the set of actions based on the current particle set
289
+
290
+ :return: None
291
+ """
292
+ attacker_strategy = MagicMock(spec=Policy)
293
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
294
+ env_name="test_env",
295
+ stopping_game_config=self.config,
296
+ attacker_strategy=attacker_strategy,
297
+ stopping_game_name="csle-stopping-game-v1",
298
+ )
299
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
300
+ particles = [1, 2, 3]
301
+ t = 0
302
+ observation = 0
303
+ expected_actions = [0, 1]
304
+ assert env.get_actions_from_particles(particles, t, observation) == expected_actions
305
+
306
+ def test_step(self) -> None:
307
+ """
308
+ Tests the function for taking a step in the environment by executing the given action
309
+
310
+ :return: None
311
+ """
312
+ attacker_stage_strategy = np.zeros((3, 2))
313
+ attacker_stage_strategy[0][0] = 0.9
314
+ attacker_stage_strategy[0][1] = 0.1
315
+ attacker_stage_strategy[1][0] = 0.9
316
+ attacker_stage_strategy[1][1] = 0.1
317
+ attacker_stage_strategy[2] = attacker_stage_strategy[1]
318
+ attacker_strategy = RandomPolicy(actions=list(self.config.A2), player_type=PlayerType.ATTACKER,
319
+ stage_policy_tensor=list(attacker_stage_strategy))
320
+ defender_pomdp_config = StoppingGameDefenderPomdpConfig(
321
+ env_name="test_env",
322
+ stopping_game_config=self.config,
323
+ attacker_strategy=attacker_strategy,
324
+ stopping_game_name="csle-stopping-game-v1",
325
+ )
326
+ env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
327
+ a1 = 1
328
+ env.reset()
329
+ defender_obs, reward, terminated, truncated, info = env.step(a1)
330
+ assert len(defender_obs) == 2
331
+ assert isinstance(defender_obs[0], float) # type: ignore
332
+ assert isinstance(defender_obs[1], float) # type: ignore
333
+ assert isinstance(reward, float) # type: ignore
334
+ assert isinstance(terminated, bool) # type: ignore
335
+ assert isinstance(truncated, bool) # type: ignore
336
+ assert isinstance(info, dict) # type: ignore
@@ -1 +0,0 @@
1
- __version__ = '0.6.1'