gym-csle-stopping-game 0.6.3__tar.gz → 0.6.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of gym-csle-stopping-game might be problematic. Click here for more details.

Files changed (31) hide show
  1. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/PKG-INFO +1 -1
  2. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/setup.cfg +5 -5
  3. gym_csle_stopping_game-0.6.5/src/gym_csle_stopping_game/__version__.py +1 -0
  4. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/dao/stopping_game_config.py +1 -1
  5. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/dao/stopping_game_state.py +1 -1
  6. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/envs/stopping_game_mdp_attacker_env.py +1 -1
  7. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/util/stopping_game_util.py +54 -5
  8. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game.egg-info/PKG-INFO +1 -1
  9. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game.egg-info/requires.txt +5 -5
  10. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/tests/test_stopping_game_env.py +18 -49
  11. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/tests/test_stopping_game_mdp_attacker_env.py +36 -19
  12. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/tests/test_stopping_game_pomdp_defender_env.py +24 -12
  13. gym_csle_stopping_game-0.6.3/src/gym_csle_stopping_game/__version__.py +0 -1
  14. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/pyproject.toml +0 -0
  15. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/setup.py +0 -0
  16. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/__init__.py +0 -0
  17. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/constants/__init__.py +0 -0
  18. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/constants/constants.py +0 -0
  19. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/dao/__init__.py +0 -0
  20. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/dao/stopping_game_attacker_mdp_config.py +0 -0
  21. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/dao/stopping_game_defender_pomdp_config.py +0 -0
  22. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/envs/__init__.py +0 -0
  23. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/envs/stopping_game_env.py +0 -0
  24. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/envs/stopping_game_pomdp_defender_env.py +0 -0
  25. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game/util/__init__.py +0 -0
  26. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game.egg-info/SOURCES.txt +0 -0
  27. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game.egg-info/dependency_links.txt +0 -0
  28. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game.egg-info/not-zip-safe +0 -0
  29. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/src/gym_csle_stopping_game.egg-info/top_level.txt +0 -0
  30. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/tests/test_stopping_game_dao.py +0 -0
  31. {gym_csle_stopping_game-0.6.3 → gym_csle_stopping_game-0.6.5}/tests/test_stopping_game_util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gym_csle_stopping_game
3
- Version: 0.6.3
3
+ Version: 0.6.5
4
4
  Summary: OpenAI gym reinforcement learning environment of a Dynkin (Optimal stopping) game in CSLE
5
5
  Author: Kim Hammar
6
6
  Author-email: hammar.kim@gmail.com
@@ -20,11 +20,11 @@ classifiers =
20
20
  [options]
21
21
  install_requires =
22
22
  gymnasium>=0.27.1
23
- csle-base>=0.6.3
24
- csle-common>=0.6.3
25
- csle-attacker>=0.6.3
26
- csle-defender>=0.6.3
27
- csle-collector>=0.6.3
23
+ csle-base>=0.6.5
24
+ csle-common>=0.6.5
25
+ csle-attacker>=0.6.5
26
+ csle-defender>=0.6.5
27
+ csle-collector>=0.6.5
28
28
  python_requires = >=3.8
29
29
  package_dir =
30
30
  =src
@@ -0,0 +1 @@
1
+ __version__ = '0.6.5'
@@ -14,7 +14,7 @@ class StoppingGameConfig(SimulationEnvInputConfig):
14
14
  T: npt.NDArray[Any], O: npt.NDArray[np.int_], Z: npt.NDArray[Any],
15
15
  R: npt.NDArray[Any], S: npt.NDArray[np.int_], A1: npt.NDArray[np.int_],
16
16
  A2: npt.NDArray[np.int_], L: int, R_INT: int, R_COST: int, R_SLA: int, R_ST: int,
17
- b1: npt.NDArray[np.float_],
17
+ b1: npt.NDArray[np.float64],
18
18
  save_dir: str, checkpoint_traces_freq: int, gamma: float = 1, compute_beliefs: bool = True,
19
19
  save_trace: bool = True) -> None:
20
20
  """
@@ -10,7 +10,7 @@ class StoppingGameState(JSONSerializable):
10
10
  Represents the state of the optimal stopping game
11
11
  """
12
12
 
13
- def __init__(self, b1: npt.NDArray[np.float_], L: int) -> None:
13
+ def __init__(self, b1: npt.NDArray[np.float64], L: int) -> None:
14
14
  """
15
15
  Intializes the state
16
16
 
@@ -48,7 +48,7 @@ class StoppingGameMdpAttackerEnv(BaseEnv):
48
48
  self.reset()
49
49
  super().__init__()
50
50
 
51
- def step(self, pi2: Union[npt.NDArray[Any], int, float, np.int_, np.float_]) \
51
+ def step(self, pi2: Union[npt.NDArray[Any], int, float, np.int_, np.float64]) \
52
52
  -> Tuple[npt.NDArray[Any], int, bool, bool, Dict[str, Any]]:
53
53
  """
54
54
  Takes a step in the environment by executing the given action
@@ -11,7 +11,7 @@ class StoppingGameUtil:
11
11
  """
12
12
 
13
13
  @staticmethod
14
- def b1() -> npt.NDArray[np.float_]:
14
+ def b1() -> npt.NDArray[np.float64]:
15
15
  """
16
16
  Gets the initial belief
17
17
 
@@ -233,7 +233,7 @@ class StoppingGameUtil:
233
233
  return int(np.random.choice(np.arange(0, len(S)), p=state_probs))
234
234
 
235
235
  @staticmethod
236
- def sample_initial_state(b1: npt.NDArray[np.float_]) -> int:
236
+ def sample_initial_state(b1: npt.NDArray[np.float64]) -> int:
237
237
  """
238
238
  Samples the initial state
239
239
 
@@ -264,7 +264,7 @@ class StoppingGameUtil:
264
264
  return int(o)
265
265
 
266
266
  @staticmethod
267
- def bayes_filter(s_prime: int, o: int, a1: int, b: npt.NDArray[np.float_], pi2: npt.NDArray[Any], l: int,
267
+ def bayes_filter(s_prime: int, o: int, a1: int, b: npt.NDArray[np.float64], pi2: npt.NDArray[Any], l: int,
268
268
  config: StoppingGameConfig) -> float:
269
269
  """
270
270
  A Bayesian filter to compute the belief of player 1
@@ -302,8 +302,8 @@ class StoppingGameUtil:
302
302
  return float(b_prime_s_prime)
303
303
 
304
304
  @staticmethod
305
- def next_belief(o: int, a1: int, b: npt.NDArray[np.float_], pi2: npt.NDArray[Any],
306
- config: StoppingGameConfig, l: int, a2: int = 0, s: int = 0) -> npt.NDArray[np.float_]:
305
+ def next_belief(o: int, a1: int, b: npt.NDArray[np.float64], pi2: npt.NDArray[Any],
306
+ config: StoppingGameConfig, l: int, a2: int = 0, s: int = 0) -> npt.NDArray[np.float64]:
307
307
  """
308
308
  Computes the next belief using a Bayesian filter
309
309
 
@@ -337,3 +337,52 @@ class StoppingGameUtil:
337
337
  :return: a2 is the attacker action
338
338
  """
339
339
  return int(np.random.choice(np.arange(0, len(pi2[s])), p=pi2[s]))
340
+
341
+ @staticmethod
342
+ def pomdp_solver_file(config: StoppingGameConfig, discount_factor: float, pi2: npt.NDArray[Any]) -> str:
343
+ """
344
+ Gets the POMDP environment specification based on the format at http://www.pomdp.org/code/index.html,
345
+ for the defender's local problem against a static attacker
346
+
347
+ :param config: the POMDP config
348
+ :param discount_factor: the discount factor
349
+ :param pi2: the attacker strategy
350
+ :return: the file content as a string
351
+ """
352
+ file_str = ""
353
+ file_str = file_str + f"discount: {discount_factor}\n\n"
354
+ file_str = file_str + "values: reward\n\n"
355
+ file_str = file_str + f"states: {len(config.S)}\n\n"
356
+ file_str = file_str + f"actions: {len(config.A1)}\n\n"
357
+ file_str = file_str + f"observations: {len(config.O)}\n\n"
358
+ initial_belief_str = " ".join(list(map(lambda x: str(x), config.b1)))
359
+ file_str = file_str + f"start: {initial_belief_str}\n\n\n"
360
+ num_transitions = 0
361
+ for s in config.S:
362
+ for a1 in config.A1:
363
+ probs = []
364
+ for s_prime in range(len(config.S)):
365
+ num_transitions += 1
366
+ prob = 0
367
+ for a2 in config.A2:
368
+ prob += config.T[0][a1][a2][s][s_prime] * pi2[s][a2]
369
+ file_str = file_str + f"T: {a1} : {s} : {s_prime} {prob:.80f}\n"
370
+ probs.append(prob)
371
+ assert round(sum(probs), 3) == 1
372
+ file_str = file_str + "\n\n"
373
+ for a1 in config.A1:
374
+ for s_prime in config.S:
375
+ probs = []
376
+ for o in range(len(config.O)):
377
+ prob = config.Z[0][0][s_prime][o]
378
+ file_str = file_str + f"O : {a1} : {s_prime} : {o} {prob:.80f}\n"
379
+ probs.append(prob)
380
+ assert round(sum(probs), 3) == 1
381
+ file_str = file_str + "\n\n"
382
+ for s in config.S:
383
+ for a1 in config.A1:
384
+ for s_prime in config.S:
385
+ for o in config.O:
386
+ r = config.R[0][a1][0][s]
387
+ file_str = file_str + f"R: {a1} : {s} : {s_prime} : {o} {r:.80f}\n"
388
+ return file_str
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gym-csle-stopping-game
3
- Version: 0.6.3
3
+ Version: 0.6.5
4
4
  Summary: OpenAI gym reinforcement learning environment of a Dynkin (Optimal stopping) game in CSLE
5
5
  Author: Kim Hammar
6
6
  Author-email: hammar.kim@gmail.com
@@ -1,9 +1,9 @@
1
1
  gymnasium>=0.27.1
2
- csle-base>=0.6.3
3
- csle-common>=0.6.3
4
- csle-attacker>=0.6.3
5
- csle-defender>=0.6.3
6
- csle-collector>=0.6.3
2
+ csle-base>=0.6.5
3
+ csle-common>=0.6.5
4
+ csle-attacker>=0.6.5
5
+ csle-defender>=0.6.5
6
+ csle-collector>=0.6.5
7
7
 
8
8
  [testing]
9
9
  pytest>=6.0
@@ -3,6 +3,7 @@ import pytest
3
3
  from unittest.mock import patch, MagicMock
4
4
  from gymnasium.spaces import Box, Discrete
5
5
  import numpy as np
6
+ from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
6
7
  from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
7
8
  from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
8
9
  from gym_csle_stopping_game.dao.stopping_game_state import StoppingGameState
@@ -23,19 +24,19 @@ class TestStoppingGameEnvSuite:
23
24
  :return: None
24
25
  """
25
26
  env_name = "test_env"
26
- T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
27
- O = np.array([0, 1])
28
- Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]])
27
+ T = StoppingGameUtil.transition_tensor(L=3, p=0)
28
+ O = StoppingGameUtil.observation_space(n=100)
29
+ Z = StoppingGameUtil.observation_tensor(n=100)
29
30
  R = np.zeros((2, 3, 3, 3))
30
- S = np.array([0, 1, 2])
31
- A1 = np.array([0, 1, 2])
32
- A2 = np.array([0, 1, 2])
31
+ S = StoppingGameUtil.state_space()
32
+ A1 = StoppingGameUtil.defender_actions()
33
+ A2 = StoppingGameUtil.attacker_actions()
33
34
  L = 2
34
35
  R_INT = 1
35
36
  R_COST = 2
36
37
  R_SLA = 3
37
38
  R_ST = 4
38
- b1 = np.array([0.6, 0.4])
39
+ b1 = StoppingGameUtil.b1()
39
40
  save_dir = "save_directory"
40
41
  checkpoint_traces_freq = 100
41
42
  gamma = 0.9
@@ -69,12 +70,12 @@ class TestStoppingGameEnvSuite:
69
70
 
70
71
  :return: None
71
72
  """
72
- T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
73
- O = np.array([0, 1])
74
- A1 = np.array([0, 1, 2])
75
- A2 = np.array([0, 1, 2])
73
+ T = StoppingGameUtil.transition_tensor(L=3, p=0)
74
+ O = StoppingGameUtil.observation_space(n=100)
75
+ A1 = StoppingGameUtil.defender_actions()
76
+ A2 = StoppingGameUtil.attacker_actions()
76
77
  L = 2
77
- b1 = np.array([0.6, 0.4])
78
+ b1 = StoppingGameUtil.b1()
78
79
  attacker_observation_space = Box(
79
80
  low=np.array([0.0, 0.0, 0.0]),
80
81
  high=np.array([float(L), 1.0, 2.0]),
@@ -304,7 +305,7 @@ class TestStoppingGameEnvSuite:
304
305
  assert not env.is_state_terminal(state_tuple)
305
306
 
306
307
  with pytest.raises(ValueError):
307
- env.is_state_terminal([1, 2, 3]) # type: ignore
308
+ env.is_state_terminal([1, 2, 3]) # type: ignore
308
309
 
309
310
  def test_get_observation_from_history(self) -> None:
310
311
  """
@@ -346,26 +347,6 @@ class TestStoppingGameEnvSuite:
346
347
  :return: None
347
348
  """
348
349
  env = StoppingGameEnv(self.config)
349
- env.state = MagicMock()
350
- env.state.s = 1
351
- env.state.l = 2
352
- env.state.t = 0
353
- env.state.attacker_observation.return_value = np.array([1, 2, 3])
354
- env.state.defender_observation.return_value = np.array([4, 5, 6])
355
- env.state.b = np.array([0.5, 0.5, 0.0])
356
-
357
- env.trace = MagicMock()
358
- env.trace.defender_rewards = []
359
- env.trace.attacker_rewards = []
360
- env.trace.attacker_actions = []
361
- env.trace.defender_actions = []
362
- env.trace.infos = []
363
- env.trace.states = []
364
- env.trace.beliefs = []
365
- env.trace.infrastructure_metrics = []
366
- env.trace.attacker_observations = []
367
- env.trace.defender_observations = []
368
-
369
350
  with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_state",
370
351
  return_value=2):
371
352
  with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_observation",
@@ -376,7 +357,7 @@ class TestStoppingGameEnvSuite:
376
357
  1,
377
358
  (
378
359
  np.array(
379
- [[0.2, 0.8, 0.0], [0.6, 0.4, 0.0], [0.5, 0.5, 0.0]]
360
+ [[0.2, 0.8], [0.6, 0.4], [0.5, 0.5]]
380
361
  ),
381
362
  2,
382
363
  ),
@@ -384,24 +365,12 @@ class TestStoppingGameEnvSuite:
384
365
  observations, rewards, terminated, truncated, info = env.step(
385
366
  action_profile
386
367
  )
387
-
388
- assert (observations[0] == np.array([4, 5, 6])).all(), "Incorrect defender observations"
389
- assert (observations[1] == np.array([1, 2, 3])).all(), "Incorrect attacker observations"
368
+ assert observations[0].all() == np.array([1, 0.7]).all(), "Incorrect defender observations"
369
+ assert observations[1].all() == np.array([1, 2, 3]).all(), "Incorrect attacker observations"
390
370
  assert rewards == (0, 0)
391
371
  assert not terminated
392
372
  assert not truncated
393
- assert env.trace.defender_rewards[-1] == 0
394
- assert env.trace.attacker_rewards[-1] == 0
395
- assert env.trace.attacker_actions[-1] == 2
396
- assert env.trace.defender_actions[-1] == 1
397
- assert env.trace.infos[-1] == info
398
- assert env.trace.states[-1] == 2
399
- print(env.trace.beliefs)
400
- assert env.trace.beliefs[-1] == 0.7
401
- assert env.trace.infrastructure_metrics[-1] == 1
402
- assert (env.trace.attacker_observations[-1] == np.array([1, 2, 3])).all()
403
- assert (env.trace.defender_observations[-1] == np.array([4, 5, 6])).all()
404
-
373
+
405
374
  def test_info(self) -> None:
406
375
  """
407
376
  Tests the function of adding the cumulative reward and episode length to the info dict
@@ -5,8 +5,12 @@ from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
5
5
  from gym_csle_stopping_game.dao.stopping_game_attacker_mdp_config import (
6
6
  StoppingGameAttackerMdpConfig,
7
7
  )
8
+ from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
8
9
  from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
9
10
  from csle_common.dao.training.policy import Policy
11
+ from csle_common.dao.training.random_policy import RandomPolicy
12
+ from csle_common.dao.training.player_type import PlayerType
13
+ from csle_common.dao.simulation_config.action import Action
10
14
  import pytest
11
15
  from unittest.mock import MagicMock
12
16
  import numpy as np
@@ -25,19 +29,19 @@ class TestStoppingGameMdpAttackerEnvSuite:
25
29
  :return: None
26
30
  """
27
31
  env_name = "test_env"
28
- T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
29
- O = np.array([0, 1])
30
- Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]])
32
+ T = StoppingGameUtil.transition_tensor(L=3, p=0)
33
+ O = StoppingGameUtil.observation_space(n=100)
34
+ Z = StoppingGameUtil.observation_tensor(n=100)
31
35
  R = np.zeros((2, 3, 3, 3))
32
- S = np.array([0, 1, 2])
33
- A1 = np.array([0, 1, 2])
34
- A2 = np.array([0, 1, 2])
36
+ S = StoppingGameUtil.state_space()
37
+ A1 = StoppingGameUtil.defender_actions()
38
+ A2 = StoppingGameUtil.attacker_actions()
35
39
  L = 2
36
40
  R_INT = 1
37
41
  R_COST = 2
38
42
  R_SLA = 3
39
43
  R_ST = 4
40
- b1 = np.array([0.6, 0.4])
44
+ b1 = StoppingGameUtil.b1()
41
45
  save_dir = "save_directory"
42
46
  checkpoint_traces_freq = 100
43
47
  gamma = 0.9
@@ -107,9 +111,8 @@ class TestStoppingGameMdpAttackerEnvSuite:
107
111
  )
108
112
 
109
113
  env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
110
- attacker_obs, info = env.reset()
111
- assert env.latest_defender_obs.all() == np.array([2, 0.4]).all() # type: ignore
112
- assert info == {}
114
+ info = env.reset()
115
+ assert info[-1] == {}
113
116
 
114
117
  def test_set_model(self) -> None:
115
118
  """
@@ -144,7 +147,7 @@ class TestStoppingGameMdpAttackerEnvSuite:
144
147
  )
145
148
 
146
149
  env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
147
- assert not env.set_state(1) # type: ignore
150
+ assert not env.set_state(1) # type: ignore
148
151
 
149
152
  def test_calculate_stage_policy(self) -> None:
150
153
  """
@@ -190,7 +193,7 @@ class TestStoppingGameMdpAttackerEnvSuite:
190
193
  def test_render(self) -> None:
191
194
  """
192
195
  Tests the function for rendering the environment
193
-
196
+
194
197
  :return: None
195
198
  """
196
199
  defender_strategy = MagicMock(spec=Policy)
@@ -317,7 +320,7 @@ class TestStoppingGameMdpAttackerEnvSuite:
317
320
  particles = [1, 2, 3]
318
321
  t = 0
319
322
  observation = 0
320
- expected_actions = [0, 1, 2]
323
+ expected_actions = [0, 1]
321
324
  assert (
322
325
  env.get_actions_from_particles(particles, t, observation)
323
326
  == expected_actions
@@ -326,18 +329,32 @@ class TestStoppingGameMdpAttackerEnvSuite:
326
329
  def test_step(self) -> None:
327
330
  """
328
331
  Tests the function for taking a step in the environment by executing the given action
329
-
332
+
330
333
  :return: None
331
334
  """
332
- defender_strategy = MagicMock(spec=Policy)
335
+ defender_stage_strategy = np.zeros((3, 2))
336
+ defender_stage_strategy[0][0] = 0.9
337
+ defender_stage_strategy[0][1] = 0.1
338
+ defender_stage_strategy[1][0] = 0.9
339
+ defender_stage_strategy[1][1] = 0.1
340
+ defender_actions = list(map(lambda x: Action(id=x, descr=""), self.config.A1))
341
+ defender_strategy = RandomPolicy(
342
+ actions=defender_actions,
343
+ player_type=PlayerType.DEFENDER,
344
+ stage_policy_tensor=list(defender_stage_strategy),
345
+ )
333
346
  attacker_mdp_config = StoppingGameAttackerMdpConfig(
334
347
  env_name="test_env",
335
348
  stopping_game_config=self.config,
336
349
  defender_strategy=defender_strategy,
337
350
  stopping_game_name="csle-stopping-game-v1",
338
351
  )
339
-
340
352
  env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
341
- pi2 = np.array([[0.5, 0.5]])
342
- with pytest.raises(AssertionError):
343
- env.step(pi2)
353
+ env.reset()
354
+ pi2 = env.calculate_stage_policy(o=list(env.latest_attacker_obs), a2=0) # type: ignore
355
+ attacker_obs, reward, terminated, truncated, info = env.step(pi2)
356
+ assert isinstance(attacker_obs[0], float) # type: ignore
357
+ assert isinstance(terminated, bool) # type: ignore
358
+ assert isinstance(truncated, bool) # type: ignore
359
+ assert isinstance(reward, float) # type: ignore
360
+ assert isinstance(info, dict) # type: ignore
@@ -1,9 +1,14 @@
1
- from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import StoppingGamePomdpDefenderEnv
1
+ from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import (
2
+ StoppingGamePomdpDefenderEnv,
3
+ )
2
4
  from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
3
- from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import StoppingGameDefenderPomdpConfig
5
+ from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import (
6
+ StoppingGameDefenderPomdpConfig,
7
+ )
4
8
  from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
5
9
  from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
6
10
  from csle_common.dao.training.policy import Policy
11
+ from csle_common.dao.simulation_config.action import Action
7
12
  from csle_common.dao.training.random_policy import RandomPolicy
8
13
  from csle_common.dao.training.player_type import PlayerType
9
14
  import pytest
@@ -219,7 +224,7 @@ class TestStoppingGamePomdpDefenderEnvSuite:
219
224
  stopping_game_name="csle-stopping-game-v1",
220
225
  )
221
226
  env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
222
- assert env.set_state(1) is None # type: ignore
227
+ assert env.set_state(1) is None # type: ignore
223
228
 
224
229
  def test_get_observation_from_history(self) -> None:
225
230
  """
@@ -301,7 +306,10 @@ class TestStoppingGamePomdpDefenderEnvSuite:
301
306
  t = 0
302
307
  observation = 0
303
308
  expected_actions = [0, 1]
304
- assert env.get_actions_from_particles(particles, t, observation) == expected_actions
309
+ assert (
310
+ env.get_actions_from_particles(particles, t, observation)
311
+ == expected_actions
312
+ )
305
313
 
306
314
  def test_step(self) -> None:
307
315
  """
@@ -315,8 +323,12 @@ class TestStoppingGamePomdpDefenderEnvSuite:
315
323
  attacker_stage_strategy[1][0] = 0.9
316
324
  attacker_stage_strategy[1][1] = 0.1
317
325
  attacker_stage_strategy[2] = attacker_stage_strategy[1]
318
- attacker_strategy = RandomPolicy(actions=list(self.config.A2), player_type=PlayerType.ATTACKER,
319
- stage_policy_tensor=list(attacker_stage_strategy))
326
+ attacker_actions = list(map(lambda x: Action(id=x, descr=""), self.config.A2))
327
+ attacker_strategy = RandomPolicy(
328
+ actions=attacker_actions,
329
+ player_type=PlayerType.ATTACKER,
330
+ stage_policy_tensor=list(attacker_stage_strategy),
331
+ )
320
332
  defender_pomdp_config = StoppingGameDefenderPomdpConfig(
321
333
  env_name="test_env",
322
334
  stopping_game_config=self.config,
@@ -328,9 +340,9 @@ class TestStoppingGamePomdpDefenderEnvSuite:
328
340
  env.reset()
329
341
  defender_obs, reward, terminated, truncated, info = env.step(a1)
330
342
  assert len(defender_obs) == 2
331
- assert isinstance(defender_obs[0], float) # type: ignore
332
- assert isinstance(defender_obs[1], float) # type: ignore
333
- assert isinstance(reward, float) # type: ignore
334
- assert isinstance(terminated, bool) # type: ignore
335
- assert isinstance(truncated, bool) # type: ignore
336
- assert isinstance(info, dict) # type: ignore
343
+ assert isinstance(defender_obs[0], float) # type: ignore
344
+ assert isinstance(defender_obs[1], float) # type: ignore
345
+ assert isinstance(reward, float) # type: ignore
346
+ assert isinstance(terminated, bool) # type: ignore
347
+ assert isinstance(truncated, bool) # type: ignore
348
+ assert isinstance(info, dict) # type: ignore
@@ -1 +0,0 @@
1
- __version__ = '0.6.3'