gym-csle-stopping-game 0.6.1__tar.gz → 0.6.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of gym-csle-stopping-game might be problematic. Click here for more details.
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/PKG-INFO +1 -1
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/setup.cfg +5 -5
- gym_csle_stopping_game-0.6.2/src/gym_csle_stopping_game/__version__.py +1 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/envs/stopping_game_env.py +4 -4
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game.egg-info/PKG-INFO +1 -1
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game.egg-info/SOURCES.txt +3 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game.egg-info/requires.txt +5 -5
- gym_csle_stopping_game-0.6.2/tests/test_stopping_game_env.py +428 -0
- gym_csle_stopping_game-0.6.2/tests/test_stopping_game_mdp_attacker_env.py +343 -0
- gym_csle_stopping_game-0.6.2/tests/test_stopping_game_pomdp_defender_env.py +336 -0
- gym_csle_stopping_game-0.6.1/src/gym_csle_stopping_game/__version__.py +0 -1
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/pyproject.toml +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/setup.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/__init__.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/constants/__init__.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/constants/constants.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/dao/__init__.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/dao/stopping_game_attacker_mdp_config.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/dao/stopping_game_config.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/dao/stopping_game_defender_pomdp_config.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/dao/stopping_game_state.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/envs/__init__.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/envs/stopping_game_mdp_attacker_env.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/envs/stopping_game_pomdp_defender_env.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/util/__init__.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/util/stopping_game_util.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game.egg-info/dependency_links.txt +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game.egg-info/not-zip-safe +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game.egg-info/top_level.txt +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/tests/test_stopping_game_dao.py +0 -0
- {gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/tests/test_stopping_game_util.py +0 -0
|
@@ -20,11 +20,11 @@ classifiers =
|
|
|
20
20
|
[options]
|
|
21
21
|
install_requires =
|
|
22
22
|
gymnasium>=0.27.1
|
|
23
|
-
csle-base>=0.6.
|
|
24
|
-
csle-common>=0.6.
|
|
25
|
-
csle-attacker>=0.6.
|
|
26
|
-
csle-defender>=0.6.
|
|
27
|
-
csle-collector>=0.6.
|
|
23
|
+
csle-base>=0.6.2
|
|
24
|
+
csle-common>=0.6.2
|
|
25
|
+
csle-attacker>=0.6.2
|
|
26
|
+
csle-defender>=0.6.2
|
|
27
|
+
csle-collector>=0.6.2
|
|
28
28
|
python_requires = >=3.8
|
|
29
29
|
package_dir =
|
|
30
30
|
=src
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.6.2'
|
|
@@ -42,7 +42,6 @@ class StoppingGameEnv(BaseEnv):
|
|
|
42
42
|
|
|
43
43
|
# Initialize environment state
|
|
44
44
|
self.state = StoppingGameState(b1=self.config.b1, L=self.config.L)
|
|
45
|
-
|
|
46
45
|
# Setup spaces
|
|
47
46
|
self.attacker_observation_space = self.config.attacker_observation_space()
|
|
48
47
|
self.defender_observation_space = self.config.defender_observation_space()
|
|
@@ -73,7 +72,7 @@ class StoppingGameEnv(BaseEnv):
|
|
|
73
72
|
a1, a2_profile = action_profile
|
|
74
73
|
pi2, a2 = a2_profile
|
|
75
74
|
assert pi2.shape[0] == len(self.config.S)
|
|
76
|
-
assert pi2.shape[1] == len(self.config.
|
|
75
|
+
assert pi2.shape[1] == len(self.config.A2)
|
|
77
76
|
done = False
|
|
78
77
|
info: Dict[str, Any] = {}
|
|
79
78
|
|
|
@@ -84,8 +83,7 @@ class StoppingGameEnv(BaseEnv):
|
|
|
84
83
|
else:
|
|
85
84
|
# Compute r, s', b',o'
|
|
86
85
|
r = self.config.R[self.state.l - 1][a1][a2][self.state.s]
|
|
87
|
-
self.state.s = StoppingGameUtil.sample_next_state(l=self.state.l, a1=a1, a2=a2,
|
|
88
|
-
T=self.config.T,
|
|
86
|
+
self.state.s = StoppingGameUtil.sample_next_state(l=self.state.l, a1=a1, a2=a2, T=self.config.T,
|
|
89
87
|
S=self.config.S, s=self.state.s)
|
|
90
88
|
o = StoppingGameUtil.sample_next_observation(Z=self.config.Z,
|
|
91
89
|
O=self.config.O, s_prime=self.state.s)
|
|
@@ -437,6 +435,8 @@ class StoppingGameEnv(BaseEnv):
|
|
|
437
435
|
:param l: the number of stops remaining
|
|
438
436
|
:return: the observation
|
|
439
437
|
"""
|
|
438
|
+
if not history:
|
|
439
|
+
raise ValueError("History must not be empty")
|
|
440
440
|
return [history[-1]]
|
|
441
441
|
|
|
442
442
|
def generate_random_particles(self, o: int, num_particles: int) -> List[int]:
|
|
@@ -23,4 +23,7 @@ src/gym_csle_stopping_game/envs/stopping_game_pomdp_defender_env.py
|
|
|
23
23
|
src/gym_csle_stopping_game/util/__init__.py
|
|
24
24
|
src/gym_csle_stopping_game/util/stopping_game_util.py
|
|
25
25
|
tests/test_stopping_game_dao.py
|
|
26
|
+
tests/test_stopping_game_env.py
|
|
27
|
+
tests/test_stopping_game_mdp_attacker_env.py
|
|
28
|
+
tests/test_stopping_game_pomdp_defender_env.py
|
|
26
29
|
tests/test_stopping_game_util.py
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
gymnasium>=0.27.1
|
|
2
|
-
csle-base>=0.6.
|
|
3
|
-
csle-common>=0.6.
|
|
4
|
-
csle-attacker>=0.6.
|
|
5
|
-
csle-defender>=0.6.
|
|
6
|
-
csle-collector>=0.6.
|
|
2
|
+
csle-base>=0.6.2
|
|
3
|
+
csle-common>=0.6.2
|
|
4
|
+
csle-attacker>=0.6.2
|
|
5
|
+
csle-defender>=0.6.2
|
|
6
|
+
csle-collector>=0.6.2
|
|
7
7
|
|
|
8
8
|
[testing]
|
|
9
9
|
pytest>=6.0
|
|
@@ -0,0 +1,428 @@
|
|
|
1
|
+
from typing import Dict, Any
|
|
2
|
+
import pytest
|
|
3
|
+
from unittest.mock import patch, MagicMock
|
|
4
|
+
from gym.spaces import Box, Discrete
|
|
5
|
+
import numpy as np
|
|
6
|
+
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
|
|
7
|
+
from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
|
|
8
|
+
from gym_csle_stopping_game.dao.stopping_game_state import StoppingGameState
|
|
9
|
+
import gym_csle_stopping_game.constants.constants as env_constants
|
|
10
|
+
from csle_common.constants import constants
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestStoppingGameEnvSuite:
|
|
14
|
+
"""
|
|
15
|
+
Test suite for stopping_game_env.py
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@pytest.fixture(autouse=True)
|
|
19
|
+
def setup_env(self) -> None:
|
|
20
|
+
"""
|
|
21
|
+
Sets up the configuration of the stopping game
|
|
22
|
+
|
|
23
|
+
:return: None
|
|
24
|
+
"""
|
|
25
|
+
env_name = "test_env"
|
|
26
|
+
T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
|
|
27
|
+
O = np.array([0, 1])
|
|
28
|
+
Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]])
|
|
29
|
+
R = np.zeros((2, 3, 3, 3))
|
|
30
|
+
S = np.array([0, 1, 2])
|
|
31
|
+
A1 = np.array([0, 1, 2])
|
|
32
|
+
A2 = np.array([0, 1, 2])
|
|
33
|
+
L = 2
|
|
34
|
+
R_INT = 1
|
|
35
|
+
R_COST = 2
|
|
36
|
+
R_SLA = 3
|
|
37
|
+
R_ST = 4
|
|
38
|
+
b1 = np.array([0.6, 0.4])
|
|
39
|
+
save_dir = "save_directory"
|
|
40
|
+
checkpoint_traces_freq = 100
|
|
41
|
+
gamma = 0.9
|
|
42
|
+
compute_beliefs = True
|
|
43
|
+
save_trace = True
|
|
44
|
+
self.config = StoppingGameConfig(
|
|
45
|
+
env_name,
|
|
46
|
+
T,
|
|
47
|
+
O,
|
|
48
|
+
Z,
|
|
49
|
+
R,
|
|
50
|
+
S,
|
|
51
|
+
A1,
|
|
52
|
+
A2,
|
|
53
|
+
L,
|
|
54
|
+
R_INT,
|
|
55
|
+
R_COST,
|
|
56
|
+
R_SLA,
|
|
57
|
+
R_ST,
|
|
58
|
+
b1,
|
|
59
|
+
save_dir,
|
|
60
|
+
checkpoint_traces_freq,
|
|
61
|
+
gamma,
|
|
62
|
+
compute_beliefs,
|
|
63
|
+
save_trace,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def test_stopping_game_init_(self) -> None:
|
|
67
|
+
"""
|
|
68
|
+
Tests the initializing function
|
|
69
|
+
|
|
70
|
+
:return: None
|
|
71
|
+
"""
|
|
72
|
+
T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
|
|
73
|
+
O = np.array([0, 1])
|
|
74
|
+
A1 = np.array([0, 1, 2])
|
|
75
|
+
A2 = np.array([0, 1, 2])
|
|
76
|
+
L = 2
|
|
77
|
+
b1 = np.array([0.6, 0.4])
|
|
78
|
+
attacker_observation_space = Box(
|
|
79
|
+
low=np.array([0.0, 0.0, 0.0]),
|
|
80
|
+
high=np.array([float(L), 1.0, 2.0]),
|
|
81
|
+
dtype=np.float64,
|
|
82
|
+
)
|
|
83
|
+
defender_observation_space = Box(
|
|
84
|
+
low=np.array([0.0, 0.0]),
|
|
85
|
+
high=np.array([float(L), 1.0]),
|
|
86
|
+
dtype=np.float64,
|
|
87
|
+
)
|
|
88
|
+
attacker_action_space = Discrete(len(A2))
|
|
89
|
+
defender_action_space = Discrete(len(A1))
|
|
90
|
+
|
|
91
|
+
assert self.config.T.any() == T.any()
|
|
92
|
+
assert self.config.O.any() == O.any()
|
|
93
|
+
assert self.config.b1.any() == b1.any()
|
|
94
|
+
assert self.config.L == L
|
|
95
|
+
|
|
96
|
+
env = StoppingGameEnv(self.config)
|
|
97
|
+
assert env.config == self.config
|
|
98
|
+
assert env.attacker_observation_space.low.any() == attacker_observation_space.low.any()
|
|
99
|
+
assert env.defender_observation_space.low.any() == defender_observation_space.low.any()
|
|
100
|
+
assert env.attacker_action_space.n == attacker_action_space.n
|
|
101
|
+
assert env.defender_action_space.n == defender_action_space.n
|
|
102
|
+
assert env.traces == []
|
|
103
|
+
|
|
104
|
+
with patch("gym_csle_stopping_game.dao.stopping_game_state.StoppingGameState") as MockStoppingGameState:
|
|
105
|
+
MockStoppingGameState(b1=self.config.b1, L=self.config.L)
|
|
106
|
+
with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_initial_state"
|
|
107
|
+
) as MockSampleInitialState:
|
|
108
|
+
MockSampleInitialState.return_value = 0
|
|
109
|
+
StoppingGameEnv(self.config)
|
|
110
|
+
MockSampleInitialState.assert_called()
|
|
111
|
+
MockStoppingGameState.assert_called_once_with(b1=self.config.b1, L=self.config.L)
|
|
112
|
+
|
|
113
|
+
with patch("csle_common.dao.simulation_config.simulation_trace.SimulationTrace") as MockSimulationTrace:
|
|
114
|
+
MockSimulationTrace(self.config.env_name).return_value
|
|
115
|
+
StoppingGameEnv(self.config)
|
|
116
|
+
MockSimulationTrace.assert_called_once_with(self.config.env_name)
|
|
117
|
+
|
|
118
|
+
def test_mean(self) -> None:
|
|
119
|
+
"""
|
|
120
|
+
Tests the utility function for getting the mean of a vector
|
|
121
|
+
|
|
122
|
+
:return: None
|
|
123
|
+
"""
|
|
124
|
+
test_cases = [
|
|
125
|
+
([], 0), # Test case for an empty vector
|
|
126
|
+
([5], 0), # Test case for a vector with a single element
|
|
127
|
+
([0.2, 0.3, 0.5], 1.3), # Test case for a vector with multiple elements
|
|
128
|
+
]
|
|
129
|
+
for prob_vector, expected_mean in test_cases:
|
|
130
|
+
result = StoppingGameEnv(self.config).mean(prob_vector)
|
|
131
|
+
assert result == expected_mean
|
|
132
|
+
|
|
133
|
+
def test_weighted_intrusion_prediction_distance(self) -> None:
|
|
134
|
+
"""
|
|
135
|
+
Tests the function of computing the weighed intrusion start time prediction distance
|
|
136
|
+
"""
|
|
137
|
+
# Test case when first_stop is before intrusion_start
|
|
138
|
+
result1 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance(5, 3)
|
|
139
|
+
assert result1 == 0
|
|
140
|
+
|
|
141
|
+
# Test case when first_stop is after intrusion_start
|
|
142
|
+
result2 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance(3, 5)
|
|
143
|
+
assert result2 == 0.95
|
|
144
|
+
|
|
145
|
+
# Test case when first_stop is equal to intrusion_start
|
|
146
|
+
result3 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance(3, 3)
|
|
147
|
+
assert result3 == 0
|
|
148
|
+
|
|
149
|
+
def test_reset(self) -> None:
|
|
150
|
+
"""
|
|
151
|
+
Tests the reset function for reseting the environment state
|
|
152
|
+
|
|
153
|
+
:return: None
|
|
154
|
+
"""
|
|
155
|
+
env = StoppingGameEnv(self.config)
|
|
156
|
+
env.state = MagicMock()
|
|
157
|
+
env.state.l = 10
|
|
158
|
+
env.state.s = "initial_state"
|
|
159
|
+
env.state.t = 0
|
|
160
|
+
env.state.attacker_observation.return_value = np.array([1, 2, 3])
|
|
161
|
+
env.state.defender_observation.return_value = np.array([4, 5, 6])
|
|
162
|
+
|
|
163
|
+
env.trace = MagicMock()
|
|
164
|
+
env.trace.attacker_rewards = [1]
|
|
165
|
+
env.traces = []
|
|
166
|
+
# Call the reset method
|
|
167
|
+
observation, info = env.reset()
|
|
168
|
+
# Assertions
|
|
169
|
+
assert env.state.reset.called, "State's reset method was not called."
|
|
170
|
+
assert env.trace.simulation_env == self.config.env_name, "Trace was not initialized correctly."
|
|
171
|
+
assert observation[0].all() == np.array([4, 5, 6]).all(), "Observation does not match expected values."
|
|
172
|
+
assert info[env_constants.ENV_METRICS.STOPS_REMAINING] == env.state.l, \
|
|
173
|
+
"Stops remaining does not match expected value."
|
|
174
|
+
assert info[env_constants.ENV_METRICS.STATE] == env.state.s, "State info does not match expected value."
|
|
175
|
+
assert info[env_constants.ENV_METRICS.OBSERVATION] == 0, "Observation info does not match expected value."
|
|
176
|
+
assert info[env_constants.ENV_METRICS.TIME_STEP] == env.state.t, "Time step info does not match expected value."
|
|
177
|
+
|
|
178
|
+
# Check if trace was appended correctly
|
|
179
|
+
if len(env.trace.attacker_rewards) > 0:
|
|
180
|
+
assert env.traces[-1] == env.trace, "Trace was not appended correctly."
|
|
181
|
+
|
|
182
|
+
def test_render(self) -> None:
|
|
183
|
+
"""
|
|
184
|
+
Tests the function of rendering the environment
|
|
185
|
+
|
|
186
|
+
:return: None
|
|
187
|
+
"""
|
|
188
|
+
with pytest.raises(NotImplementedError):
|
|
189
|
+
StoppingGameEnv(self.config).render()
|
|
190
|
+
|
|
191
|
+
def test_is_defense_action_legal(self) -> None:
|
|
192
|
+
"""
|
|
193
|
+
Tests the function of checking whether a defender action in the environment is legal or not
|
|
194
|
+
|
|
195
|
+
:return: None
|
|
196
|
+
"""
|
|
197
|
+
assert StoppingGameEnv(self.config).is_defense_action_legal(1)
|
|
198
|
+
|
|
199
|
+
def test_is_attack_action_legal(self) -> None:
|
|
200
|
+
"""
|
|
201
|
+
Tests the function of checking whether an attacker action in the environment is legal or not
|
|
202
|
+
|
|
203
|
+
:return: None
|
|
204
|
+
"""
|
|
205
|
+
assert StoppingGameEnv(self.config).is_attack_action_legal(1)
|
|
206
|
+
|
|
207
|
+
def test_get_traces(self) -> None:
|
|
208
|
+
"""
|
|
209
|
+
Tests the function of getting the list of simulation traces
|
|
210
|
+
|
|
211
|
+
:return: None
|
|
212
|
+
"""
|
|
213
|
+
assert StoppingGameEnv(self.config).get_traces() == StoppingGameEnv(self.config).traces
|
|
214
|
+
|
|
215
|
+
def test_reset_traces(self) -> None:
|
|
216
|
+
"""
|
|
217
|
+
Tests the function of resetting the list of traces
|
|
218
|
+
|
|
219
|
+
:return: None
|
|
220
|
+
"""
|
|
221
|
+
env = StoppingGameEnv(self.config)
|
|
222
|
+
env.traces = ["trace1", "trace2"]
|
|
223
|
+
env.reset_traces()
|
|
224
|
+
assert env.traces == []
|
|
225
|
+
|
|
226
|
+
def test_checkpoint_traces(self) -> None:
|
|
227
|
+
"""
|
|
228
|
+
Tests the function of checkpointing agent traces
|
|
229
|
+
|
|
230
|
+
:return: None
|
|
231
|
+
"""
|
|
232
|
+
env = StoppingGameEnv(self.config)
|
|
233
|
+
fixed_timestamp = 123
|
|
234
|
+
with patch("time.time", return_value=fixed_timestamp):
|
|
235
|
+
with patch(
|
|
236
|
+
"csle_common.dao.simulation_config.simulation_trace.SimulationTrace.save_traces"
|
|
237
|
+
) as mock_save_traces:
|
|
238
|
+
env.traces = ["trace1", "trace2"]
|
|
239
|
+
env._StoppingGameEnv__checkpoint_traces()
|
|
240
|
+
mock_save_traces.assert_called_once_with(
|
|
241
|
+
traces_save_dir=constants.LOGGING.DEFAULT_LOG_DIR,
|
|
242
|
+
traces=env.traces,
|
|
243
|
+
traces_file=f"taus{fixed_timestamp}.json",
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
def test_set_model(self) -> None:
|
|
247
|
+
"""
|
|
248
|
+
Tests the function of setting the model
|
|
249
|
+
|
|
250
|
+
:return: None
|
|
251
|
+
"""
|
|
252
|
+
env = StoppingGameEnv(self.config)
|
|
253
|
+
mock_model = MagicMock()
|
|
254
|
+
env.set_model(mock_model)
|
|
255
|
+
assert env.model == mock_model
|
|
256
|
+
|
|
257
|
+
def test_set_state(self) -> None:
|
|
258
|
+
"""
|
|
259
|
+
Tests the function of setting the state
|
|
260
|
+
|
|
261
|
+
:return: None
|
|
262
|
+
"""
|
|
263
|
+
env = StoppingGameEnv(self.config)
|
|
264
|
+
env.state = MagicMock()
|
|
265
|
+
|
|
266
|
+
mock_state = MagicMock(spec=StoppingGameState)
|
|
267
|
+
env.set_state(mock_state)
|
|
268
|
+
assert env.state == mock_state
|
|
269
|
+
|
|
270
|
+
state_int = 5
|
|
271
|
+
env.set_state(state_int)
|
|
272
|
+
assert env.state.s == state_int
|
|
273
|
+
assert env.state.l == self.config.L
|
|
274
|
+
|
|
275
|
+
state_tuple = (3, 7)
|
|
276
|
+
env.set_state(state_tuple)
|
|
277
|
+
assert env.state.s == state_tuple[0]
|
|
278
|
+
assert env.state.l == state_tuple[1]
|
|
279
|
+
|
|
280
|
+
with pytest.raises(ValueError):
|
|
281
|
+
env.set_state([1, 2, 3]) # type: ignore
|
|
282
|
+
|
|
283
|
+
def test_is_state_terminal(self) -> None:
|
|
284
|
+
"""
|
|
285
|
+
Tests the function of checking whether a given state is terminal or not
|
|
286
|
+
|
|
287
|
+
:return: None
|
|
288
|
+
"""
|
|
289
|
+
env = StoppingGameEnv(self.config)
|
|
290
|
+
env.state = MagicMock()
|
|
291
|
+
|
|
292
|
+
mock_state = MagicMock(spec=StoppingGameState)
|
|
293
|
+
mock_state.s = 2
|
|
294
|
+
assert env.is_state_terminal(mock_state)
|
|
295
|
+
mock_state.s = 1
|
|
296
|
+
assert not env.is_state_terminal(mock_state)
|
|
297
|
+
state_int = 2
|
|
298
|
+
assert env.is_state_terminal(state_int)
|
|
299
|
+
state_int = 1
|
|
300
|
+
assert not env.is_state_terminal(state_int)
|
|
301
|
+
state_tuple = (2, 5)
|
|
302
|
+
assert env.is_state_terminal(state_tuple)
|
|
303
|
+
state_tuple = (1, 5)
|
|
304
|
+
assert not env.is_state_terminal(state_tuple)
|
|
305
|
+
|
|
306
|
+
with pytest.raises(ValueError):
|
|
307
|
+
env.is_state_terminal([1, 2, 3]) # type: ignore
|
|
308
|
+
|
|
309
|
+
def test_get_observation_from_history(self) -> None:
|
|
310
|
+
"""
|
|
311
|
+
Tests the function of getting a hidden observation based on a history
|
|
312
|
+
|
|
313
|
+
:return: None
|
|
314
|
+
"""
|
|
315
|
+
env = StoppingGameEnv(self.config)
|
|
316
|
+
history = [1, 2, 3, 4, 5]
|
|
317
|
+
pi2 = np.array([0.1, 0.9])
|
|
318
|
+
l = 3
|
|
319
|
+
observation = env.get_observation_from_history(history, pi2, l)
|
|
320
|
+
assert observation == [5]
|
|
321
|
+
|
|
322
|
+
history = []
|
|
323
|
+
with pytest.raises(ValueError, match="History must not be empty"):
|
|
324
|
+
env.get_observation_from_history(history, pi2, l)
|
|
325
|
+
|
|
326
|
+
def test_generate_random_particles(self) -> None:
|
|
327
|
+
"""
|
|
328
|
+
Tests the funtion of generating a random list of state particles from a given observation
|
|
329
|
+
|
|
330
|
+
:return: None
|
|
331
|
+
"""
|
|
332
|
+
env = StoppingGameEnv(self.config)
|
|
333
|
+
num_particles = 10
|
|
334
|
+
particles = env.generate_random_particles(o=1, num_particles=num_particles)
|
|
335
|
+
assert len(particles) == num_particles
|
|
336
|
+
assert all(p in [0, 1] for p in particles)
|
|
337
|
+
|
|
338
|
+
num_particles = 0
|
|
339
|
+
particles = env.generate_random_particles(o=1, num_particles=num_particles)
|
|
340
|
+
assert len(particles) == num_particles
|
|
341
|
+
|
|
342
|
+
def test_step(self) -> None:
|
|
343
|
+
"""
|
|
344
|
+
Tests the funtion of taking a step in the environment by executing the given action
|
|
345
|
+
|
|
346
|
+
:return: None
|
|
347
|
+
"""
|
|
348
|
+
env = StoppingGameEnv(self.config)
|
|
349
|
+
env.state = MagicMock()
|
|
350
|
+
env.state.s = 1
|
|
351
|
+
env.state.l = 2
|
|
352
|
+
env.state.t = 0
|
|
353
|
+
env.state.attacker_observation.return_value = np.array([1, 2, 3])
|
|
354
|
+
env.state.defender_observation.return_value = np.array([4, 5, 6])
|
|
355
|
+
env.state.b = np.array([0.5, 0.5, 0.0])
|
|
356
|
+
|
|
357
|
+
env.trace = MagicMock()
|
|
358
|
+
env.trace.defender_rewards = []
|
|
359
|
+
env.trace.attacker_rewards = []
|
|
360
|
+
env.trace.attacker_actions = []
|
|
361
|
+
env.trace.defender_actions = []
|
|
362
|
+
env.trace.infos = []
|
|
363
|
+
env.trace.states = []
|
|
364
|
+
env.trace.beliefs = []
|
|
365
|
+
env.trace.infrastructure_metrics = []
|
|
366
|
+
env.trace.attacker_observations = []
|
|
367
|
+
env.trace.defender_observations = []
|
|
368
|
+
|
|
369
|
+
with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_state",
|
|
370
|
+
return_value=2):
|
|
371
|
+
with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_observation",
|
|
372
|
+
return_value=1):
|
|
373
|
+
with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.next_belief",
|
|
374
|
+
return_value=np.array([0.3, 0.7, 0.0])):
|
|
375
|
+
action_profile = (
|
|
376
|
+
1,
|
|
377
|
+
(
|
|
378
|
+
np.array(
|
|
379
|
+
[[0.2, 0.8, 0.0], [0.6, 0.4, 0.0], [0.5, 0.5, 0.0]]
|
|
380
|
+
),
|
|
381
|
+
2,
|
|
382
|
+
),
|
|
383
|
+
)
|
|
384
|
+
observations, rewards, terminated, truncated, info = env.step(
|
|
385
|
+
action_profile
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
assert (observations[0] == np.array([4, 5, 6])).all(), "Incorrect defender observations"
|
|
389
|
+
assert (observations[1] == np.array([1, 2, 3])).all(), "Incorrect attacker observations"
|
|
390
|
+
assert rewards == (0, 0)
|
|
391
|
+
assert not terminated
|
|
392
|
+
assert not truncated
|
|
393
|
+
assert env.trace.defender_rewards[-1] == 0
|
|
394
|
+
assert env.trace.attacker_rewards[-1] == 0
|
|
395
|
+
assert env.trace.attacker_actions[-1] == 2
|
|
396
|
+
assert env.trace.defender_actions[-1] == 1
|
|
397
|
+
assert env.trace.infos[-1] == info
|
|
398
|
+
assert env.trace.states[-1] == 2
|
|
399
|
+
print(env.trace.beliefs)
|
|
400
|
+
assert env.trace.beliefs[-1] == 0.7
|
|
401
|
+
assert env.trace.infrastructure_metrics[-1] == 1
|
|
402
|
+
assert (env.trace.attacker_observations[-1] == np.array([1, 2, 3])).all()
|
|
403
|
+
assert (env.trace.defender_observations[-1] == np.array([4, 5, 6])).all()
|
|
404
|
+
|
|
405
|
+
def test_info(self) -> None:
|
|
406
|
+
"""
|
|
407
|
+
Tests the function of adding the cumulative reward and episode length to the info dict
|
|
408
|
+
|
|
409
|
+
:return: None
|
|
410
|
+
"""
|
|
411
|
+
env = StoppingGameEnv(self.config)
|
|
412
|
+
env.trace = MagicMock()
|
|
413
|
+
env.trace.defender_rewards = [1, 2]
|
|
414
|
+
env.trace.attacker_actions = [0, 1]
|
|
415
|
+
env.trace.defender_actions = [0, 1]
|
|
416
|
+
env.trace.states = [0, 1]
|
|
417
|
+
env.trace.infrastructure_metrics = [0, 1]
|
|
418
|
+
info: Dict[str, Any] = {}
|
|
419
|
+
updated_info = env._info(info)
|
|
420
|
+
assert updated_info[env_constants.ENV_METRICS.RETURN] == sum(env.trace.defender_rewards)
|
|
421
|
+
|
|
422
|
+
def test_emulation_evaluation(self) -> None:
|
|
423
|
+
"""
|
|
424
|
+
Tests the function for evaluating a strategy profile in the emulation environment
|
|
425
|
+
|
|
426
|
+
:return: None
|
|
427
|
+
"""
|
|
428
|
+
StoppingGameEnv(self.config)
|
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
from gym_csle_stopping_game.envs.stopping_game_mdp_attacker_env import (
|
|
2
|
+
StoppingGameMdpAttackerEnv,
|
|
3
|
+
)
|
|
4
|
+
from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
|
|
5
|
+
from gym_csle_stopping_game.dao.stopping_game_attacker_mdp_config import (
|
|
6
|
+
StoppingGameAttackerMdpConfig,
|
|
7
|
+
)
|
|
8
|
+
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
|
|
9
|
+
from csle_common.dao.training.policy import Policy
|
|
10
|
+
import pytest
|
|
11
|
+
from unittest.mock import MagicMock
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TestStoppingGameMdpAttackerEnvSuite:
|
|
16
|
+
"""
|
|
17
|
+
Test suite for stopping_game_mdp_attacker_env.py
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
@pytest.fixture(autouse=True)
|
|
21
|
+
def setup_env(self) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Sets up the configuration of the stopping game
|
|
24
|
+
|
|
25
|
+
:return: None
|
|
26
|
+
"""
|
|
27
|
+
env_name = "test_env"
|
|
28
|
+
T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
|
|
29
|
+
O = np.array([0, 1])
|
|
30
|
+
Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]])
|
|
31
|
+
R = np.zeros((2, 3, 3, 3))
|
|
32
|
+
S = np.array([0, 1, 2])
|
|
33
|
+
A1 = np.array([0, 1, 2])
|
|
34
|
+
A2 = np.array([0, 1, 2])
|
|
35
|
+
L = 2
|
|
36
|
+
R_INT = 1
|
|
37
|
+
R_COST = 2
|
|
38
|
+
R_SLA = 3
|
|
39
|
+
R_ST = 4
|
|
40
|
+
b1 = np.array([0.6, 0.4])
|
|
41
|
+
save_dir = "save_directory"
|
|
42
|
+
checkpoint_traces_freq = 100
|
|
43
|
+
gamma = 0.9
|
|
44
|
+
compute_beliefs = True
|
|
45
|
+
save_trace = True
|
|
46
|
+
self.config = StoppingGameConfig(
|
|
47
|
+
env_name,
|
|
48
|
+
T,
|
|
49
|
+
O,
|
|
50
|
+
Z,
|
|
51
|
+
R,
|
|
52
|
+
S,
|
|
53
|
+
A1,
|
|
54
|
+
A2,
|
|
55
|
+
L,
|
|
56
|
+
R_INT,
|
|
57
|
+
R_COST,
|
|
58
|
+
R_SLA,
|
|
59
|
+
R_ST,
|
|
60
|
+
b1,
|
|
61
|
+
save_dir,
|
|
62
|
+
checkpoint_traces_freq,
|
|
63
|
+
gamma,
|
|
64
|
+
compute_beliefs,
|
|
65
|
+
save_trace,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def test_init_(self) -> None:
|
|
69
|
+
"""
|
|
70
|
+
Tests the initializing function
|
|
71
|
+
|
|
72
|
+
:return: None
|
|
73
|
+
"""
|
|
74
|
+
# Mock the defender strategy
|
|
75
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
76
|
+
# Create the attacker MDP configuration
|
|
77
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
78
|
+
env_name="test_env",
|
|
79
|
+
stopping_game_config=self.config,
|
|
80
|
+
defender_strategy=defender_strategy,
|
|
81
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
82
|
+
)
|
|
83
|
+
# Initialize the StoppingGameMdpAttackerEnv
|
|
84
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
85
|
+
assert env.config == attacker_mdp_config
|
|
86
|
+
assert env.observation_space == self.config.attacker_observation_space()
|
|
87
|
+
assert env.action_space == self.config.attacker_action_space()
|
|
88
|
+
assert env.static_defender_strategy == defender_strategy
|
|
89
|
+
# print(env.latest_defender_obs)
|
|
90
|
+
# assert not env.latest_defender_obs
|
|
91
|
+
# assert not env.latest_attacker_obs
|
|
92
|
+
assert not env.model
|
|
93
|
+
assert not env.viewer
|
|
94
|
+
|
|
95
|
+
def test_reset(self) -> None:
|
|
96
|
+
"""
|
|
97
|
+
Tests the function for reseting the environment state
|
|
98
|
+
|
|
99
|
+
:return: None
|
|
100
|
+
"""
|
|
101
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
102
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
103
|
+
env_name="test_env",
|
|
104
|
+
stopping_game_config=self.config,
|
|
105
|
+
defender_strategy=defender_strategy,
|
|
106
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
110
|
+
attacker_obs, info = env.reset()
|
|
111
|
+
assert env.latest_defender_obs.all() == np.array([2, 0.4]).all() # type: ignore
|
|
112
|
+
assert info == {}
|
|
113
|
+
|
|
114
|
+
def test_set_model(self) -> None:
|
|
115
|
+
"""
|
|
116
|
+
Tests the function for setting the model
|
|
117
|
+
|
|
118
|
+
:return: None
|
|
119
|
+
"""
|
|
120
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
121
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
122
|
+
env_name="test_env",
|
|
123
|
+
stopping_game_config=self.config,
|
|
124
|
+
defender_strategy=defender_strategy,
|
|
125
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
126
|
+
)
|
|
127
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
128
|
+
mock_model = MagicMock()
|
|
129
|
+
env.set_model(mock_model)
|
|
130
|
+
assert env.model == mock_model
|
|
131
|
+
|
|
132
|
+
def test_set_state(self) -> None:
|
|
133
|
+
"""
|
|
134
|
+
Tests the function for setting the state
|
|
135
|
+
|
|
136
|
+
:return: None
|
|
137
|
+
"""
|
|
138
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
139
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
140
|
+
env_name="test_env",
|
|
141
|
+
stopping_game_config=self.config,
|
|
142
|
+
defender_strategy=defender_strategy,
|
|
143
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
147
|
+
assert not env.set_state(1) # type: ignore
|
|
148
|
+
|
|
149
|
+
def test_calculate_stage_policy(self) -> None:
|
|
150
|
+
"""
|
|
151
|
+
Tests the function for calculating the stage policy of a given model and observation
|
|
152
|
+
|
|
153
|
+
:return: None
|
|
154
|
+
"""
|
|
155
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
156
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
157
|
+
env_name="test_env",
|
|
158
|
+
stopping_game_config=self.config,
|
|
159
|
+
defender_strategy=defender_strategy,
|
|
160
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
164
|
+
env.model = None
|
|
165
|
+
observation = [1, 0.5]
|
|
166
|
+
stage_policy = env.calculate_stage_policy(o=observation)
|
|
167
|
+
expected_stage_policy = np.array([[1.0, 0.0], [1.0, 0.0], [0.5, 0.5]])
|
|
168
|
+
assert stage_policy.all() == expected_stage_policy.all()
|
|
169
|
+
|
|
170
|
+
def test_get_attacker_dist(self) -> None:
|
|
171
|
+
"""
|
|
172
|
+
Tests the function for getting the attacker's action distribution based on a given observation
|
|
173
|
+
|
|
174
|
+
:return: None
|
|
175
|
+
"""
|
|
176
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
177
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
178
|
+
env_name="test_env",
|
|
179
|
+
stopping_game_config=self.config,
|
|
180
|
+
defender_strategy=defender_strategy,
|
|
181
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
185
|
+
env.model = None
|
|
186
|
+
observation = [1, 0.5, 0]
|
|
187
|
+
with pytest.raises(ValueError, match="Model is None"):
|
|
188
|
+
env._get_attacker_dist(observation)
|
|
189
|
+
|
|
190
|
+
def test_render(self) -> None:
|
|
191
|
+
"""
|
|
192
|
+
Tests the function for rendering the environment
|
|
193
|
+
|
|
194
|
+
:return: None
|
|
195
|
+
"""
|
|
196
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
197
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
198
|
+
env_name="test_env",
|
|
199
|
+
stopping_game_config=self.config,
|
|
200
|
+
defender_strategy=defender_strategy,
|
|
201
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
205
|
+
with pytest.raises(NotImplementedError):
|
|
206
|
+
env.render("human")
|
|
207
|
+
|
|
208
|
+
def test_is_defense_action_legal(self) -> None:
|
|
209
|
+
"""
|
|
210
|
+
Tests the function of checking whether a defender action in the environment is legal or not
|
|
211
|
+
|
|
212
|
+
:return: None
|
|
213
|
+
"""
|
|
214
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
215
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
216
|
+
env_name="test_env",
|
|
217
|
+
stopping_game_config=self.config,
|
|
218
|
+
defender_strategy=defender_strategy,
|
|
219
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
223
|
+
assert env.is_defense_action_legal(1)
|
|
224
|
+
|
|
225
|
+
def test_is_attack_action_legal(self) -> None:
|
|
226
|
+
"""
|
|
227
|
+
Tests the function of checking whether an attacker action in the environment is legal or not
|
|
228
|
+
|
|
229
|
+
:return: None
|
|
230
|
+
"""
|
|
231
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
232
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
233
|
+
env_name="test_env",
|
|
234
|
+
stopping_game_config=self.config,
|
|
235
|
+
defender_strategy=defender_strategy,
|
|
236
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
240
|
+
assert env.is_attack_action_legal(1)
|
|
241
|
+
|
|
242
|
+
def test_get_traces(self) -> None:
|
|
243
|
+
"""
|
|
244
|
+
Tests the function of getting the list of simulation traces
|
|
245
|
+
|
|
246
|
+
:return: None
|
|
247
|
+
"""
|
|
248
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
249
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
250
|
+
env_name="test_env",
|
|
251
|
+
stopping_game_config=self.config,
|
|
252
|
+
defender_strategy=defender_strategy,
|
|
253
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
257
|
+
assert env.get_traces() == StoppingGameEnv(self.config).traces
|
|
258
|
+
|
|
259
|
+
def test_reset_traces(self) -> None:
|
|
260
|
+
"""
|
|
261
|
+
Tests the function of resetting the list of traces
|
|
262
|
+
|
|
263
|
+
:return: None
|
|
264
|
+
"""
|
|
265
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
266
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
267
|
+
env_name="test_env",
|
|
268
|
+
stopping_game_config=self.config,
|
|
269
|
+
defender_strategy=defender_strategy,
|
|
270
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
274
|
+
env.traces = ["trace1", "trace2"]
|
|
275
|
+
env.reset_traces()
|
|
276
|
+
assert StoppingGameEnv(self.config).traces == []
|
|
277
|
+
|
|
278
|
+
def test_generate_random_particles(self) -> None:
|
|
279
|
+
"""
|
|
280
|
+
Tests the funtion of generating a random list of state particles from a given observation
|
|
281
|
+
|
|
282
|
+
:return: None
|
|
283
|
+
"""
|
|
284
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
285
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
286
|
+
env_name="test_env",
|
|
287
|
+
stopping_game_config=self.config,
|
|
288
|
+
defender_strategy=defender_strategy,
|
|
289
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
293
|
+
num_particles = 10
|
|
294
|
+
particles = env.generate_random_particles(o=1, num_particles=num_particles)
|
|
295
|
+
assert len(particles) == num_particles
|
|
296
|
+
assert all(p in [0, 1] for p in particles)
|
|
297
|
+
|
|
298
|
+
num_particles = 0
|
|
299
|
+
particles = env.generate_random_particles(o=1, num_particles=num_particles)
|
|
300
|
+
assert len(particles) == num_particles
|
|
301
|
+
|
|
302
|
+
def test_get_actions_from_particles(self) -> None:
|
|
303
|
+
"""
|
|
304
|
+
Tests the function for pruning the set of actions based on the current particle set
|
|
305
|
+
|
|
306
|
+
:return: None
|
|
307
|
+
"""
|
|
308
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
309
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
310
|
+
env_name="test_env",
|
|
311
|
+
stopping_game_config=self.config,
|
|
312
|
+
defender_strategy=defender_strategy,
|
|
313
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
317
|
+
particles = [1, 2, 3]
|
|
318
|
+
t = 0
|
|
319
|
+
observation = 0
|
|
320
|
+
expected_actions = [0, 1, 2]
|
|
321
|
+
assert (
|
|
322
|
+
env.get_actions_from_particles(particles, t, observation)
|
|
323
|
+
== expected_actions
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
def test_step(self) -> None:
|
|
327
|
+
"""
|
|
328
|
+
Tests the function for taking a step in the environment by executing the given action
|
|
329
|
+
|
|
330
|
+
:return: None
|
|
331
|
+
"""
|
|
332
|
+
defender_strategy = MagicMock(spec=Policy)
|
|
333
|
+
attacker_mdp_config = StoppingGameAttackerMdpConfig(
|
|
334
|
+
env_name="test_env",
|
|
335
|
+
stopping_game_config=self.config,
|
|
336
|
+
defender_strategy=defender_strategy,
|
|
337
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
|
|
341
|
+
pi2 = np.array([[0.5, 0.5]])
|
|
342
|
+
with pytest.raises(AssertionError):
|
|
343
|
+
env.step(pi2)
|
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import StoppingGamePomdpDefenderEnv
|
|
2
|
+
from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
|
|
3
|
+
from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import StoppingGameDefenderPomdpConfig
|
|
4
|
+
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
|
|
5
|
+
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
|
|
6
|
+
from csle_common.dao.training.policy import Policy
|
|
7
|
+
from csle_common.dao.training.random_policy import RandomPolicy
|
|
8
|
+
from csle_common.dao.training.player_type import PlayerType
|
|
9
|
+
import pytest
|
|
10
|
+
from unittest.mock import MagicMock
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestStoppingGamePomdpDefenderEnvSuite:
|
|
15
|
+
"""
|
|
16
|
+
Test suite for stopping_game_pomdp_defender_env.py
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@pytest.fixture(autouse=True)
|
|
20
|
+
def setup_env(self) -> None:
|
|
21
|
+
"""
|
|
22
|
+
Sets up the configuration of the stopping game
|
|
23
|
+
|
|
24
|
+
:return: None
|
|
25
|
+
"""
|
|
26
|
+
env_name = "test_env"
|
|
27
|
+
T = StoppingGameUtil.transition_tensor(L=3, p=0)
|
|
28
|
+
O = StoppingGameUtil.observation_space(n=100)
|
|
29
|
+
Z = StoppingGameUtil.observation_tensor(n=100)
|
|
30
|
+
R = np.zeros((2, 3, 3, 3))
|
|
31
|
+
S = StoppingGameUtil.state_space()
|
|
32
|
+
A1 = StoppingGameUtil.defender_actions()
|
|
33
|
+
A2 = StoppingGameUtil.attacker_actions()
|
|
34
|
+
L = 2
|
|
35
|
+
R_INT = 1
|
|
36
|
+
R_COST = 2
|
|
37
|
+
R_SLA = 3
|
|
38
|
+
R_ST = 4
|
|
39
|
+
b1 = StoppingGameUtil.b1()
|
|
40
|
+
save_dir = "save_directory"
|
|
41
|
+
checkpoint_traces_freq = 100
|
|
42
|
+
gamma = 0.9
|
|
43
|
+
compute_beliefs = True
|
|
44
|
+
save_trace = True
|
|
45
|
+
self.config = StoppingGameConfig(
|
|
46
|
+
env_name,
|
|
47
|
+
T,
|
|
48
|
+
O,
|
|
49
|
+
Z,
|
|
50
|
+
R,
|
|
51
|
+
S,
|
|
52
|
+
A1,
|
|
53
|
+
A2,
|
|
54
|
+
L,
|
|
55
|
+
R_INT,
|
|
56
|
+
R_COST,
|
|
57
|
+
R_SLA,
|
|
58
|
+
R_ST,
|
|
59
|
+
b1,
|
|
60
|
+
save_dir,
|
|
61
|
+
checkpoint_traces_freq,
|
|
62
|
+
gamma,
|
|
63
|
+
compute_beliefs,
|
|
64
|
+
save_trace,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def test_init_(self) -> None:
|
|
68
|
+
"""
|
|
69
|
+
Tests the initializing function
|
|
70
|
+
|
|
71
|
+
:return: None
|
|
72
|
+
"""
|
|
73
|
+
# Mock the attacker strategy
|
|
74
|
+
attacker_strategy = MagicMock(spec=Policy)
|
|
75
|
+
# Create the defender POMDP configuration
|
|
76
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
77
|
+
env_name="test_env",
|
|
78
|
+
stopping_game_config=self.config,
|
|
79
|
+
attacker_strategy=attacker_strategy,
|
|
80
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
81
|
+
)
|
|
82
|
+
# Initialize the StoppingGamePomdpDefenderEnv
|
|
83
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
84
|
+
assert env.config == defender_pomdp_config
|
|
85
|
+
assert env.observation_space == self.config.defender_observation_space()
|
|
86
|
+
assert env.action_space == self.config.defender_action_space()
|
|
87
|
+
assert env.static_attacker_strategy == attacker_strategy
|
|
88
|
+
assert not env.viewer
|
|
89
|
+
|
|
90
|
+
def test_reset(self) -> None:
|
|
91
|
+
"""
|
|
92
|
+
Tests the function for reseting the environment state
|
|
93
|
+
|
|
94
|
+
:return: None
|
|
95
|
+
"""
|
|
96
|
+
attacker_strategy = MagicMock(spec=Policy)
|
|
97
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
98
|
+
env_name="test_env",
|
|
99
|
+
stopping_game_config=self.config,
|
|
100
|
+
attacker_strategy=attacker_strategy,
|
|
101
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
102
|
+
)
|
|
103
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
104
|
+
_, info = env.reset()
|
|
105
|
+
assert info
|
|
106
|
+
|
|
107
|
+
def test_render(self) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Tests the function for rendering the environment
|
|
110
|
+
|
|
111
|
+
:return: None
|
|
112
|
+
"""
|
|
113
|
+
attacker_strategy = MagicMock(spec=Policy)
|
|
114
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
115
|
+
env_name="test_env",
|
|
116
|
+
stopping_game_config=self.config,
|
|
117
|
+
attacker_strategy=attacker_strategy,
|
|
118
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
119
|
+
)
|
|
120
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
121
|
+
with pytest.raises(NotImplementedError):
|
|
122
|
+
env.render("human")
|
|
123
|
+
|
|
124
|
+
def test_is_defense_action_legal(self) -> None:
|
|
125
|
+
"""
|
|
126
|
+
Tests the function of checking whether a defender action in the environment is legal or not
|
|
127
|
+
|
|
128
|
+
:return: None
|
|
129
|
+
"""
|
|
130
|
+
attacker_strategy = MagicMock(spec=Policy)
|
|
131
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
132
|
+
env_name="test_env",
|
|
133
|
+
stopping_game_config=self.config,
|
|
134
|
+
attacker_strategy=attacker_strategy,
|
|
135
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
136
|
+
)
|
|
137
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
138
|
+
assert env.is_defense_action_legal(1)
|
|
139
|
+
|
|
140
|
+
def test_is_attack_action_legal(self) -> None:
|
|
141
|
+
"""
|
|
142
|
+
Tests the function of checking whether an attacker action in the environment is legal or not
|
|
143
|
+
|
|
144
|
+
:return: None
|
|
145
|
+
"""
|
|
146
|
+
attacker_strategy = MagicMock(spec=Policy)
|
|
147
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
148
|
+
env_name="test_env",
|
|
149
|
+
stopping_game_config=self.config,
|
|
150
|
+
attacker_strategy=attacker_strategy,
|
|
151
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
152
|
+
)
|
|
153
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
154
|
+
assert env.is_attack_action_legal(1)
|
|
155
|
+
|
|
156
|
+
def test_get_traces(self) -> None:
|
|
157
|
+
"""
|
|
158
|
+
Tests the function of getting the list of simulation traces
|
|
159
|
+
|
|
160
|
+
:return: None
|
|
161
|
+
"""
|
|
162
|
+
attacker_strategy = MagicMock(spec=Policy)
|
|
163
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
164
|
+
env_name="test_env",
|
|
165
|
+
stopping_game_config=self.config,
|
|
166
|
+
attacker_strategy=attacker_strategy,
|
|
167
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
168
|
+
)
|
|
169
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
170
|
+
assert env.get_traces() == StoppingGameEnv(self.config).traces
|
|
171
|
+
|
|
172
|
+
def test_reset_traces(self) -> None:
|
|
173
|
+
"""
|
|
174
|
+
Tests the function of resetting the list of traces
|
|
175
|
+
|
|
176
|
+
:return: None
|
|
177
|
+
"""
|
|
178
|
+
attacker_strategy = MagicMock(spec=Policy)
|
|
179
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
180
|
+
env_name="test_env",
|
|
181
|
+
stopping_game_config=self.config,
|
|
182
|
+
attacker_strategy=attacker_strategy,
|
|
183
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
184
|
+
)
|
|
185
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
186
|
+
env.traces = ["trace1", "trace2"]
|
|
187
|
+
env.reset_traces()
|
|
188
|
+
assert StoppingGameEnv(self.config).traces == []
|
|
189
|
+
|
|
190
|
+
def test_set_model(self) -> None:
|
|
191
|
+
"""
|
|
192
|
+
Tests the function for setting the model
|
|
193
|
+
|
|
194
|
+
:return: None
|
|
195
|
+
"""
|
|
196
|
+
attacker_strategy = MagicMock(spec=Policy)
|
|
197
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
198
|
+
env_name="test_env",
|
|
199
|
+
stopping_game_config=self.config,
|
|
200
|
+
attacker_strategy=attacker_strategy,
|
|
201
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
202
|
+
)
|
|
203
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
204
|
+
mock_model = MagicMock()
|
|
205
|
+
env.set_model(mock_model)
|
|
206
|
+
assert env.model == mock_model
|
|
207
|
+
|
|
208
|
+
def test_set_state(self) -> None:
|
|
209
|
+
"""
|
|
210
|
+
Tests the function for setting the state
|
|
211
|
+
|
|
212
|
+
:return: None
|
|
213
|
+
"""
|
|
214
|
+
attacker_strategy = MagicMock(spec=Policy)
|
|
215
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
216
|
+
env_name="test_env",
|
|
217
|
+
stopping_game_config=self.config,
|
|
218
|
+
attacker_strategy=attacker_strategy,
|
|
219
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
220
|
+
)
|
|
221
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
222
|
+
assert env.set_state(1) is None # type: ignore
|
|
223
|
+
|
|
224
|
+
def test_get_observation_from_history(self) -> None:
|
|
225
|
+
"""
|
|
226
|
+
Tests the function for getting a defender observation (belief) from a history
|
|
227
|
+
|
|
228
|
+
:return: None
|
|
229
|
+
"""
|
|
230
|
+
attacker_strategy = MagicMock(spec=Policy)
|
|
231
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
232
|
+
env_name="test_env",
|
|
233
|
+
stopping_game_config=self.config,
|
|
234
|
+
attacker_strategy=attacker_strategy,
|
|
235
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
236
|
+
)
|
|
237
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
238
|
+
history = [1, 2, 3]
|
|
239
|
+
l = self.config.L
|
|
240
|
+
pi2 = env.static_attacker_strategy.stage_policy(o=0)
|
|
241
|
+
assert env.get_observation_from_history(history) == StoppingGameEnv(
|
|
242
|
+
self.config
|
|
243
|
+
).get_observation_from_history(history, pi2, l)
|
|
244
|
+
|
|
245
|
+
def test_is_state_terminal(self) -> None:
|
|
246
|
+
"""
|
|
247
|
+
Tests the funciton for checking whether a state is terminal or not
|
|
248
|
+
|
|
249
|
+
:return: None
|
|
250
|
+
"""
|
|
251
|
+
attacker_strategy = MagicMock(spec=Policy)
|
|
252
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
253
|
+
env_name="test_env",
|
|
254
|
+
stopping_game_config=self.config,
|
|
255
|
+
attacker_strategy=attacker_strategy,
|
|
256
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
257
|
+
)
|
|
258
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
259
|
+
assert env.is_state_terminal(1) == StoppingGameEnv(
|
|
260
|
+
self.config
|
|
261
|
+
).is_state_terminal(1)
|
|
262
|
+
|
|
263
|
+
def test_generate_random_particles(self) -> None:
|
|
264
|
+
"""
|
|
265
|
+
Tests the funtion of generating a random list of state particles from a given observation
|
|
266
|
+
|
|
267
|
+
:return: None
|
|
268
|
+
"""
|
|
269
|
+
attacker_strategy = MagicMock(spec=Policy)
|
|
270
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
271
|
+
env_name="test_env",
|
|
272
|
+
stopping_game_config=self.config,
|
|
273
|
+
attacker_strategy=attacker_strategy,
|
|
274
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
275
|
+
)
|
|
276
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
277
|
+
num_particles = 10
|
|
278
|
+
particles = env.generate_random_particles(o=1, num_particles=num_particles)
|
|
279
|
+
assert len(particles) == num_particles
|
|
280
|
+
assert all(p in [0, 1] for p in particles)
|
|
281
|
+
|
|
282
|
+
num_particles = 0
|
|
283
|
+
particles = env.generate_random_particles(o=1, num_particles=num_particles)
|
|
284
|
+
assert len(particles) == num_particles
|
|
285
|
+
|
|
286
|
+
def test_get_actions_from_particles(self) -> None:
|
|
287
|
+
"""
|
|
288
|
+
Tests the function for pruning the set of actions based on the current particle set
|
|
289
|
+
|
|
290
|
+
:return: None
|
|
291
|
+
"""
|
|
292
|
+
attacker_strategy = MagicMock(spec=Policy)
|
|
293
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
294
|
+
env_name="test_env",
|
|
295
|
+
stopping_game_config=self.config,
|
|
296
|
+
attacker_strategy=attacker_strategy,
|
|
297
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
298
|
+
)
|
|
299
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
300
|
+
particles = [1, 2, 3]
|
|
301
|
+
t = 0
|
|
302
|
+
observation = 0
|
|
303
|
+
expected_actions = [0, 1]
|
|
304
|
+
assert env.get_actions_from_particles(particles, t, observation) == expected_actions
|
|
305
|
+
|
|
306
|
+
def test_step(self) -> None:
|
|
307
|
+
"""
|
|
308
|
+
Tests the function for taking a step in the environment by executing the given action
|
|
309
|
+
|
|
310
|
+
:return: None
|
|
311
|
+
"""
|
|
312
|
+
attacker_stage_strategy = np.zeros((3, 2))
|
|
313
|
+
attacker_stage_strategy[0][0] = 0.9
|
|
314
|
+
attacker_stage_strategy[0][1] = 0.1
|
|
315
|
+
attacker_stage_strategy[1][0] = 0.9
|
|
316
|
+
attacker_stage_strategy[1][1] = 0.1
|
|
317
|
+
attacker_stage_strategy[2] = attacker_stage_strategy[1]
|
|
318
|
+
attacker_strategy = RandomPolicy(actions=list(self.config.A2), player_type=PlayerType.ATTACKER,
|
|
319
|
+
stage_policy_tensor=list(attacker_stage_strategy))
|
|
320
|
+
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
|
|
321
|
+
env_name="test_env",
|
|
322
|
+
stopping_game_config=self.config,
|
|
323
|
+
attacker_strategy=attacker_strategy,
|
|
324
|
+
stopping_game_name="csle-stopping-game-v1",
|
|
325
|
+
)
|
|
326
|
+
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
|
|
327
|
+
a1 = 1
|
|
328
|
+
env.reset()
|
|
329
|
+
defender_obs, reward, terminated, truncated, info = env.step(a1)
|
|
330
|
+
assert len(defender_obs) == 2
|
|
331
|
+
assert isinstance(defender_obs[0], float) # type: ignore
|
|
332
|
+
assert isinstance(defender_obs[1], float) # type: ignore
|
|
333
|
+
assert isinstance(reward, float) # type: ignore
|
|
334
|
+
assert isinstance(terminated, bool) # type: ignore
|
|
335
|
+
assert isinstance(truncated, bool) # type: ignore
|
|
336
|
+
assert isinstance(info, dict) # type: ignore
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = '0.6.1'
|
|
File without changes
|
|
File without changes
|
{gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/src/gym_csle_stopping_game/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/tests/test_stopping_game_dao.py
RENAMED
|
File without changes
|
{gym_csle_stopping_game-0.6.1 → gym_csle_stopping_game-0.6.2}/tests/test_stopping_game_util.py
RENAMED
|
File without changes
|