gym-csle-stopping-game 0.9.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,393 @@
1
+ import random
2
+ from typing import Tuple, Dict, List, Any, Union
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+ import time
6
+ import math
7
+ import csle_common.constants.constants as constants
8
+ from csle_common.dao.simulation_config.base_env import BaseEnv
9
+ from csle_common.dao.simulation_config.simulation_trace import SimulationTrace
10
+ from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
11
+ from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
12
+ from gym_csle_stopping_game.dao.stopping_game_state import StoppingGameState
13
+ import gym_csle_stopping_game.constants.constants as env_constants
14
+
15
+
16
+ class StoppingGameEnv(BaseEnv):
17
+ """
18
+ OpenAI Gym Env for the csle-stopping-game
19
+ """
20
+
21
+ def __init__(self, config: StoppingGameConfig):
22
+ """
23
+ Initializes the environment
24
+
25
+ :param config: the environment configuration
26
+ """
27
+ self.config = config
28
+
29
+ # Initialize environment state
30
+ self.state = StoppingGameState(b1=self.config.b1, L=self.config.L)
31
+ # Setup spaces
32
+ self.attacker_observation_space = self.config.attacker_observation_space()
33
+ self.defender_observation_space = self.config.defender_observation_space()
34
+ self.attacker_action_space = self.config.attacker_action_space()
35
+ self.defender_action_space = self.config.defender_action_space()
36
+
37
+ self.action_space = self.defender_action_space
38
+ self.observation_space = self.defender_observation_space
39
+
40
+ # Setup traces
41
+ self.traces: List[SimulationTrace] = []
42
+ self.trace = SimulationTrace(simulation_env=self.config.env_name)
43
+
44
+ # Reset
45
+ self.reset()
46
+ super().__init__()
47
+
48
+ def step(self, action_profile: Tuple[int, Tuple[npt.NDArray[Any], int]]) \
49
+ -> Tuple[Tuple[npt.NDArray[Any], npt.NDArray[Any]], Tuple[int, int], bool, bool, Dict[str, Any]]:
50
+ """
51
+ Takes a step in the environment by executing the given action
52
+
53
+ :param action_profile: the actions to take (both players actions
54
+ :return: (obs, reward, terminated, truncated, info)
55
+ """
56
+
57
+ # Setup initial values
58
+ a1, a2_profile = action_profile
59
+ pi2, a2 = a2_profile
60
+ assert pi2.shape[0] == len(self.config.S)
61
+ assert pi2.shape[1] == len(self.config.A2)
62
+ done = False
63
+ info: Dict[str, Any] = {}
64
+
65
+ o = max(self.config.O)
66
+ if self.state.s == 2:
67
+ done = True
68
+ r = 0
69
+ else:
70
+ # Compute r, s', b',o'
71
+ r = self.config.R[self.state.l - 1][a1][a2][self.state.s]
72
+ self.state.s = StoppingGameUtil.sample_next_state(l=self.state.l, a1=a1, a2=a2, T=self.config.T,
73
+ S=self.config.S, s=self.state.s)
74
+ o = StoppingGameUtil.sample_next_observation(Z=self.config.Z,
75
+ O=self.config.O, s_prime=self.state.s)
76
+ if self.config.compute_beliefs:
77
+ self.state.b = StoppingGameUtil.next_belief(o=o, a1=a1, b=self.state.b, pi2=pi2,
78
+ config=self.config,
79
+ l=self.state.l, a2=a2)
80
+
81
+ # Update stops remaining
82
+ self.state.l = self.state.l - a1
83
+
84
+ # Update time-step
85
+ self.state.t += 1
86
+
87
+ # Populate info dict
88
+ info[env_constants.ENV_METRICS.STOPS_REMAINING] = self.state.l
89
+ info[env_constants.ENV_METRICS.STATE] = self.state.s
90
+ info[env_constants.ENV_METRICS.DEFENDER_ACTION] = a1
91
+ info[env_constants.ENV_METRICS.ATTACKER_ACTION] = a2
92
+ info[env_constants.ENV_METRICS.OBSERVATION] = o
93
+ info[env_constants.ENV_METRICS.TIME_STEP] = self.state.t
94
+
95
+ # Get observations
96
+ attacker_obs = self.state.attacker_observation()
97
+ defender_obs = self.state.defender_observation()
98
+
99
+ # Log trace
100
+ if self.config.save_trace:
101
+ self.trace.defender_rewards.append(r)
102
+ self.trace.attacker_rewards.append(-r)
103
+ self.trace.attacker_actions.append(a2)
104
+ self.trace.defender_actions.append(a1)
105
+ self.trace.infos.append(info)
106
+ self.trace.states.append(self.state.s)
107
+ self.trace.beliefs.append(self.state.b[1])
108
+ self.trace.infrastructure_metrics.append(o)
109
+ if not done:
110
+ self.trace.attacker_observations.append(attacker_obs)
111
+ self.trace.defender_observations.append(defender_obs)
112
+
113
+ # Populate info
114
+ info = self._info(info)
115
+
116
+ return (defender_obs, attacker_obs), (r, -r), done, done, info
117
+
118
+ def mean(self, prob_vector):
119
+ """
120
+ Utility function for getting the mean of a vector
121
+
122
+ :param prob_vector: the vector to take the mean of
123
+ :return: the mean
124
+ """
125
+ m = 0
126
+ for i in range(len(prob_vector)):
127
+ m += prob_vector[i] * i
128
+ return m
129
+
130
+ def weighted_intrusion_prediction_distance(self, intrusion_start: int, first_stop: int):
131
+ """
132
+ Computes the weighted intrusion start time prediction distance (Wang, Hammar, Stadler, 2022)
133
+
134
+ :param intrusion_start: the intrusion start time
135
+ :param first_stop: the predicted start time
136
+ :return: the weighted distance
137
+ """
138
+ if first_stop <= intrusion_start:
139
+ return 1 - (10 / 10)
140
+ else:
141
+ return 1 - (min(10, (first_stop - (intrusion_start + 1))) / 2) / 10
142
+
143
+ def _info(self, info: Dict[str, Any]) -> Dict[str, Any]:
144
+ """
145
+ Adds the cumulative reward and episode length to the info dict
146
+
147
+ :param info: the info dict to update
148
+ :return: the updated info dict
149
+ """
150
+ R = 0
151
+ for i in range(len(self.trace.defender_rewards)):
152
+ R += self.trace.defender_rewards[i] * math.pow(self.config.gamma, i)
153
+ info[env_constants.ENV_METRICS.RETURN] = sum(self.trace.defender_rewards)
154
+ info[env_constants.ENV_METRICS.TIME_HORIZON] = len(self.trace.defender_actions)
155
+ stop = self.config.L
156
+ for i in range(1, self.config.L + 1):
157
+ info[f"{env_constants.ENV_METRICS.STOP}_{i}"] = len(self.trace.states)
158
+ for i in range(len(self.trace.defender_actions)):
159
+ if self.trace.defender_actions[i] == 1:
160
+ info[f"{env_constants.ENV_METRICS.STOP}_{stop}"] = i
161
+ stop -= 1
162
+ intrusion_start = len(self.trace.defender_actions)
163
+ for i in range(len(self.trace.attacker_actions)):
164
+ if self.trace.attacker_actions[i] == 1:
165
+ intrusion_start = i
166
+ break
167
+ intrusion_end = len(self.trace.attacker_actions)
168
+ info[env_constants.ENV_METRICS.INTRUSION_START] = intrusion_start
169
+ info[env_constants.ENV_METRICS.INTRUSION_END] = intrusion_end
170
+ info[env_constants.ENV_METRICS.START_POINT_CORRECT] = \
171
+ int(intrusion_start == (info[f"{env_constants.ENV_METRICS.STOP}_1"] + 1))
172
+ info[env_constants.ENV_METRICS.WEIGHTED_INTRUSION_PREDICTION_DISTANCE] = \
173
+ self.weighted_intrusion_prediction_distance(intrusion_start=intrusion_start,
174
+ first_stop=info[f"{env_constants.ENV_METRICS.STOP}_1"])
175
+ info[env_constants.ENV_METRICS.INTRUSION_LENGTH] = intrusion_end - intrusion_start
176
+ upper_bound_return = 0
177
+ defender_baseline_stop_on_first_alert_return = 0
178
+ upper_bound_stops_remaining = self.config.L
179
+ defender_baseline_stop_on_first_alert_stops_remaining = self.config.L
180
+ for i in range(len(self.trace.states)):
181
+ if defender_baseline_stop_on_first_alert_stops_remaining > 0:
182
+ if self.trace.infrastructure_metrics[i] > 0:
183
+ defender_baseline_stop_on_first_alert_return += \
184
+ self.config.R[int(defender_baseline_stop_on_first_alert_stops_remaining) - 1][1][
185
+ self.trace.attacker_actions[i]][self.trace.states[i]] * math.pow(self.config.gamma, i)
186
+ defender_baseline_stop_on_first_alert_stops_remaining -= 1
187
+ else:
188
+ defender_baseline_stop_on_first_alert_return += \
189
+ self.config.R[int(defender_baseline_stop_on_first_alert_stops_remaining) - 1][0][
190
+ self.trace.attacker_actions[i]][self.trace.states[i]] * math.pow(self.config.gamma, i)
191
+ if upper_bound_stops_remaining > 0:
192
+ if self.trace.states[i] == 0:
193
+ r = self.config.R[int(upper_bound_stops_remaining) - 1][0][self.trace.attacker_actions[i]][
194
+ self.trace.states[i]]
195
+ upper_bound_return += r * math.pow(self.config.gamma, i)
196
+ else:
197
+ r = self.config.R[int(upper_bound_stops_remaining) - 1][1][self.trace.attacker_actions[i]][
198
+ self.trace.states[i]]
199
+ upper_bound_return += r * math.pow(self.config.gamma, i)
200
+ upper_bound_stops_remaining -= 1
201
+ info[env_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN] = upper_bound_return
202
+ info[env_constants.ENV_METRICS.AVERAGE_DEFENDER_BASELINE_STOP_ON_FIRST_ALERT_RETURN] = \
203
+ defender_baseline_stop_on_first_alert_return
204
+ return info
205
+
206
+ def reset(self, seed: Union[None, int] = None, soft: bool = False, options: Union[Dict[str, Any], None] = None) \
207
+ -> Tuple[Tuple[npt.NDArray[Any], npt.NDArray[Any]], Dict[str, Any]]:
208
+ """
209
+ Resets the environment state, this should be called whenever step() returns <done>
210
+
211
+ :param seed: the random seed
212
+ :param soft: boolean flag indicating whether it is a soft reset or not
213
+ :param options: optional configuration parameters
214
+ :return: initial observation and info
215
+ """
216
+ super().reset(seed=seed)
217
+ self.state.reset()
218
+ if len(self.trace.attacker_rewards) > 0:
219
+ self.traces.append(self.trace)
220
+ self.trace = SimulationTrace(simulation_env=self.config.env_name)
221
+ attacker_obs = self.state.attacker_observation()
222
+ defender_obs = self.state.defender_observation()
223
+ if self.config.save_trace:
224
+ self.trace.attacker_observations.append(attacker_obs)
225
+ self.trace.defender_observations.append(defender_obs)
226
+ info: Dict[str, Any] = {}
227
+ info[env_constants.ENV_METRICS.STOPS_REMAINING] = self.state.l
228
+ info[env_constants.ENV_METRICS.STATE] = self.state.s
229
+ info[env_constants.ENV_METRICS.OBSERVATION] = 0
230
+ info[env_constants.ENV_METRICS.TIME_STEP] = self.state.t
231
+ return (defender_obs, attacker_obs), info
232
+
233
+ def render(self, mode: str = 'human'):
234
+ """
235
+ Renders the environment. Supported rendering modes: (1) human; and (2) rgb_array
236
+
237
+ :param mode: the rendering mode
238
+ :return: True (if human mode) otherwise an rgb array
239
+ """
240
+ raise NotImplementedError("Rendering is not implemented for this environment")
241
+
242
+ def is_defense_action_legal(self, defense_action_id: int) -> bool:
243
+ """
244
+ Checks whether a defender action in the environment is legal or not
245
+
246
+ :param defense_action_id: the id of the action
247
+ :return: True or False
248
+ """
249
+ return True
250
+
251
+ def is_attack_action_legal(self, attack_action_id: int) -> bool:
252
+ """
253
+ Checks whether an attacker action in the environment is legal or not
254
+
255
+ :param attack_action_id: the id of the attacker action
256
+ :return: True or False
257
+ """
258
+ return True
259
+
260
+ def get_traces(self) -> List[SimulationTrace]:
261
+ """
262
+ :return: the list of simulation traces
263
+ """
264
+ return self.traces
265
+
266
+ def reset_traces(self) -> None:
267
+ """
268
+ Resets the list of traces
269
+
270
+ :return: None
271
+ """
272
+ self.traces = []
273
+
274
+ def __checkpoint_traces(self) -> None:
275
+ """
276
+ Checkpoints agent traces
277
+ :return: None
278
+ """
279
+ ts = time.time()
280
+ SimulationTrace.save_traces(traces_save_dir=constants.LOGGING.DEFAULT_LOG_DIR,
281
+ traces=self.traces, traces_file=f"taus{ts}.json")
282
+
283
+ def set_model(self, model) -> None:
284
+ """
285
+ Sets the model. Useful when using RL frameworks where the stage policy is not easy to extract
286
+
287
+ :param model: the model
288
+ :return: None
289
+ """
290
+ self.model = model
291
+
292
+ def set_state(self, state: Union[StoppingGameState, int, Tuple[int, int]]) -> None:
293
+ """
294
+ Sets the state. Allows to simulate samples from specific states
295
+
296
+ :param state: the state
297
+ :return: None
298
+ """
299
+ if isinstance(state, StoppingGameState):
300
+ self.state = state
301
+ elif type(state) is int or type(state) is np.int64:
302
+ self.state.s = state
303
+ self.state.l = self.config.L
304
+ elif type(state) is tuple:
305
+ self.state.s = state[0]
306
+ self.state.l = state[1]
307
+ else:
308
+ raise ValueError(f"state: {state} not valid")
309
+
310
+ def is_state_terminal(self, state: Union[StoppingGameState, int, Tuple[int, int]]) -> bool:
311
+ """
312
+ Checks whether a given state is terminal or not
313
+
314
+ :param state: the state
315
+ :return: True if terminal, else false
316
+ """
317
+ if isinstance(state, StoppingGameState):
318
+ return state.s == 2
319
+ elif type(state) is int or type(state) is np.int64:
320
+ return state == 2
321
+ elif type(state) is tuple:
322
+ return state[0] == 2
323
+ else:
324
+ raise ValueError(f"state: {state} not valid")
325
+
326
+ def get_observation_from_history(self, history: List[int], pi2: npt.NDArray[Any], l: int) -> List[Any]:
327
+ """
328
+ Utility method to get a hidden observation based on a history
329
+
330
+ :param history: the history to get the observation from
331
+ :param pi2: the attacker stage strategy
332
+ :param l: the number of stops remaining
333
+ :return: the observation
334
+ """
335
+ if not history:
336
+ raise ValueError("History must not be empty")
337
+ return [history[-1]]
338
+
339
+ def generate_random_particles(self, o: int, num_particles: int) -> List[int]:
340
+ """
341
+ Generates a random list of state particles from a given observation
342
+
343
+ :param o: the latest observation
344
+ :param num_particles: the number of particles to generate
345
+ :return: the list of random particles
346
+ """
347
+ particles = []
348
+ for i in range(num_particles):
349
+ particles.append(random.choice([0, 1]))
350
+ return particles
351
+
352
+ def manual_play(self) -> None:
353
+ """
354
+ An interactive loop to test the environment manually
355
+
356
+ :return: None
357
+ """
358
+ done = False
359
+ while True:
360
+ raw_input = input("> ")
361
+ raw_input = raw_input.strip()
362
+ if raw_input == "help":
363
+ print("Enter an action id to execute the action, "
364
+ "press R to reset,"
365
+ "press S to print the state, press A to print the actions, "
366
+ "press D to check if done"
367
+ "press H to print the history of actions")
368
+ elif raw_input == "A":
369
+ print(f"Attacker space: {self.action_space}")
370
+ elif raw_input == "S":
371
+ print(self.state)
372
+ elif raw_input == "D":
373
+ print(done)
374
+ elif raw_input == "H":
375
+ print(self.trace)
376
+ elif raw_input == "R":
377
+ print("Resetting the state")
378
+ self.reset()
379
+ else:
380
+ action_profile = raw_input
381
+ parts = action_profile.split(",")
382
+ a1 = int(parts[0])
383
+ a2 = int(parts[1])
384
+ stage_policy = []
385
+ for s in self.config.S:
386
+ if s != 2:
387
+ dist = [0.0, 0.0]
388
+ dist[a2] = 1.0
389
+ stage_policy.append(dist)
390
+ else:
391
+ stage_policy.append([0.5, 0.5])
392
+ pi2 = np.array(stage_policy)
393
+ _, _, done, _, _ = self.step(action_profile=(a1, (pi2, a2)))
@@ -0,0 +1,282 @@
1
+ from typing import Tuple, List, Dict, Any, Union
2
+ import numpy as np
3
+ import numpy.typing as npt
4
+ import torch
5
+ import math
6
+ from csle_common.dao.simulation_config.base_env import BaseEnv
7
+ from csle_common.dao.training.mixed_multi_threshold_stopping_policy import MixedMultiThresholdStoppingPolicy
8
+ from gym_csle_stopping_game.dao.stopping_game_attacker_mdp_config import StoppingGameAttackerMdpConfig
9
+ from csle_common.dao.simulation_config.simulation_trace import SimulationTrace
10
+ from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
11
+ import gym_csle_stopping_game.constants.constants as env_constants
12
+ from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
13
+
14
+
15
+ class StoppingGameMdpAttackerEnv(BaseEnv):
16
+ """
17
+ OpenAI Gym Env for the MDP of the attacker when facing a static defender
18
+ """
19
+
20
+ def __init__(self, config: StoppingGameAttackerMdpConfig):
21
+ """
22
+ Initializes the environment
23
+
24
+ :param config: the configuration of the environment
25
+ """
26
+ self.config = config
27
+ self.stopping_game_env: StoppingGameEnv = StoppingGameEnv(config=self.config.stopping_game_config)
28
+
29
+ # Setup spaces
30
+ self.observation_space = self.config.stopping_game_config.attacker_observation_space()
31
+ self.action_space = self.config.stopping_game_config.attacker_action_space()
32
+
33
+ # Setup static defender
34
+ self.static_defender_strategy = self.config.defender_strategy
35
+
36
+ # Setup Config
37
+ self.viewer: Union[None, Any] = None
38
+ self.metadata = {
39
+ 'render.modes': ['human', 'rgb_array'],
40
+ 'video.frames_per_second': 50 # Video rendering speed
41
+ }
42
+
43
+ self.latest_defender_obs: Union[None, List[Any], npt.NDArray[Any]] = None
44
+ self.latest_attacker_obs: Union[None, List[Any], npt.NDArray[Any]] = None
45
+ self.model: Union[None, Any] = None
46
+
47
+ # Reset
48
+ self.reset()
49
+ super().__init__()
50
+
51
+ def step(self, pi2: Union[npt.NDArray[Any], int, float, int, np.float64]) \
52
+ -> Tuple[npt.NDArray[Any], int, bool, bool, Dict[str, Any]]:
53
+ """
54
+ Takes a step in the environment by executing the given action
55
+
56
+ :param pi2: attacker stage policy
57
+ :return: (obs, reward, terminated, truncated, info)
58
+ """
59
+ if (type(pi2) is int or type(pi2) is float or type(pi2) is np.int64 or type(pi2) is np.int32 # type: ignore
60
+ or type(pi2) is np.float64): # type: ignore
61
+ a2 = pi2
62
+ if self.latest_attacker_obs is None:
63
+ raise ValueError("Attacker observation is None")
64
+ pi2 = self.calculate_stage_policy(o=list(self.latest_attacker_obs), a2=int(a2))
65
+ else:
66
+ if self.model is not None:
67
+ if self.latest_attacker_obs is None:
68
+ raise ValueError("Attacker observation is None")
69
+ pi2 = self.calculate_stage_policy(o=list(self.latest_attacker_obs))
70
+ a2 = StoppingGameUtil.sample_attacker_action(pi2=pi2, s=self.stopping_game_env.state.s)
71
+ else:
72
+ pi2 = np.array(pi2)
73
+ try:
74
+ if self.latest_attacker_obs is None:
75
+ raise ValueError("Attacker observation is None")
76
+ pi2 = self.calculate_stage_policy(o=list(self.latest_attacker_obs))
77
+ except Exception:
78
+ pass
79
+ a2 = StoppingGameUtil.sample_attacker_action(pi2=pi2, s=self.stopping_game_env.state.s)
80
+
81
+ # a2 = pi2
82
+ # pi2 = np.array([
83
+ # [0.5,0.5],
84
+ # [0.5,0.5],
85
+ # [0.5,0.5]
86
+ # ])
87
+ assert pi2.shape[0] == len(self.config.stopping_game_config.S)
88
+ assert pi2.shape[1] == len(self.config.stopping_game_config.A1)
89
+
90
+ # Get defender action from static strategy
91
+ a1 = self.static_defender_strategy.action(o=self.latest_defender_obs)
92
+
93
+ # Step the game
94
+ o, r, d, _, info = self.stopping_game_env.step((int(a1), (pi2, int(a2))))
95
+ self.latest_defender_obs = o[0]
96
+ self.latest_attacker_obs = o[1]
97
+ attacker_obs = o[1]
98
+
99
+ info[env_constants.ENV_METRICS.RETURN] = -info[env_constants.ENV_METRICS.RETURN]
100
+ info[env_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN] = \
101
+ -info[env_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN]
102
+
103
+ return attacker_obs, r[1], d, d, info
104
+
105
+ def reset(self, seed: Union[int, None] = None, soft: bool = False, options: Union[Dict[str, Any], None] = None) \
106
+ -> Tuple[npt.NDArray[Any], Dict[str, Any]]:
107
+ """
108
+ Resets the environment state, this should be called whenever step() returns <done>
109
+
110
+ :param seed: the random seed
111
+ :param soft: boolean flag indicating whether it is a soft reset or not
112
+ :param options: optional configuration parameters
113
+ :return: initial observation
114
+ """
115
+ o, _ = self.stopping_game_env.reset()
116
+ self.latest_defender_obs = o[0]
117
+ self.latest_attacker_obs = o[1]
118
+ attacker_obs = o[1]
119
+ info: Dict[str, Any] = {}
120
+ return attacker_obs, info
121
+
122
+ def set_model(self, model) -> None:
123
+ """
124
+ Sets the model. Useful when using RL frameworks where the stage policy is not easy to extract
125
+
126
+ :param model: the model
127
+ :return: None
128
+ """
129
+ self.model = model
130
+
131
+ def set_state(self, state: Any) -> None:
132
+ """
133
+ Sets the state. Allows to simulate samples from specific states
134
+
135
+ :param state: the state
136
+ :return: None
137
+ """
138
+ self.stopping_game_env.set_state(state=state)
139
+
140
+ def calculate_stage_policy(self, o: List[Any], a2: int = 0) -> npt.NDArray[Any]:
141
+ """
142
+ Calculates the stage policy of a given model and observation
143
+
144
+ :param o: the observation
145
+ :return: the stage policy
146
+ """
147
+ if self.model is None:
148
+ stage_policy = []
149
+ for s in self.config.stopping_game_config.S:
150
+ if s != 2:
151
+ dist = [0.0, 0.0]
152
+ dist[a2] = 1.0
153
+ stage_policy.append(dist)
154
+ else:
155
+ stage_policy.append([0.5, 0.5])
156
+ return np.array(stage_policy)
157
+ if isinstance(self.model, MixedMultiThresholdStoppingPolicy):
158
+ return np.array(self.model.stage_policy(o=o))
159
+ else:
160
+ b1 = o[1]
161
+ l = int(o[0])
162
+ stage_policy = []
163
+ for s in self.config.stopping_game_config.S:
164
+ if s != 2:
165
+ o = [l, b1, s]
166
+ stage_policy.append(self._get_attacker_dist(obs=o))
167
+ else:
168
+ stage_policy.append([0.5, 0.5])
169
+ return np.array(stage_policy)
170
+
171
+ def _get_attacker_dist(self, obs: List[Any]) -> List[float]:
172
+ """
173
+ Utility function for getting the attacker's action distribution based on a given observation
174
+
175
+ :param obs: the given observation
176
+ :return: the action distribution
177
+ """
178
+ np_obs = np.array([obs])
179
+ if self.model is None:
180
+ raise ValueError("Model is None")
181
+ actions, values, log_prob = self.model.policy.forward(obs=torch.tensor(np_obs).to(self.model.device))
182
+ action = actions[0]
183
+ if action == 1:
184
+ stop_prob = math.exp(log_prob)
185
+ else:
186
+ stop_prob = 1 - math.exp(log_prob)
187
+ return [1 - stop_prob, stop_prob]
188
+
189
+ def render(self, mode: str = 'human'):
190
+ """
191
+ Renders the environment. Supported rendering modes: (1) human; and (2) rgb_array
192
+
193
+ :param mode: the rendering mode
194
+ :return: True (if human mode) otherwise an rgb array
195
+ """
196
+ raise NotImplementedError("Rendering is not implemented for this environment")
197
+
198
+ def is_defense_action_legal(self, defense_action_id: int) -> bool:
199
+ """
200
+ Checks whether a defender action in the environment is legal or not
201
+
202
+ :param defense_action_id: the id of the action
203
+ :return: True or False
204
+ """
205
+ return True
206
+
207
+ def is_attack_action_legal(self, attack_action_id: int) -> bool:
208
+ """
209
+ Checks whether an attacker action in the environment is legal or not
210
+
211
+ :param attack_action_id: the id of the attacker action
212
+ :return: True or False
213
+ """
214
+ return True
215
+
216
+ def get_traces(self) -> List[SimulationTrace]:
217
+ """
218
+ :return: the list of simulation traces
219
+ """
220
+ return self.stopping_game_env.get_traces()
221
+
222
+ def reset_traces(self) -> None:
223
+ """
224
+ Resets the list of traces
225
+
226
+ :return: None
227
+ """
228
+ return self.stopping_game_env.reset_traces()
229
+
230
+ def generate_random_particles(self, o: int, num_particles: int) -> List[int]:
231
+ """
232
+ Generates a random list of state particles from a given observation
233
+
234
+ :param o: the latest observation
235
+ :param num_particles: the number of particles to generate
236
+ :return: the list of random particles
237
+ """
238
+ return self.stopping_game_env.generate_random_particles(o=o, num_particles=num_particles)
239
+
240
+ def get_actions_from_particles(self, particles: List[int], t: int, observation: int,
241
+ verbose: bool = False) -> List[int]:
242
+ """
243
+ Prunes the set of actiosn based on the current particle set
244
+
245
+ :param particles: the set of particles
246
+ :param t: the current time step
247
+ :param observation: the latest observation
248
+ :param verbose: boolean flag indicating whether logging should be verbose or not
249
+ :return: the list of pruned actions
250
+ """
251
+ return list(self.config.stopping_game_config.A2)
252
+
253
+ def manual_play(self) -> None:
254
+ """
255
+ An interactive loop to test the environment manually
256
+
257
+ :return: None
258
+ """
259
+ done = False
260
+ while True:
261
+ raw_input = input("> ")
262
+ raw_input = raw_input.strip()
263
+ if raw_input == "help":
264
+ print("Enter an action id to execute the action, "
265
+ "press R to reset,"
266
+ "press S to print the state, press A to print the actions, "
267
+ "press D to check if done"
268
+ "press H to print the history of actions")
269
+ elif raw_input == "A":
270
+ print(f"Action space: {self.action_space}")
271
+ elif raw_input == "S":
272
+ print(self.stopping_game_env.state)
273
+ elif raw_input == "D":
274
+ print(done)
275
+ elif raw_input == "H":
276
+ print(self.stopping_game_env.trace)
277
+ elif raw_input == "R":
278
+ print("Resetting the state")
279
+ self.reset()
280
+ else:
281
+ action_idx = int(raw_input)
282
+ _, _, done, _, _ = self.step(pi2=action_idx)