gym-csle-stopping-game 0.9.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gym_csle_stopping_game/__init__.py +23 -0
- gym_csle_stopping_game/__version__.py +1 -0
- gym_csle_stopping_game/constants/__init__.py +0 -0
- gym_csle_stopping_game/constants/constants.py +40 -0
- gym_csle_stopping_game/dao/__init__.py +0 -0
- gym_csle_stopping_game/dao/stopping_game_attacker_mdp_config.py +86 -0
- gym_csle_stopping_game/dao/stopping_game_config.py +165 -0
- gym_csle_stopping_game/dao/stopping_game_defender_pomdp_config.py +92 -0
- gym_csle_stopping_game/dao/stopping_game_state.py +98 -0
- gym_csle_stopping_game/envs/__init__.py +1 -0
- gym_csle_stopping_game/envs/stopping_game_env.py +393 -0
- gym_csle_stopping_game/envs/stopping_game_mdp_attacker_env.py +282 -0
- gym_csle_stopping_game/envs/stopping_game_pomdp_defender_env.py +233 -0
- gym_csle_stopping_game/util/__init__.py +0 -0
- gym_csle_stopping_game/util/stopping_game_util.py +699 -0
- gym_csle_stopping_game-0.9.24.dist-info/METADATA +414 -0
- gym_csle_stopping_game-0.9.24.dist-info/RECORD +19 -0
- gym_csle_stopping_game-0.9.24.dist-info/WHEEL +5 -0
- gym_csle_stopping_game-0.9.24.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Tuple, Dict, List, Any, Union
|
|
3
|
+
import numpy as np
|
|
4
|
+
import numpy.typing as npt
|
|
5
|
+
import time
|
|
6
|
+
import math
|
|
7
|
+
import csle_common.constants.constants as constants
|
|
8
|
+
from csle_common.dao.simulation_config.base_env import BaseEnv
|
|
9
|
+
from csle_common.dao.simulation_config.simulation_trace import SimulationTrace
|
|
10
|
+
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
|
|
11
|
+
from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
|
|
12
|
+
from gym_csle_stopping_game.dao.stopping_game_state import StoppingGameState
|
|
13
|
+
import gym_csle_stopping_game.constants.constants as env_constants
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class StoppingGameEnv(BaseEnv):
|
|
17
|
+
"""
|
|
18
|
+
OpenAI Gym Env for the csle-stopping-game
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, config: StoppingGameConfig):
|
|
22
|
+
"""
|
|
23
|
+
Initializes the environment
|
|
24
|
+
|
|
25
|
+
:param config: the environment configuration
|
|
26
|
+
"""
|
|
27
|
+
self.config = config
|
|
28
|
+
|
|
29
|
+
# Initialize environment state
|
|
30
|
+
self.state = StoppingGameState(b1=self.config.b1, L=self.config.L)
|
|
31
|
+
# Setup spaces
|
|
32
|
+
self.attacker_observation_space = self.config.attacker_observation_space()
|
|
33
|
+
self.defender_observation_space = self.config.defender_observation_space()
|
|
34
|
+
self.attacker_action_space = self.config.attacker_action_space()
|
|
35
|
+
self.defender_action_space = self.config.defender_action_space()
|
|
36
|
+
|
|
37
|
+
self.action_space = self.defender_action_space
|
|
38
|
+
self.observation_space = self.defender_observation_space
|
|
39
|
+
|
|
40
|
+
# Setup traces
|
|
41
|
+
self.traces: List[SimulationTrace] = []
|
|
42
|
+
self.trace = SimulationTrace(simulation_env=self.config.env_name)
|
|
43
|
+
|
|
44
|
+
# Reset
|
|
45
|
+
self.reset()
|
|
46
|
+
super().__init__()
|
|
47
|
+
|
|
48
|
+
def step(self, action_profile: Tuple[int, Tuple[npt.NDArray[Any], int]]) \
|
|
49
|
+
-> Tuple[Tuple[npt.NDArray[Any], npt.NDArray[Any]], Tuple[int, int], bool, bool, Dict[str, Any]]:
|
|
50
|
+
"""
|
|
51
|
+
Takes a step in the environment by executing the given action
|
|
52
|
+
|
|
53
|
+
:param action_profile: the actions to take (both players actions
|
|
54
|
+
:return: (obs, reward, terminated, truncated, info)
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
# Setup initial values
|
|
58
|
+
a1, a2_profile = action_profile
|
|
59
|
+
pi2, a2 = a2_profile
|
|
60
|
+
assert pi2.shape[0] == len(self.config.S)
|
|
61
|
+
assert pi2.shape[1] == len(self.config.A2)
|
|
62
|
+
done = False
|
|
63
|
+
info: Dict[str, Any] = {}
|
|
64
|
+
|
|
65
|
+
o = max(self.config.O)
|
|
66
|
+
if self.state.s == 2:
|
|
67
|
+
done = True
|
|
68
|
+
r = 0
|
|
69
|
+
else:
|
|
70
|
+
# Compute r, s', b',o'
|
|
71
|
+
r = self.config.R[self.state.l - 1][a1][a2][self.state.s]
|
|
72
|
+
self.state.s = StoppingGameUtil.sample_next_state(l=self.state.l, a1=a1, a2=a2, T=self.config.T,
|
|
73
|
+
S=self.config.S, s=self.state.s)
|
|
74
|
+
o = StoppingGameUtil.sample_next_observation(Z=self.config.Z,
|
|
75
|
+
O=self.config.O, s_prime=self.state.s)
|
|
76
|
+
if self.config.compute_beliefs:
|
|
77
|
+
self.state.b = StoppingGameUtil.next_belief(o=o, a1=a1, b=self.state.b, pi2=pi2,
|
|
78
|
+
config=self.config,
|
|
79
|
+
l=self.state.l, a2=a2)
|
|
80
|
+
|
|
81
|
+
# Update stops remaining
|
|
82
|
+
self.state.l = self.state.l - a1
|
|
83
|
+
|
|
84
|
+
# Update time-step
|
|
85
|
+
self.state.t += 1
|
|
86
|
+
|
|
87
|
+
# Populate info dict
|
|
88
|
+
info[env_constants.ENV_METRICS.STOPS_REMAINING] = self.state.l
|
|
89
|
+
info[env_constants.ENV_METRICS.STATE] = self.state.s
|
|
90
|
+
info[env_constants.ENV_METRICS.DEFENDER_ACTION] = a1
|
|
91
|
+
info[env_constants.ENV_METRICS.ATTACKER_ACTION] = a2
|
|
92
|
+
info[env_constants.ENV_METRICS.OBSERVATION] = o
|
|
93
|
+
info[env_constants.ENV_METRICS.TIME_STEP] = self.state.t
|
|
94
|
+
|
|
95
|
+
# Get observations
|
|
96
|
+
attacker_obs = self.state.attacker_observation()
|
|
97
|
+
defender_obs = self.state.defender_observation()
|
|
98
|
+
|
|
99
|
+
# Log trace
|
|
100
|
+
if self.config.save_trace:
|
|
101
|
+
self.trace.defender_rewards.append(r)
|
|
102
|
+
self.trace.attacker_rewards.append(-r)
|
|
103
|
+
self.trace.attacker_actions.append(a2)
|
|
104
|
+
self.trace.defender_actions.append(a1)
|
|
105
|
+
self.trace.infos.append(info)
|
|
106
|
+
self.trace.states.append(self.state.s)
|
|
107
|
+
self.trace.beliefs.append(self.state.b[1])
|
|
108
|
+
self.trace.infrastructure_metrics.append(o)
|
|
109
|
+
if not done:
|
|
110
|
+
self.trace.attacker_observations.append(attacker_obs)
|
|
111
|
+
self.trace.defender_observations.append(defender_obs)
|
|
112
|
+
|
|
113
|
+
# Populate info
|
|
114
|
+
info = self._info(info)
|
|
115
|
+
|
|
116
|
+
return (defender_obs, attacker_obs), (r, -r), done, done, info
|
|
117
|
+
|
|
118
|
+
def mean(self, prob_vector):
|
|
119
|
+
"""
|
|
120
|
+
Utility function for getting the mean of a vector
|
|
121
|
+
|
|
122
|
+
:param prob_vector: the vector to take the mean of
|
|
123
|
+
:return: the mean
|
|
124
|
+
"""
|
|
125
|
+
m = 0
|
|
126
|
+
for i in range(len(prob_vector)):
|
|
127
|
+
m += prob_vector[i] * i
|
|
128
|
+
return m
|
|
129
|
+
|
|
130
|
+
def weighted_intrusion_prediction_distance(self, intrusion_start: int, first_stop: int):
|
|
131
|
+
"""
|
|
132
|
+
Computes the weighted intrusion start time prediction distance (Wang, Hammar, Stadler, 2022)
|
|
133
|
+
|
|
134
|
+
:param intrusion_start: the intrusion start time
|
|
135
|
+
:param first_stop: the predicted start time
|
|
136
|
+
:return: the weighted distance
|
|
137
|
+
"""
|
|
138
|
+
if first_stop <= intrusion_start:
|
|
139
|
+
return 1 - (10 / 10)
|
|
140
|
+
else:
|
|
141
|
+
return 1 - (min(10, (first_stop - (intrusion_start + 1))) / 2) / 10
|
|
142
|
+
|
|
143
|
+
def _info(self, info: Dict[str, Any]) -> Dict[str, Any]:
|
|
144
|
+
"""
|
|
145
|
+
Adds the cumulative reward and episode length to the info dict
|
|
146
|
+
|
|
147
|
+
:param info: the info dict to update
|
|
148
|
+
:return: the updated info dict
|
|
149
|
+
"""
|
|
150
|
+
R = 0
|
|
151
|
+
for i in range(len(self.trace.defender_rewards)):
|
|
152
|
+
R += self.trace.defender_rewards[i] * math.pow(self.config.gamma, i)
|
|
153
|
+
info[env_constants.ENV_METRICS.RETURN] = sum(self.trace.defender_rewards)
|
|
154
|
+
info[env_constants.ENV_METRICS.TIME_HORIZON] = len(self.trace.defender_actions)
|
|
155
|
+
stop = self.config.L
|
|
156
|
+
for i in range(1, self.config.L + 1):
|
|
157
|
+
info[f"{env_constants.ENV_METRICS.STOP}_{i}"] = len(self.trace.states)
|
|
158
|
+
for i in range(len(self.trace.defender_actions)):
|
|
159
|
+
if self.trace.defender_actions[i] == 1:
|
|
160
|
+
info[f"{env_constants.ENV_METRICS.STOP}_{stop}"] = i
|
|
161
|
+
stop -= 1
|
|
162
|
+
intrusion_start = len(self.trace.defender_actions)
|
|
163
|
+
for i in range(len(self.trace.attacker_actions)):
|
|
164
|
+
if self.trace.attacker_actions[i] == 1:
|
|
165
|
+
intrusion_start = i
|
|
166
|
+
break
|
|
167
|
+
intrusion_end = len(self.trace.attacker_actions)
|
|
168
|
+
info[env_constants.ENV_METRICS.INTRUSION_START] = intrusion_start
|
|
169
|
+
info[env_constants.ENV_METRICS.INTRUSION_END] = intrusion_end
|
|
170
|
+
info[env_constants.ENV_METRICS.START_POINT_CORRECT] = \
|
|
171
|
+
int(intrusion_start == (info[f"{env_constants.ENV_METRICS.STOP}_1"] + 1))
|
|
172
|
+
info[env_constants.ENV_METRICS.WEIGHTED_INTRUSION_PREDICTION_DISTANCE] = \
|
|
173
|
+
self.weighted_intrusion_prediction_distance(intrusion_start=intrusion_start,
|
|
174
|
+
first_stop=info[f"{env_constants.ENV_METRICS.STOP}_1"])
|
|
175
|
+
info[env_constants.ENV_METRICS.INTRUSION_LENGTH] = intrusion_end - intrusion_start
|
|
176
|
+
upper_bound_return = 0
|
|
177
|
+
defender_baseline_stop_on_first_alert_return = 0
|
|
178
|
+
upper_bound_stops_remaining = self.config.L
|
|
179
|
+
defender_baseline_stop_on_first_alert_stops_remaining = self.config.L
|
|
180
|
+
for i in range(len(self.trace.states)):
|
|
181
|
+
if defender_baseline_stop_on_first_alert_stops_remaining > 0:
|
|
182
|
+
if self.trace.infrastructure_metrics[i] > 0:
|
|
183
|
+
defender_baseline_stop_on_first_alert_return += \
|
|
184
|
+
self.config.R[int(defender_baseline_stop_on_first_alert_stops_remaining) - 1][1][
|
|
185
|
+
self.trace.attacker_actions[i]][self.trace.states[i]] * math.pow(self.config.gamma, i)
|
|
186
|
+
defender_baseline_stop_on_first_alert_stops_remaining -= 1
|
|
187
|
+
else:
|
|
188
|
+
defender_baseline_stop_on_first_alert_return += \
|
|
189
|
+
self.config.R[int(defender_baseline_stop_on_first_alert_stops_remaining) - 1][0][
|
|
190
|
+
self.trace.attacker_actions[i]][self.trace.states[i]] * math.pow(self.config.gamma, i)
|
|
191
|
+
if upper_bound_stops_remaining > 0:
|
|
192
|
+
if self.trace.states[i] == 0:
|
|
193
|
+
r = self.config.R[int(upper_bound_stops_remaining) - 1][0][self.trace.attacker_actions[i]][
|
|
194
|
+
self.trace.states[i]]
|
|
195
|
+
upper_bound_return += r * math.pow(self.config.gamma, i)
|
|
196
|
+
else:
|
|
197
|
+
r = self.config.R[int(upper_bound_stops_remaining) - 1][1][self.trace.attacker_actions[i]][
|
|
198
|
+
self.trace.states[i]]
|
|
199
|
+
upper_bound_return += r * math.pow(self.config.gamma, i)
|
|
200
|
+
upper_bound_stops_remaining -= 1
|
|
201
|
+
info[env_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN] = upper_bound_return
|
|
202
|
+
info[env_constants.ENV_METRICS.AVERAGE_DEFENDER_BASELINE_STOP_ON_FIRST_ALERT_RETURN] = \
|
|
203
|
+
defender_baseline_stop_on_first_alert_return
|
|
204
|
+
return info
|
|
205
|
+
|
|
206
|
+
def reset(self, seed: Union[None, int] = None, soft: bool = False, options: Union[Dict[str, Any], None] = None) \
|
|
207
|
+
-> Tuple[Tuple[npt.NDArray[Any], npt.NDArray[Any]], Dict[str, Any]]:
|
|
208
|
+
"""
|
|
209
|
+
Resets the environment state, this should be called whenever step() returns <done>
|
|
210
|
+
|
|
211
|
+
:param seed: the random seed
|
|
212
|
+
:param soft: boolean flag indicating whether it is a soft reset or not
|
|
213
|
+
:param options: optional configuration parameters
|
|
214
|
+
:return: initial observation and info
|
|
215
|
+
"""
|
|
216
|
+
super().reset(seed=seed)
|
|
217
|
+
self.state.reset()
|
|
218
|
+
if len(self.trace.attacker_rewards) > 0:
|
|
219
|
+
self.traces.append(self.trace)
|
|
220
|
+
self.trace = SimulationTrace(simulation_env=self.config.env_name)
|
|
221
|
+
attacker_obs = self.state.attacker_observation()
|
|
222
|
+
defender_obs = self.state.defender_observation()
|
|
223
|
+
if self.config.save_trace:
|
|
224
|
+
self.trace.attacker_observations.append(attacker_obs)
|
|
225
|
+
self.trace.defender_observations.append(defender_obs)
|
|
226
|
+
info: Dict[str, Any] = {}
|
|
227
|
+
info[env_constants.ENV_METRICS.STOPS_REMAINING] = self.state.l
|
|
228
|
+
info[env_constants.ENV_METRICS.STATE] = self.state.s
|
|
229
|
+
info[env_constants.ENV_METRICS.OBSERVATION] = 0
|
|
230
|
+
info[env_constants.ENV_METRICS.TIME_STEP] = self.state.t
|
|
231
|
+
return (defender_obs, attacker_obs), info
|
|
232
|
+
|
|
233
|
+
def render(self, mode: str = 'human'):
|
|
234
|
+
"""
|
|
235
|
+
Renders the environment. Supported rendering modes: (1) human; and (2) rgb_array
|
|
236
|
+
|
|
237
|
+
:param mode: the rendering mode
|
|
238
|
+
:return: True (if human mode) otherwise an rgb array
|
|
239
|
+
"""
|
|
240
|
+
raise NotImplementedError("Rendering is not implemented for this environment")
|
|
241
|
+
|
|
242
|
+
def is_defense_action_legal(self, defense_action_id: int) -> bool:
|
|
243
|
+
"""
|
|
244
|
+
Checks whether a defender action in the environment is legal or not
|
|
245
|
+
|
|
246
|
+
:param defense_action_id: the id of the action
|
|
247
|
+
:return: True or False
|
|
248
|
+
"""
|
|
249
|
+
return True
|
|
250
|
+
|
|
251
|
+
def is_attack_action_legal(self, attack_action_id: int) -> bool:
|
|
252
|
+
"""
|
|
253
|
+
Checks whether an attacker action in the environment is legal or not
|
|
254
|
+
|
|
255
|
+
:param attack_action_id: the id of the attacker action
|
|
256
|
+
:return: True or False
|
|
257
|
+
"""
|
|
258
|
+
return True
|
|
259
|
+
|
|
260
|
+
def get_traces(self) -> List[SimulationTrace]:
|
|
261
|
+
"""
|
|
262
|
+
:return: the list of simulation traces
|
|
263
|
+
"""
|
|
264
|
+
return self.traces
|
|
265
|
+
|
|
266
|
+
def reset_traces(self) -> None:
|
|
267
|
+
"""
|
|
268
|
+
Resets the list of traces
|
|
269
|
+
|
|
270
|
+
:return: None
|
|
271
|
+
"""
|
|
272
|
+
self.traces = []
|
|
273
|
+
|
|
274
|
+
def __checkpoint_traces(self) -> None:
|
|
275
|
+
"""
|
|
276
|
+
Checkpoints agent traces
|
|
277
|
+
:return: None
|
|
278
|
+
"""
|
|
279
|
+
ts = time.time()
|
|
280
|
+
SimulationTrace.save_traces(traces_save_dir=constants.LOGGING.DEFAULT_LOG_DIR,
|
|
281
|
+
traces=self.traces, traces_file=f"taus{ts}.json")
|
|
282
|
+
|
|
283
|
+
def set_model(self, model) -> None:
|
|
284
|
+
"""
|
|
285
|
+
Sets the model. Useful when using RL frameworks where the stage policy is not easy to extract
|
|
286
|
+
|
|
287
|
+
:param model: the model
|
|
288
|
+
:return: None
|
|
289
|
+
"""
|
|
290
|
+
self.model = model
|
|
291
|
+
|
|
292
|
+
def set_state(self, state: Union[StoppingGameState, int, Tuple[int, int]]) -> None:
|
|
293
|
+
"""
|
|
294
|
+
Sets the state. Allows to simulate samples from specific states
|
|
295
|
+
|
|
296
|
+
:param state: the state
|
|
297
|
+
:return: None
|
|
298
|
+
"""
|
|
299
|
+
if isinstance(state, StoppingGameState):
|
|
300
|
+
self.state = state
|
|
301
|
+
elif type(state) is int or type(state) is np.int64:
|
|
302
|
+
self.state.s = state
|
|
303
|
+
self.state.l = self.config.L
|
|
304
|
+
elif type(state) is tuple:
|
|
305
|
+
self.state.s = state[0]
|
|
306
|
+
self.state.l = state[1]
|
|
307
|
+
else:
|
|
308
|
+
raise ValueError(f"state: {state} not valid")
|
|
309
|
+
|
|
310
|
+
def is_state_terminal(self, state: Union[StoppingGameState, int, Tuple[int, int]]) -> bool:
|
|
311
|
+
"""
|
|
312
|
+
Checks whether a given state is terminal or not
|
|
313
|
+
|
|
314
|
+
:param state: the state
|
|
315
|
+
:return: True if terminal, else false
|
|
316
|
+
"""
|
|
317
|
+
if isinstance(state, StoppingGameState):
|
|
318
|
+
return state.s == 2
|
|
319
|
+
elif type(state) is int or type(state) is np.int64:
|
|
320
|
+
return state == 2
|
|
321
|
+
elif type(state) is tuple:
|
|
322
|
+
return state[0] == 2
|
|
323
|
+
else:
|
|
324
|
+
raise ValueError(f"state: {state} not valid")
|
|
325
|
+
|
|
326
|
+
def get_observation_from_history(self, history: List[int], pi2: npt.NDArray[Any], l: int) -> List[Any]:
|
|
327
|
+
"""
|
|
328
|
+
Utility method to get a hidden observation based on a history
|
|
329
|
+
|
|
330
|
+
:param history: the history to get the observation from
|
|
331
|
+
:param pi2: the attacker stage strategy
|
|
332
|
+
:param l: the number of stops remaining
|
|
333
|
+
:return: the observation
|
|
334
|
+
"""
|
|
335
|
+
if not history:
|
|
336
|
+
raise ValueError("History must not be empty")
|
|
337
|
+
return [history[-1]]
|
|
338
|
+
|
|
339
|
+
def generate_random_particles(self, o: int, num_particles: int) -> List[int]:
|
|
340
|
+
"""
|
|
341
|
+
Generates a random list of state particles from a given observation
|
|
342
|
+
|
|
343
|
+
:param o: the latest observation
|
|
344
|
+
:param num_particles: the number of particles to generate
|
|
345
|
+
:return: the list of random particles
|
|
346
|
+
"""
|
|
347
|
+
particles = []
|
|
348
|
+
for i in range(num_particles):
|
|
349
|
+
particles.append(random.choice([0, 1]))
|
|
350
|
+
return particles
|
|
351
|
+
|
|
352
|
+
def manual_play(self) -> None:
|
|
353
|
+
"""
|
|
354
|
+
An interactive loop to test the environment manually
|
|
355
|
+
|
|
356
|
+
:return: None
|
|
357
|
+
"""
|
|
358
|
+
done = False
|
|
359
|
+
while True:
|
|
360
|
+
raw_input = input("> ")
|
|
361
|
+
raw_input = raw_input.strip()
|
|
362
|
+
if raw_input == "help":
|
|
363
|
+
print("Enter an action id to execute the action, "
|
|
364
|
+
"press R to reset,"
|
|
365
|
+
"press S to print the state, press A to print the actions, "
|
|
366
|
+
"press D to check if done"
|
|
367
|
+
"press H to print the history of actions")
|
|
368
|
+
elif raw_input == "A":
|
|
369
|
+
print(f"Attacker space: {self.action_space}")
|
|
370
|
+
elif raw_input == "S":
|
|
371
|
+
print(self.state)
|
|
372
|
+
elif raw_input == "D":
|
|
373
|
+
print(done)
|
|
374
|
+
elif raw_input == "H":
|
|
375
|
+
print(self.trace)
|
|
376
|
+
elif raw_input == "R":
|
|
377
|
+
print("Resetting the state")
|
|
378
|
+
self.reset()
|
|
379
|
+
else:
|
|
380
|
+
action_profile = raw_input
|
|
381
|
+
parts = action_profile.split(",")
|
|
382
|
+
a1 = int(parts[0])
|
|
383
|
+
a2 = int(parts[1])
|
|
384
|
+
stage_policy = []
|
|
385
|
+
for s in self.config.S:
|
|
386
|
+
if s != 2:
|
|
387
|
+
dist = [0.0, 0.0]
|
|
388
|
+
dist[a2] = 1.0
|
|
389
|
+
stage_policy.append(dist)
|
|
390
|
+
else:
|
|
391
|
+
stage_policy.append([0.5, 0.5])
|
|
392
|
+
pi2 = np.array(stage_policy)
|
|
393
|
+
_, _, done, _, _ = self.step(action_profile=(a1, (pi2, a2)))
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
from typing import Tuple, List, Dict, Any, Union
|
|
2
|
+
import numpy as np
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
import torch
|
|
5
|
+
import math
|
|
6
|
+
from csle_common.dao.simulation_config.base_env import BaseEnv
|
|
7
|
+
from csle_common.dao.training.mixed_multi_threshold_stopping_policy import MixedMultiThresholdStoppingPolicy
|
|
8
|
+
from gym_csle_stopping_game.dao.stopping_game_attacker_mdp_config import StoppingGameAttackerMdpConfig
|
|
9
|
+
from csle_common.dao.simulation_config.simulation_trace import SimulationTrace
|
|
10
|
+
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
|
|
11
|
+
import gym_csle_stopping_game.constants.constants as env_constants
|
|
12
|
+
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class StoppingGameMdpAttackerEnv(BaseEnv):
|
|
16
|
+
"""
|
|
17
|
+
OpenAI Gym Env for the MDP of the attacker when facing a static defender
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, config: StoppingGameAttackerMdpConfig):
|
|
21
|
+
"""
|
|
22
|
+
Initializes the environment
|
|
23
|
+
|
|
24
|
+
:param config: the configuration of the environment
|
|
25
|
+
"""
|
|
26
|
+
self.config = config
|
|
27
|
+
self.stopping_game_env: StoppingGameEnv = StoppingGameEnv(config=self.config.stopping_game_config)
|
|
28
|
+
|
|
29
|
+
# Setup spaces
|
|
30
|
+
self.observation_space = self.config.stopping_game_config.attacker_observation_space()
|
|
31
|
+
self.action_space = self.config.stopping_game_config.attacker_action_space()
|
|
32
|
+
|
|
33
|
+
# Setup static defender
|
|
34
|
+
self.static_defender_strategy = self.config.defender_strategy
|
|
35
|
+
|
|
36
|
+
# Setup Config
|
|
37
|
+
self.viewer: Union[None, Any] = None
|
|
38
|
+
self.metadata = {
|
|
39
|
+
'render.modes': ['human', 'rgb_array'],
|
|
40
|
+
'video.frames_per_second': 50 # Video rendering speed
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
self.latest_defender_obs: Union[None, List[Any], npt.NDArray[Any]] = None
|
|
44
|
+
self.latest_attacker_obs: Union[None, List[Any], npt.NDArray[Any]] = None
|
|
45
|
+
self.model: Union[None, Any] = None
|
|
46
|
+
|
|
47
|
+
# Reset
|
|
48
|
+
self.reset()
|
|
49
|
+
super().__init__()
|
|
50
|
+
|
|
51
|
+
def step(self, pi2: Union[npt.NDArray[Any], int, float, int, np.float64]) \
|
|
52
|
+
-> Tuple[npt.NDArray[Any], int, bool, bool, Dict[str, Any]]:
|
|
53
|
+
"""
|
|
54
|
+
Takes a step in the environment by executing the given action
|
|
55
|
+
|
|
56
|
+
:param pi2: attacker stage policy
|
|
57
|
+
:return: (obs, reward, terminated, truncated, info)
|
|
58
|
+
"""
|
|
59
|
+
if (type(pi2) is int or type(pi2) is float or type(pi2) is np.int64 or type(pi2) is np.int32 # type: ignore
|
|
60
|
+
or type(pi2) is np.float64): # type: ignore
|
|
61
|
+
a2 = pi2
|
|
62
|
+
if self.latest_attacker_obs is None:
|
|
63
|
+
raise ValueError("Attacker observation is None")
|
|
64
|
+
pi2 = self.calculate_stage_policy(o=list(self.latest_attacker_obs), a2=int(a2))
|
|
65
|
+
else:
|
|
66
|
+
if self.model is not None:
|
|
67
|
+
if self.latest_attacker_obs is None:
|
|
68
|
+
raise ValueError("Attacker observation is None")
|
|
69
|
+
pi2 = self.calculate_stage_policy(o=list(self.latest_attacker_obs))
|
|
70
|
+
a2 = StoppingGameUtil.sample_attacker_action(pi2=pi2, s=self.stopping_game_env.state.s)
|
|
71
|
+
else:
|
|
72
|
+
pi2 = np.array(pi2)
|
|
73
|
+
try:
|
|
74
|
+
if self.latest_attacker_obs is None:
|
|
75
|
+
raise ValueError("Attacker observation is None")
|
|
76
|
+
pi2 = self.calculate_stage_policy(o=list(self.latest_attacker_obs))
|
|
77
|
+
except Exception:
|
|
78
|
+
pass
|
|
79
|
+
a2 = StoppingGameUtil.sample_attacker_action(pi2=pi2, s=self.stopping_game_env.state.s)
|
|
80
|
+
|
|
81
|
+
# a2 = pi2
|
|
82
|
+
# pi2 = np.array([
|
|
83
|
+
# [0.5,0.5],
|
|
84
|
+
# [0.5,0.5],
|
|
85
|
+
# [0.5,0.5]
|
|
86
|
+
# ])
|
|
87
|
+
assert pi2.shape[0] == len(self.config.stopping_game_config.S)
|
|
88
|
+
assert pi2.shape[1] == len(self.config.stopping_game_config.A1)
|
|
89
|
+
|
|
90
|
+
# Get defender action from static strategy
|
|
91
|
+
a1 = self.static_defender_strategy.action(o=self.latest_defender_obs)
|
|
92
|
+
|
|
93
|
+
# Step the game
|
|
94
|
+
o, r, d, _, info = self.stopping_game_env.step((int(a1), (pi2, int(a2))))
|
|
95
|
+
self.latest_defender_obs = o[0]
|
|
96
|
+
self.latest_attacker_obs = o[1]
|
|
97
|
+
attacker_obs = o[1]
|
|
98
|
+
|
|
99
|
+
info[env_constants.ENV_METRICS.RETURN] = -info[env_constants.ENV_METRICS.RETURN]
|
|
100
|
+
info[env_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN] = \
|
|
101
|
+
-info[env_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN]
|
|
102
|
+
|
|
103
|
+
return attacker_obs, r[1], d, d, info
|
|
104
|
+
|
|
105
|
+
def reset(self, seed: Union[int, None] = None, soft: bool = False, options: Union[Dict[str, Any], None] = None) \
|
|
106
|
+
-> Tuple[npt.NDArray[Any], Dict[str, Any]]:
|
|
107
|
+
"""
|
|
108
|
+
Resets the environment state, this should be called whenever step() returns <done>
|
|
109
|
+
|
|
110
|
+
:param seed: the random seed
|
|
111
|
+
:param soft: boolean flag indicating whether it is a soft reset or not
|
|
112
|
+
:param options: optional configuration parameters
|
|
113
|
+
:return: initial observation
|
|
114
|
+
"""
|
|
115
|
+
o, _ = self.stopping_game_env.reset()
|
|
116
|
+
self.latest_defender_obs = o[0]
|
|
117
|
+
self.latest_attacker_obs = o[1]
|
|
118
|
+
attacker_obs = o[1]
|
|
119
|
+
info: Dict[str, Any] = {}
|
|
120
|
+
return attacker_obs, info
|
|
121
|
+
|
|
122
|
+
def set_model(self, model) -> None:
|
|
123
|
+
"""
|
|
124
|
+
Sets the model. Useful when using RL frameworks where the stage policy is not easy to extract
|
|
125
|
+
|
|
126
|
+
:param model: the model
|
|
127
|
+
:return: None
|
|
128
|
+
"""
|
|
129
|
+
self.model = model
|
|
130
|
+
|
|
131
|
+
def set_state(self, state: Any) -> None:
|
|
132
|
+
"""
|
|
133
|
+
Sets the state. Allows to simulate samples from specific states
|
|
134
|
+
|
|
135
|
+
:param state: the state
|
|
136
|
+
:return: None
|
|
137
|
+
"""
|
|
138
|
+
self.stopping_game_env.set_state(state=state)
|
|
139
|
+
|
|
140
|
+
def calculate_stage_policy(self, o: List[Any], a2: int = 0) -> npt.NDArray[Any]:
|
|
141
|
+
"""
|
|
142
|
+
Calculates the stage policy of a given model and observation
|
|
143
|
+
|
|
144
|
+
:param o: the observation
|
|
145
|
+
:return: the stage policy
|
|
146
|
+
"""
|
|
147
|
+
if self.model is None:
|
|
148
|
+
stage_policy = []
|
|
149
|
+
for s in self.config.stopping_game_config.S:
|
|
150
|
+
if s != 2:
|
|
151
|
+
dist = [0.0, 0.0]
|
|
152
|
+
dist[a2] = 1.0
|
|
153
|
+
stage_policy.append(dist)
|
|
154
|
+
else:
|
|
155
|
+
stage_policy.append([0.5, 0.5])
|
|
156
|
+
return np.array(stage_policy)
|
|
157
|
+
if isinstance(self.model, MixedMultiThresholdStoppingPolicy):
|
|
158
|
+
return np.array(self.model.stage_policy(o=o))
|
|
159
|
+
else:
|
|
160
|
+
b1 = o[1]
|
|
161
|
+
l = int(o[0])
|
|
162
|
+
stage_policy = []
|
|
163
|
+
for s in self.config.stopping_game_config.S:
|
|
164
|
+
if s != 2:
|
|
165
|
+
o = [l, b1, s]
|
|
166
|
+
stage_policy.append(self._get_attacker_dist(obs=o))
|
|
167
|
+
else:
|
|
168
|
+
stage_policy.append([0.5, 0.5])
|
|
169
|
+
return np.array(stage_policy)
|
|
170
|
+
|
|
171
|
+
def _get_attacker_dist(self, obs: List[Any]) -> List[float]:
|
|
172
|
+
"""
|
|
173
|
+
Utility function for getting the attacker's action distribution based on a given observation
|
|
174
|
+
|
|
175
|
+
:param obs: the given observation
|
|
176
|
+
:return: the action distribution
|
|
177
|
+
"""
|
|
178
|
+
np_obs = np.array([obs])
|
|
179
|
+
if self.model is None:
|
|
180
|
+
raise ValueError("Model is None")
|
|
181
|
+
actions, values, log_prob = self.model.policy.forward(obs=torch.tensor(np_obs).to(self.model.device))
|
|
182
|
+
action = actions[0]
|
|
183
|
+
if action == 1:
|
|
184
|
+
stop_prob = math.exp(log_prob)
|
|
185
|
+
else:
|
|
186
|
+
stop_prob = 1 - math.exp(log_prob)
|
|
187
|
+
return [1 - stop_prob, stop_prob]
|
|
188
|
+
|
|
189
|
+
def render(self, mode: str = 'human'):
|
|
190
|
+
"""
|
|
191
|
+
Renders the environment. Supported rendering modes: (1) human; and (2) rgb_array
|
|
192
|
+
|
|
193
|
+
:param mode: the rendering mode
|
|
194
|
+
:return: True (if human mode) otherwise an rgb array
|
|
195
|
+
"""
|
|
196
|
+
raise NotImplementedError("Rendering is not implemented for this environment")
|
|
197
|
+
|
|
198
|
+
def is_defense_action_legal(self, defense_action_id: int) -> bool:
|
|
199
|
+
"""
|
|
200
|
+
Checks whether a defender action in the environment is legal or not
|
|
201
|
+
|
|
202
|
+
:param defense_action_id: the id of the action
|
|
203
|
+
:return: True or False
|
|
204
|
+
"""
|
|
205
|
+
return True
|
|
206
|
+
|
|
207
|
+
def is_attack_action_legal(self, attack_action_id: int) -> bool:
|
|
208
|
+
"""
|
|
209
|
+
Checks whether an attacker action in the environment is legal or not
|
|
210
|
+
|
|
211
|
+
:param attack_action_id: the id of the attacker action
|
|
212
|
+
:return: True or False
|
|
213
|
+
"""
|
|
214
|
+
return True
|
|
215
|
+
|
|
216
|
+
def get_traces(self) -> List[SimulationTrace]:
|
|
217
|
+
"""
|
|
218
|
+
:return: the list of simulation traces
|
|
219
|
+
"""
|
|
220
|
+
return self.stopping_game_env.get_traces()
|
|
221
|
+
|
|
222
|
+
def reset_traces(self) -> None:
|
|
223
|
+
"""
|
|
224
|
+
Resets the list of traces
|
|
225
|
+
|
|
226
|
+
:return: None
|
|
227
|
+
"""
|
|
228
|
+
return self.stopping_game_env.reset_traces()
|
|
229
|
+
|
|
230
|
+
def generate_random_particles(self, o: int, num_particles: int) -> List[int]:
|
|
231
|
+
"""
|
|
232
|
+
Generates a random list of state particles from a given observation
|
|
233
|
+
|
|
234
|
+
:param o: the latest observation
|
|
235
|
+
:param num_particles: the number of particles to generate
|
|
236
|
+
:return: the list of random particles
|
|
237
|
+
"""
|
|
238
|
+
return self.stopping_game_env.generate_random_particles(o=o, num_particles=num_particles)
|
|
239
|
+
|
|
240
|
+
def get_actions_from_particles(self, particles: List[int], t: int, observation: int,
|
|
241
|
+
verbose: bool = False) -> List[int]:
|
|
242
|
+
"""
|
|
243
|
+
Prunes the set of actiosn based on the current particle set
|
|
244
|
+
|
|
245
|
+
:param particles: the set of particles
|
|
246
|
+
:param t: the current time step
|
|
247
|
+
:param observation: the latest observation
|
|
248
|
+
:param verbose: boolean flag indicating whether logging should be verbose or not
|
|
249
|
+
:return: the list of pruned actions
|
|
250
|
+
"""
|
|
251
|
+
return list(self.config.stopping_game_config.A2)
|
|
252
|
+
|
|
253
|
+
def manual_play(self) -> None:
|
|
254
|
+
"""
|
|
255
|
+
An interactive loop to test the environment manually
|
|
256
|
+
|
|
257
|
+
:return: None
|
|
258
|
+
"""
|
|
259
|
+
done = False
|
|
260
|
+
while True:
|
|
261
|
+
raw_input = input("> ")
|
|
262
|
+
raw_input = raw_input.strip()
|
|
263
|
+
if raw_input == "help":
|
|
264
|
+
print("Enter an action id to execute the action, "
|
|
265
|
+
"press R to reset,"
|
|
266
|
+
"press S to print the state, press A to print the actions, "
|
|
267
|
+
"press D to check if done"
|
|
268
|
+
"press H to print the history of actions")
|
|
269
|
+
elif raw_input == "A":
|
|
270
|
+
print(f"Action space: {self.action_space}")
|
|
271
|
+
elif raw_input == "S":
|
|
272
|
+
print(self.stopping_game_env.state)
|
|
273
|
+
elif raw_input == "D":
|
|
274
|
+
print(done)
|
|
275
|
+
elif raw_input == "H":
|
|
276
|
+
print(self.stopping_game_env.trace)
|
|
277
|
+
elif raw_input == "R":
|
|
278
|
+
print("Resetting the state")
|
|
279
|
+
self.reset()
|
|
280
|
+
else:
|
|
281
|
+
action_idx = int(raw_input)
|
|
282
|
+
_, _, done, _, _ = self.step(pi2=action_idx)
|