gym-csle-stopping-game 0.6.1__tar.gz → 0.6.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of gym-csle-stopping-game might be problematic. Click here for more details.

Files changed (31) hide show
  1. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/PKG-INFO +1 -1
  2. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/setup.cfg +5 -5
  3. gym_csle_stopping_game-0.6.3/src/gym_csle_stopping_game/__version__.py +1 -0
  4. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/envs/stopping_game_env.py +4 -107
  5. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/envs/stopping_game_pomdp_defender_env.py +0 -27
  6. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game.egg-info/PKG-INFO +1 -1
  7. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game.egg-info/SOURCES.txt +3 -0
  8. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game.egg-info/requires.txt +5 -5
  9. gym_csle_stopping_game-0.6.3/tests/test_stopping_game_env.py +420 -0
  10. gym_csle_stopping_game-0.6.3/tests/test_stopping_game_mdp_attacker_env.py +343 -0
  11. gym_csle_stopping_game-0.6.3/tests/test_stopping_game_pomdp_defender_env.py +336 -0
  12. gym_csle_stopping_game-0.6.1/src/gym_csle_stopping_game/__version__.py +0 -1
  13. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/pyproject.toml +0 -0
  14. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/setup.py +0 -0
  15. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/__init__.py +0 -0
  16. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/constants/__init__.py +0 -0
  17. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/constants/constants.py +0 -0
  18. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/dao/__init__.py +0 -0
  19. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/dao/stopping_game_attacker_mdp_config.py +0 -0
  20. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/dao/stopping_game_config.py +0 -0
  21. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/dao/stopping_game_defender_pomdp_config.py +0 -0
  22. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/dao/stopping_game_state.py +0 -0
  23. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/envs/__init__.py +0 -0
  24. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/envs/stopping_game_mdp_attacker_env.py +0 -0
  25. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/util/__init__.py +0 -0
  26. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game/util/stopping_game_util.py +0 -0
  27. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game.egg-info/dependency_links.txt +0 -0
  28. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game.egg-info/not-zip-safe +0 -0
  29. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/src/gym_csle_stopping_game.egg-info/top_level.txt +0 -0
  30. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/tests/test_stopping_game_dao.py +0 -0
  31. {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.3}/tests/test_stopping_game_util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gym_csle_stopping_game
3
- Version: 0.6.1
3
+ Version: 0.6.3
4
4
  Summary: OpenAI gym reinforcement learning environment of a Dynkin (Optimal stopping) game in CSLE
5
5
  Author: Kim Hammar
6
6
  Author-email: hammar.kim@gmail.com
@@ -20,11 +20,11 @@ classifiers =
20
20
  [options]
21
21
  install_requires =
22
22
  gymnasium>=0.27.1
23
- csle-base>=0.6.1
24
- csle-common>=0.6.1
25
- csle-attacker>=0.6.1
26
- csle-defender>=0.6.1
27
- csle-collector>=0.6.1
23
+ csle-base>=0.6.3
24
+ csle-common>=0.6.3
25
+ csle-attacker>=0.6.3
26
+ csle-defender>=0.6.3
27
+ csle-collector>=0.6.3
28
28
  python_requires = >=3.8
29
29
  package_dir =
30
30
  =src
@@ -0,0 +1 @@
1
+ __version__ = '0.6.3'
@@ -7,24 +7,10 @@ import math
7
7
  import csle_common.constants.constants as constants
8
8
  from csle_common.dao.simulation_config.base_env import BaseEnv
9
9
  from csle_common.dao.simulation_config.simulation_trace import SimulationTrace
10
- from csle_common.dao.training.policy import Policy
11
- from csle_common.dao.emulation_config.emulation_env_state import EmulationEnvState
12
- from csle_common.dao.emulation_config.emulation_env_config import EmulationEnvConfig
13
- from csle_common.dao.simulation_config.simulation_env_config import SimulationEnvConfig
14
- from csle_common.dao.emulation_config.emulation_simulation_trace import EmulationSimulationTrace
15
- from csle_common.dao.emulation_action.attacker.emulation_attacker_stopping_actions \
16
- import EmulationAttackerStoppingActions
17
- from csle_common.dao.emulation_action.attacker.emulation_attacker_action import EmulationAttackerAction
18
- from csle_common.dao.emulation_action.defender.emulation_defender_stopping_actions \
19
- import EmulationDefenderStoppingActions
20
- from csle_common.metastore.metastore_facade import MetastoreFacade
21
- from csle_common.logging.log import Logger
22
- from csle_system_identification.emulator import Emulator
23
10
  from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
24
11
  from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
25
12
  from gym_csle_stopping_game.dao.stopping_game_state import StoppingGameState
26
13
  import gym_csle_stopping_game.constants.constants as env_constants
27
- from csle_common.dao.emulation_config.emulation_trace import EmulationTrace
28
14
 
29
15
 
30
16
  class StoppingGameEnv(BaseEnv):
@@ -42,7 +28,6 @@ class StoppingGameEnv(BaseEnv):
42
28
 
43
29
  # Initialize environment state
44
30
  self.state = StoppingGameState(b1=self.config.b1, L=self.config.L)
45
-
46
31
  # Setup spaces
47
32
  self.attacker_observation_space = self.config.attacker_observation_space()
48
33
  self.defender_observation_space = self.config.defender_observation_space()
@@ -73,7 +58,7 @@ class StoppingGameEnv(BaseEnv):
73
58
  a1, a2_profile = action_profile
74
59
  pi2, a2 = a2_profile
75
60
  assert pi2.shape[0] == len(self.config.S)
76
- assert pi2.shape[1] == len(self.config.A1)
61
+ assert pi2.shape[1] == len(self.config.A2)
77
62
  done = False
78
63
  info: Dict[str, Any] = {}
79
64
 
@@ -84,8 +69,7 @@ class StoppingGameEnv(BaseEnv):
84
69
  else:
85
70
  # Compute r, s', b',o'
86
71
  r = self.config.R[self.state.l - 1][a1][a2][self.state.s]
87
- self.state.s = StoppingGameUtil.sample_next_state(l=self.state.l, a1=a1, a2=a2,
88
- T=self.config.T,
72
+ self.state.s = StoppingGameUtil.sample_next_state(l=self.state.l, a1=a1, a2=a2, T=self.config.T,
89
73
  S=self.config.S, s=self.state.s)
90
74
  o = StoppingGameUtil.sample_next_observation(Z=self.config.Z,
91
75
  O=self.config.O, s_prime=self.state.s)
@@ -246,95 +230,6 @@ class StoppingGameEnv(BaseEnv):
246
230
  info[env_constants.ENV_METRICS.TIME_STEP] = self.state.t
247
231
  return (defender_obs, attacker_obs), info
248
232
 
249
- @staticmethod
250
- def emulation_evaluation(env: "StoppingGameEnv", n_episodes: int, intrusion_seq: List[EmulationAttackerAction],
251
- defender_policy: Policy,
252
- attacker_policy: Policy,
253
- emulation_env_config: EmulationEnvConfig,
254
- simulation_env_config: SimulationEnvConfig
255
- ) -> List[EmulationSimulationTrace]:
256
- """
257
- Utility function for evaluating a strategy profile in the emulation environment
258
-
259
- :param env: the environment to use for evaluation
260
- :param n_episodes: the number of evaluation episodes
261
- :param intrusion_seq: the intrusion sequence for the evaluation (sequence of attacker actions)
262
- :param defender_policy: the defender policy for the evaluation
263
- :param attacker_policy: the attacker policy for the evaluation
264
- :param emulation_env_config: configuration of the emulation environment for the evaluation
265
- :param simulation_env_config: configuration of the simulation environment for the evaluation
266
- :return: traces with the evaluation results
267
- """
268
- logger = Logger.__call__().get_logger()
269
- traces = []
270
- s = EmulationEnvState(emulation_env_config=emulation_env_config)
271
- s.initialize_defender_machines()
272
- for i in range(n_episodes):
273
- done = False
274
- defender_obs_space = simulation_env_config.joint_observation_space_config.observation_spaces[0]
275
- b = env.state.b1
276
- o, _ = env.reset()
277
- (d_obs, a_obs) = o
278
- t = 0
279
- s.reset()
280
- emulation_trace = EmulationTrace(initial_attacker_observation_state=s.attacker_obs_state,
281
- initial_defender_observation_state=s.defender_obs_state,
282
- emulation_name=emulation_env_config.name)
283
- simulation_trace = SimulationTrace(simulation_env=env.config.env_name)
284
- while not done:
285
- a1 = defender_policy.action(d_obs)
286
- a2 = attacker_policy.action(a_obs)
287
- o, r, done, info, _ = env.step((a1, a2))
288
- (d_obs, a_obs) = o
289
- r_1, r_2 = r
290
- logger.debug(f"a1:{a1}, a2:{a2}, d_obs:{d_obs}, a_obs:{a_obs}, r:{r}, done:{done}, info: {info}")
291
- if a1 == 0:
292
- defender_action = EmulationDefenderStoppingActions.CONTINUE(index=-1)
293
- else:
294
- defender_action = EmulationDefenderStoppingActions.CONTINUE(index=-1)
295
- if env.state.s == 1:
296
- if t >= len(intrusion_seq):
297
- t = 0
298
- attacker_action = intrusion_seq[t]
299
- else:
300
- attacker_action = EmulationAttackerStoppingActions.CONTINUE(index=-1)
301
- emulation_trace, s = Emulator.run_actions(
302
- s=s,
303
- emulation_env_config=emulation_env_config, attacker_action=attacker_action,
304
- defender_action=defender_action, trace=emulation_trace,
305
- sleep_time=emulation_env_config.kafka_config.time_step_len_seconds)
306
- o_components = [s.defender_obs_state.snort_ids_alert_counters.severe_alerts,
307
- s.defender_obs_state.snort_ids_alert_counters.warning_alerts,
308
- s.defender_obs_state.aggregated_host_metrics.num_failed_login_attempts]
309
- o_components_str = ",".join(list(map(lambda x: str(x), o_components)))
310
- logger.debug(f"o_components:{o_components}")
311
- logger.debug(f"observation_id_to_observation_vector_inv:"
312
- f"{defender_obs_space.observation_id_to_observation_vector_inv}")
313
- logger.debug(f"observation_id_to_observation_vector_inv:"
314
- f"{o_components_str in defender_obs_space.observation_id_to_observation_vector_inv}")
315
- emulation_o = 0
316
- if o_components_str in defender_obs_space.observation_id_to_observation_vector_inv:
317
- emulation_o = defender_obs_space.observation_id_to_observation_vector_inv[o_components_str]
318
- logger.debug(f"o:{emulation_o}")
319
- b = StoppingGameUtil.next_belief(o=emulation_o, a1=a1, b=b, pi2=a2, config=env.config,
320
- l=env.state.l, a2=a2)
321
- d_obs[1] = b[1]
322
- a_obs[1] = b[1]
323
- logger.debug(f"b:{b}")
324
- simulation_trace.defender_rewards.append(r_1)
325
- simulation_trace.attacker_rewards.append(r_2)
326
- simulation_trace.attacker_actions.append(a2)
327
- simulation_trace.defender_actions.append(a1)
328
- simulation_trace.infos.append(info)
329
- simulation_trace.states.append(s)
330
- simulation_trace.beliefs.append(b[1])
331
- simulation_trace.infrastructure_metrics.append(emulation_o)
332
-
333
- em_sim_trace = EmulationSimulationTrace(emulation_trace=emulation_trace, simulation_trace=simulation_trace)
334
- MetastoreFacade.save_emulation_simulation_trace(em_sim_trace)
335
- traces.append(em_sim_trace)
336
- return traces
337
-
338
233
  def render(self, mode: str = 'human'):
339
234
  """
340
235
  Renders the environment. Supported rendering modes: (1) human; and (2) rgb_array
@@ -437,6 +332,8 @@ class StoppingGameEnv(BaseEnv):
437
332
  :param l: the number of stops remaining
438
333
  :return: the observation
439
334
  """
335
+ if not history:
336
+ raise ValueError("History must not be empty")
440
337
  return [history[-1]]
441
338
 
442
339
  def generate_random_particles(self, o: int, num_particles: int) -> List[int]:
@@ -4,12 +4,7 @@ import numpy.typing as npt
4
4
  from csle_common.dao.simulation_config.base_env import BaseEnv
5
5
  from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import StoppingGameDefenderPomdpConfig
6
6
  from csle_common.dao.simulation_config.simulation_trace import SimulationTrace
7
- from csle_common.dao.training.policy import Policy
8
- from csle_common.dao.emulation_config.emulation_env_config import EmulationEnvConfig
9
- from csle_common.dao.simulation_config.simulation_env_config import SimulationEnvConfig
10
- from csle_common.dao.emulation_config.emulation_simulation_trace import EmulationSimulationTrace
11
7
  from csle_common.dao.emulation_config.emulation_trace import EmulationTrace
12
- from csle_common.dao.emulation_action.attacker.emulation_attacker_action import EmulationAttackerAction
13
8
  from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
14
9
  from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
15
10
 
@@ -103,28 +98,6 @@ class StoppingGamePomdpDefenderEnv(BaseEnv):
103
98
  defender_obs = o[0]
104
99
  return defender_obs, r[0], d, info
105
100
 
106
- @staticmethod
107
- def emulation_evaluation(env: "StoppingGamePomdpDefenderEnv",
108
- n_episodes: int, intrusion_seq: List[EmulationAttackerAction],
109
- defender_policy: Policy,
110
- emulation_env_config: EmulationEnvConfig, simulation_env_config: SimulationEnvConfig) \
111
- -> List[EmulationSimulationTrace]:
112
- """
113
- Utility function for evaluating policies in the emulation environment
114
-
115
- :param env: the environment to use for evaluation
116
- :param n_episodes: the number of episodes to use for evaluation
117
- :param intrusion_seq: the sequence of intrusion actions to use for evaluation
118
- :param defender_policy: the defender policy to use for evaluation
119
- :param emulation_env_config: the configuration of the emulation environment to use for evaluation
120
- :param simulation_env_config: the configuration of the simulation environment to use for evaluation
121
- :return: traces with the evaluation results
122
- """
123
- return StoppingGameEnv.emulation_evaluation(
124
- env=env.stopping_game_env, n_episodes=n_episodes, intrusion_seq=intrusion_seq,
125
- defender_policy=defender_policy, attacker_policy=env.static_attacker_strategy,
126
- emulation_env_config=emulation_env_config, simulation_env_config=simulation_env_config)
127
-
128
101
  def is_defense_action_legal(self, defense_action_id: int) -> bool:
129
102
  """
130
103
  Checks whether a defender action in the environment is legal or not
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gym-csle-stopping-game
3
- Version: 0.6.1
3
+ Version: 0.6.3
4
4
  Summary: OpenAI gym reinforcement learning environment of a Dynkin (Optimal stopping) game in CSLE
5
5
  Author: Kim Hammar
6
6
  Author-email: hammar.kim@gmail.com
@@ -23,4 +23,7 @@ src/gym_csle_stopping_game/envs/stopping_game_pomdp_defender_env.py
23
23
  src/gym_csle_stopping_game/util/__init__.py
24
24
  src/gym_csle_stopping_game/util/stopping_game_util.py
25
25
  tests/test_stopping_game_dao.py
26
+ tests/test_stopping_game_env.py
27
+ tests/test_stopping_game_mdp_attacker_env.py
28
+ tests/test_stopping_game_pomdp_defender_env.py
26
29
  tests/test_stopping_game_util.py
@@ -1,9 +1,9 @@
1
1
  gymnasium>=0.27.1
2
- csle-base>=0.6.1
3
- csle-common>=0.6.1
4
- csle-attacker>=0.6.1
5
- csle-defender>=0.6.1
6
- csle-collector>=0.6.1
2
+ csle-base>=0.6.3
3
+ csle-common>=0.6.3
4
+ csle-attacker>=0.6.3
5
+ csle-defender>=0.6.3
6
+ csle-collector>=0.6.3
7
7
 
8
8
  [testing]
9
9
  pytest>=6.0
@@ -0,0 +1,420 @@
1
+ from typing import Dict, Any
2
+ import pytest
3
+ from unittest.mock import patch, MagicMock
4
+ from gymnasium.spaces import Box, Discrete
5
+ import numpy as np
6
+ from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
7
+ from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
8
+ from gym_csle_stopping_game.dao.stopping_game_state import StoppingGameState
9
+ import gym_csle_stopping_game.constants.constants as env_constants
10
+ from csle_common.constants import constants
11
+
12
+
13
+ class TestStoppingGameEnvSuite:
14
+ """
15
+ Test suite for stopping_game_env.py
16
+ """
17
+
18
+ @pytest.fixture(autouse=True)
19
+ def setup_env(self) -> None:
20
+ """
21
+ Sets up the configuration of the stopping game
22
+
23
+ :return: None
24
+ """
25
+ env_name = "test_env"
26
+ T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
27
+ O = np.array([0, 1])
28
+ Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]])
29
+ R = np.zeros((2, 3, 3, 3))
30
+ S = np.array([0, 1, 2])
31
+ A1 = np.array([0, 1, 2])
32
+ A2 = np.array([0, 1, 2])
33
+ L = 2
34
+ R_INT = 1
35
+ R_COST = 2
36
+ R_SLA = 3
37
+ R_ST = 4
38
+ b1 = np.array([0.6, 0.4])
39
+ save_dir = "save_directory"
40
+ checkpoint_traces_freq = 100
41
+ gamma = 0.9
42
+ compute_beliefs = True
43
+ save_trace = True
44
+ self.config = StoppingGameConfig(
45
+ env_name,
46
+ T,
47
+ O,
48
+ Z,
49
+ R,
50
+ S,
51
+ A1,
52
+ A2,
53
+ L,
54
+ R_INT,
55
+ R_COST,
56
+ R_SLA,
57
+ R_ST,
58
+ b1,
59
+ save_dir,
60
+ checkpoint_traces_freq,
61
+ gamma,
62
+ compute_beliefs,
63
+ save_trace,
64
+ )
65
+
66
+ def test_stopping_game_init_(self) -> None:
67
+ """
68
+ Tests the initializing function
69
+
70
+ :return: None
71
+ """
72
+ T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
73
+ O = np.array([0, 1])
74
+ A1 = np.array([0, 1, 2])
75
+ A2 = np.array([0, 1, 2])
76
+ L = 2
77
+ b1 = np.array([0.6, 0.4])
78
+ attacker_observation_space = Box(
79
+ low=np.array([0.0, 0.0, 0.0]),
80
+ high=np.array([float(L), 1.0, 2.0]),
81
+ dtype=np.float64,
82
+ )
83
+ defender_observation_space = Box(
84
+ low=np.array([0.0, 0.0]),
85
+ high=np.array([float(L), 1.0]),
86
+ dtype=np.float64,
87
+ )
88
+ attacker_action_space = Discrete(len(A2))
89
+ defender_action_space = Discrete(len(A1))
90
+
91
+ assert self.config.T.any() == T.any()
92
+ assert self.config.O.any() == O.any()
93
+ assert self.config.b1.any() == b1.any()
94
+ assert self.config.L == L
95
+
96
+ env = StoppingGameEnv(self.config)
97
+ assert env.config == self.config
98
+ assert env.attacker_observation_space.low.any() == attacker_observation_space.low.any()
99
+ assert env.defender_observation_space.low.any() == defender_observation_space.low.any()
100
+ assert env.attacker_action_space.n == attacker_action_space.n
101
+ assert env.defender_action_space.n == defender_action_space.n
102
+ assert env.traces == []
103
+
104
+ with patch("gym_csle_stopping_game.dao.stopping_game_state.StoppingGameState") as MockStoppingGameState:
105
+ MockStoppingGameState(b1=self.config.b1, L=self.config.L)
106
+ with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_initial_state"
107
+ ) as MockSampleInitialState:
108
+ MockSampleInitialState.return_value = 0
109
+ StoppingGameEnv(self.config)
110
+ MockSampleInitialState.assert_called()
111
+ MockStoppingGameState.assert_called_once_with(b1=self.config.b1, L=self.config.L)
112
+
113
+ with patch("csle_common.dao.simulation_config.simulation_trace.SimulationTrace") as MockSimulationTrace:
114
+ MockSimulationTrace(self.config.env_name).return_value
115
+ StoppingGameEnv(self.config)
116
+ MockSimulationTrace.assert_called_once_with(self.config.env_name)
117
+
118
+ def test_mean(self) -> None:
119
+ """
120
+ Tests the utility function for getting the mean of a vector
121
+
122
+ :return: None
123
+ """
124
+ test_cases = [
125
+ ([], 0), # Test case for an empty vector
126
+ ([5], 0), # Test case for a vector with a single element
127
+ ([0.2, 0.3, 0.5], 1.3), # Test case for a vector with multiple elements
128
+ ]
129
+ for prob_vector, expected_mean in test_cases:
130
+ result = StoppingGameEnv(self.config).mean(prob_vector)
131
+ assert result == expected_mean
132
+
133
+ def test_weighted_intrusion_prediction_distance(self) -> None:
134
+ """
135
+ Tests the function of computing the weighed intrusion start time prediction distance
136
+ """
137
+ # Test case when first_stop is before intrusion_start
138
+ result1 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance(5, 3)
139
+ assert result1 == 0
140
+
141
+ # Test case when first_stop is after intrusion_start
142
+ result2 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance(3, 5)
143
+ assert result2 == 0.95
144
+
145
+ # Test case when first_stop is equal to intrusion_start
146
+ result3 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance(3, 3)
147
+ assert result3 == 0
148
+
149
+ def test_reset(self) -> None:
150
+ """
151
+ Tests the reset function for reseting the environment state
152
+
153
+ :return: None
154
+ """
155
+ env = StoppingGameEnv(self.config)
156
+ env.state = MagicMock()
157
+ env.state.l = 10
158
+ env.state.s = "initial_state"
159
+ env.state.t = 0
160
+ env.state.attacker_observation.return_value = np.array([1, 2, 3])
161
+ env.state.defender_observation.return_value = np.array([4, 5, 6])
162
+
163
+ env.trace = MagicMock()
164
+ env.trace.attacker_rewards = [1]
165
+ env.traces = []
166
+ # Call the reset method
167
+ observation, info = env.reset()
168
+ # Assertions
169
+ assert env.state.reset.called, "State's reset method was not called."
170
+ assert env.trace.simulation_env == self.config.env_name, "Trace was not initialized correctly."
171
+ assert observation[0].all() == np.array([4, 5, 6]).all(), "Observation does not match expected values."
172
+ assert info[env_constants.ENV_METRICS.STOPS_REMAINING] == env.state.l, \
173
+ "Stops remaining does not match expected value."
174
+ assert info[env_constants.ENV_METRICS.STATE] == env.state.s, "State info does not match expected value."
175
+ assert info[env_constants.ENV_METRICS.OBSERVATION] == 0, "Observation info does not match expected value."
176
+ assert info[env_constants.ENV_METRICS.TIME_STEP] == env.state.t, "Time step info does not match expected value."
177
+
178
+ # Check if trace was appended correctly
179
+ if len(env.trace.attacker_rewards) > 0:
180
+ assert env.traces[-1] == env.trace, "Trace was not appended correctly."
181
+
182
+ def test_render(self) -> None:
183
+ """
184
+ Tests the function of rendering the environment
185
+
186
+ :return: None
187
+ """
188
+ with pytest.raises(NotImplementedError):
189
+ StoppingGameEnv(self.config).render()
190
+
191
+ def test_is_defense_action_legal(self) -> None:
192
+ """
193
+ Tests the function of checking whether a defender action in the environment is legal or not
194
+
195
+ :return: None
196
+ """
197
+ assert StoppingGameEnv(self.config).is_defense_action_legal(1)
198
+
199
+ def test_is_attack_action_legal(self) -> None:
200
+ """
201
+ Tests the function of checking whether an attacker action in the environment is legal or not
202
+
203
+ :return: None
204
+ """
205
+ assert StoppingGameEnv(self.config).is_attack_action_legal(1)
206
+
207
+ def test_get_traces(self) -> None:
208
+ """
209
+ Tests the function of getting the list of simulation traces
210
+
211
+ :return: None
212
+ """
213
+ assert StoppingGameEnv(self.config).get_traces() == StoppingGameEnv(self.config).traces
214
+
215
+ def test_reset_traces(self) -> None:
216
+ """
217
+ Tests the function of resetting the list of traces
218
+
219
+ :return: None
220
+ """
221
+ env = StoppingGameEnv(self.config)
222
+ env.traces = ["trace1", "trace2"]
223
+ env.reset_traces()
224
+ assert env.traces == []
225
+
226
+ def test_checkpoint_traces(self) -> None:
227
+ """
228
+ Tests the function of checkpointing agent traces
229
+
230
+ :return: None
231
+ """
232
+ env = StoppingGameEnv(self.config)
233
+ fixed_timestamp = 123
234
+ with patch("time.time", return_value=fixed_timestamp):
235
+ with patch(
236
+ "csle_common.dao.simulation_config.simulation_trace.SimulationTrace.save_traces"
237
+ ) as mock_save_traces:
238
+ env.traces = ["trace1", "trace2"]
239
+ env._StoppingGameEnv__checkpoint_traces()
240
+ mock_save_traces.assert_called_once_with(
241
+ traces_save_dir=constants.LOGGING.DEFAULT_LOG_DIR,
242
+ traces=env.traces,
243
+ traces_file=f"taus{fixed_timestamp}.json",
244
+ )
245
+
246
+ def test_set_model(self) -> None:
247
+ """
248
+ Tests the function of setting the model
249
+
250
+ :return: None
251
+ """
252
+ env = StoppingGameEnv(self.config)
253
+ mock_model = MagicMock()
254
+ env.set_model(mock_model)
255
+ assert env.model == mock_model
256
+
257
+ def test_set_state(self) -> None:
258
+ """
259
+ Tests the function of setting the state
260
+
261
+ :return: None
262
+ """
263
+ env = StoppingGameEnv(self.config)
264
+ env.state = MagicMock()
265
+
266
+ mock_state = MagicMock(spec=StoppingGameState)
267
+ env.set_state(mock_state)
268
+ assert env.state == mock_state
269
+
270
+ state_int = 5
271
+ env.set_state(state_int)
272
+ assert env.state.s == state_int
273
+ assert env.state.l == self.config.L
274
+
275
+ state_tuple = (3, 7)
276
+ env.set_state(state_tuple)
277
+ assert env.state.s == state_tuple[0]
278
+ assert env.state.l == state_tuple[1]
279
+
280
+ with pytest.raises(ValueError):
281
+ env.set_state([1, 2, 3]) # type: ignore
282
+
283
+ def test_is_state_terminal(self) -> None:
284
+ """
285
+ Tests the function of checking whether a given state is terminal or not
286
+
287
+ :return: None
288
+ """
289
+ env = StoppingGameEnv(self.config)
290
+ env.state = MagicMock()
291
+
292
+ mock_state = MagicMock(spec=StoppingGameState)
293
+ mock_state.s = 2
294
+ assert env.is_state_terminal(mock_state)
295
+ mock_state.s = 1
296
+ assert not env.is_state_terminal(mock_state)
297
+ state_int = 2
298
+ assert env.is_state_terminal(state_int)
299
+ state_int = 1
300
+ assert not env.is_state_terminal(state_int)
301
+ state_tuple = (2, 5)
302
+ assert env.is_state_terminal(state_tuple)
303
+ state_tuple = (1, 5)
304
+ assert not env.is_state_terminal(state_tuple)
305
+
306
+ with pytest.raises(ValueError):
307
+ env.is_state_terminal([1, 2, 3]) # type: ignore
308
+
309
+ def test_get_observation_from_history(self) -> None:
310
+ """
311
+ Tests the function of getting a hidden observation based on a history
312
+
313
+ :return: None
314
+ """
315
+ env = StoppingGameEnv(self.config)
316
+ history = [1, 2, 3, 4, 5]
317
+ pi2 = np.array([0.1, 0.9])
318
+ l = 3
319
+ observation = env.get_observation_from_history(history, pi2, l)
320
+ assert observation == [5]
321
+
322
+ history = []
323
+ with pytest.raises(ValueError, match="History must not be empty"):
324
+ env.get_observation_from_history(history, pi2, l)
325
+
326
+ def test_generate_random_particles(self) -> None:
327
+ """
328
+ Tests the funtion of generating a random list of state particles from a given observation
329
+
330
+ :return: None
331
+ """
332
+ env = StoppingGameEnv(self.config)
333
+ num_particles = 10
334
+ particles = env.generate_random_particles(o=1, num_particles=num_particles)
335
+ assert len(particles) == num_particles
336
+ assert all(p in [0, 1] for p in particles)
337
+
338
+ num_particles = 0
339
+ particles = env.generate_random_particles(o=1, num_particles=num_particles)
340
+ assert len(particles) == num_particles
341
+
342
+ def test_step(self) -> None:
343
+ """
344
+ Tests the funtion of taking a step in the environment by executing the given action
345
+
346
+ :return: None
347
+ """
348
+ env = StoppingGameEnv(self.config)
349
+ env.state = MagicMock()
350
+ env.state.s = 1
351
+ env.state.l = 2
352
+ env.state.t = 0
353
+ env.state.attacker_observation.return_value = np.array([1, 2, 3])
354
+ env.state.defender_observation.return_value = np.array([4, 5, 6])
355
+ env.state.b = np.array([0.5, 0.5, 0.0])
356
+
357
+ env.trace = MagicMock()
358
+ env.trace.defender_rewards = []
359
+ env.trace.attacker_rewards = []
360
+ env.trace.attacker_actions = []
361
+ env.trace.defender_actions = []
362
+ env.trace.infos = []
363
+ env.trace.states = []
364
+ env.trace.beliefs = []
365
+ env.trace.infrastructure_metrics = []
366
+ env.trace.attacker_observations = []
367
+ env.trace.defender_observations = []
368
+
369
+ with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_state",
370
+ return_value=2):
371
+ with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_observation",
372
+ return_value=1):
373
+ with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.next_belief",
374
+ return_value=np.array([0.3, 0.7, 0.0])):
375
+ action_profile = (
376
+ 1,
377
+ (
378
+ np.array(
379
+ [[0.2, 0.8, 0.0], [0.6, 0.4, 0.0], [0.5, 0.5, 0.0]]
380
+ ),
381
+ 2,
382
+ ),
383
+ )
384
+ observations, rewards, terminated, truncated, info = env.step(
385
+ action_profile
386
+ )
387
+
388
+ assert (observations[0] == np.array([4, 5, 6])).all(), "Incorrect defender observations"
389
+ assert (observations[1] == np.array([1, 2, 3])).all(), "Incorrect attacker observations"
390
+ assert rewards == (0, 0)
391
+ assert not terminated
392
+ assert not truncated
393
+ assert env.trace.defender_rewards[-1] == 0
394
+ assert env.trace.attacker_rewards[-1] == 0
395
+ assert env.trace.attacker_actions[-1] == 2
396
+ assert env.trace.defender_actions[-1] == 1
397
+ assert env.trace.infos[-1] == info
398
+ assert env.trace.states[-1] == 2
399
+ print(env.trace.beliefs)
400
+ assert env.trace.beliefs[-1] == 0.7
401
+ assert env.trace.infrastructure_metrics[-1] == 1
402
+ assert (env.trace.attacker_observations[-1] == np.array([1, 2, 3])).all()
403
+ assert (env.trace.defender_observations[-1] == np.array([4, 5, 6])).all()
404
+
405
+ def test_info(self) -> None:
406
+ """
407
+ Tests the function of adding the cumulative reward and episode length to the info dict
408
+
409
+ :return: None
410
+ """
411
+ env = StoppingGameEnv(self.config)
412
+ env.trace = MagicMock()
413
+ env.trace.defender_rewards = [1, 2]
414
+ env.trace.attacker_actions = [0, 1]
415
+ env.trace.defender_actions = [0, 1]
416
+ env.trace.states = [0, 1]
417
+ env.trace.infrastructure_metrics = [0, 1]
418
+ info: Dict[str, Any] = {}
419
+ updated_info = env._info(info)
420
+ assert updated_info[env_constants.ENV_METRICS.RETURN] == sum(env.trace.defender_rewards)