ai-snake-lab 0.1.0__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,10 +10,14 @@ ai/ReplayMemory.py
10
10
  This file contains the ReplayMemory class.
11
11
  """
12
12
 
13
+ import os
13
14
  from collections import deque
14
15
  import random, sys
16
+ import sqlite3, pickle
15
17
 
16
18
  from constants.DReplayMemory import MEM_TYPE
19
+ from constants.DFile import DFile
20
+ from constants.DDir import DDir
17
21
 
18
22
 
19
23
  class ReplayMemory:
@@ -27,6 +31,11 @@ class ReplayMemory:
27
31
  self.max_states = 15000
28
32
  self.max_shuffle_games = 40
29
33
  self.max_games = 500
34
+ self.db_file = os.path.join(DDir.AI_SNAKE_LAB, DDir.DB, DFile.REPLAY_DB)
35
+
36
+ # Delete the replay memory file, if it exists
37
+ if os.path.exists(self.db_file):
38
+ os.remove(self.db_file)
30
39
 
31
40
  if self._mem_type == MEM_TYPE.SHUFFLE:
32
41
  # States are stored in a deque and a random sample will be returned
@@ -35,35 +44,50 @@ class ReplayMemory:
35
44
  elif self._mem_type == MEM_TYPE.RANDOM_GAME:
36
45
  # All of the states for a game are stored, in order, in a deque.
37
46
  # A complete game will be returned
38
- self.memories = deque(maxlen=self.max_shuffle_games)
39
47
  self.cur_memory = []
40
48
 
41
- else:
42
- print(f"ERROR: Unrecognized replay memory type ({self._mem_type}), exiting")
43
- sys.exit(1)
49
+ # Connect to SQLite
50
+ self.conn = sqlite3.connect(self.db_file, check_same_thread=False)
51
+ self.cursor = self.conn.cursor()
52
+ self.init_db()
44
53
 
45
- def append(self, transition):
46
- ## Add memories
54
+ def __len__(self):
55
+ return len(self.memories)
47
56
 
48
- # States are stored in a deque and a random sample will be returned
49
- if self._mem_type == MEM_TYPE.SHUFFLE:
50
- self.memories.append(transition)
57
+ def append(self, transition):
58
+ """Add a transition to the current game."""
59
+ if self._mem_type != MEM_TYPE.RANDOM_GAME:
60
+ raise NotImplementedError(
61
+ "Only RANDOM_GAME memory type is implemented for SQLite backend"
62
+ )
63
+
64
+ self.cur_memory.append(transition)
65
+ _, _, _, _, done = transition
66
+
67
+ if done:
68
+ # Serialize the full game to JSON
69
+ serialized = pickle.dumps(self.cur_memory)
70
+ self.cursor.execute(
71
+ "INSERT INTO games (transitions) VALUES (?)", (serialized,)
72
+ )
73
+ self.conn.commit()
74
+ self.cur_memory = []
51
75
 
52
- # All of the states for a game are stored, in order, in a deque.
53
- # A set of ordered states representing a complete game will be returned
54
- elif self._mem_type == MEM_TYPE.RANDOM_GAME:
55
- self.cur_memory.append(transition)
56
- state, action, reward, next_state, done = transition
57
- if done:
58
- self.memories.append(self.cur_memory)
59
- self.cur_memory = []
76
+ def close(self):
77
+ """Close the database connection."""
78
+ self.conn.close()
60
79
 
61
80
  def get_random_game(self):
62
- if len(self.memories) >= self.min_games:
63
- rand_game = random.sample(self.memories, 1)
64
- return rand_game
65
- else:
66
- return False
81
+ """Return a random full game from the database."""
82
+ self.cursor.execute("SELECT id FROM games")
83
+ all_ids = [row[0] for row in self.cursor.fetchall()]
84
+ if len(all_ids) >= self.min_games:
85
+ rand_id = random.choice(all_ids)
86
+ self.cursor.execute("SELECT transitions FROM games WHERE id=?", (rand_id,))
87
+ row = self.cursor.fetchone()
88
+ if row:
89
+ return pickle.loads(row[0])
90
+ return False
67
91
 
68
92
  def get_random_states(self):
69
93
  mem_size = len(self.memories)
@@ -78,8 +102,21 @@ class ReplayMemory:
78
102
  elif self._mem_type == MEM_TYPE.RANDOM_GAME:
79
103
  return self.get_random_game()
80
104
 
81
- def get_num_memories(self):
82
- return len(self.memories)
105
+ def get_num_games(self):
106
+ """Return number of games stored in the database."""
107
+ self.cursor.execute("SELECT COUNT(*) FROM games")
108
+ return self.cursor.fetchone()[0]
109
+
110
+ def init_db(self):
111
+ self.cursor.execute(
112
+ """
113
+ CREATE TABLE IF NOT EXISTS games (
114
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
115
+ transitions TEXT NOT NULL
116
+ )
117
+ """
118
+ )
119
+ self.conn.commit()
83
120
 
84
121
  def mem_type(self, mem_type=None):
85
122
  if mem_type is not None:
@@ -12,15 +12,18 @@ import torch
12
12
  import torch.nn as nn
13
13
  import torch.nn.functional as F
14
14
 
15
+ from ai_snake_lab.constants.DSim import DSim
16
+ from ai_snake_lab.constants.DModelL import DModelL
17
+
15
18
 
16
19
  class ModelL(nn.Module):
17
20
  def __init__(self, seed: int):
18
21
  super(ModelL, self).__init__()
19
22
  torch.manual_seed(seed)
20
- input_size = 27 # Size of the "state" as tracked by the GameBoard
21
- hidden_size = 170
22
- output_size = 3
23
- p_value = 0.1
23
+ input_size = DSim.STATE_SIZE # Size of the "state" as tracked by the GameBoard
24
+ hidden_size = DModelL.HIDDEN_SIZE
25
+ output_size = DSim.OUTPUT_SIZE
26
+ p_value = DModelL.P_VALUE
24
27
  self.input_block = nn.Sequential(
25
28
  nn.Linear(input_size, hidden_size),
26
29
  nn.ReLU(),
@@ -17,7 +17,7 @@ class ModelRNN(nn.Module):
17
17
  def __init__(self, seed: int):
18
18
  super(ModelRNN, self).__init__()
19
19
  torch.manual_seed(seed)
20
- input_size = 27
20
+ input_size = 30
21
21
  hidden_size = 200
22
22
  output_size = 3
23
23
  rnn_layers = 4
@@ -0,0 +1,20 @@
1
+ """
2
+ constants/DDb4EPlot.py
3
+
4
+ AI Snake Game Simulator
5
+ Author: Nadim-Daniel Ghaznavi
6
+ Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
7
+ GitHub: https://github.com/NadimGhaznavi/ai
8
+ License: GPL 3.0
9
+ """
10
+
11
+ from ai_snake_lab.utils.ConstGroup import ConstGroup
12
+
13
+
14
+ class Plot(ConstGroup):
15
+ """Db4EPlot Constants"""
16
+
17
+ # Simulation loop states
18
+ AVERAGE: str = "average"
19
+ SLIDING: str = "sliding"
20
+ MAX_DATA_POINTS: int = 200
@@ -8,11 +8,11 @@ constants/DDef.py
8
8
  License: GPL 3.0
9
9
  """
10
10
 
11
- from utils.ConstGroup import ConstGroup
11
+ from ai_snake_lab.utils.ConstGroup import ConstGroup
12
12
 
13
13
 
14
14
  class DDef(ConstGroup):
15
15
  """Defaults"""
16
16
 
17
- APP_TITLE: str = "AI Snake Game Simulator"
17
+ APP_TITLE: str = "AI Snake Game Lab"
18
18
  MOVE_DELAY: float = 0.0
@@ -8,9 +8,12 @@ constants/DDir.py
8
8
  License: GPL 3.0
9
9
  """
10
10
 
11
- from utils.ConstGroup import ConstGroup
11
+ from ai_snake_lab.utils.ConstGroup import ConstGroup
12
12
 
13
13
 
14
14
  class DDir(ConstGroup):
15
15
  """Directories"""
16
- UTILS : str = "utils"
16
+
17
+ AI_SNAKE_LAB: str = "ai_snake_lab"
18
+ DB: str = "db"
19
+ UTILS: str = "utils"
@@ -8,7 +8,7 @@ constants/DEpsilon.py
8
8
  License: GPL 3.0
9
9
  """
10
10
 
11
- from utils.ConstGroup import ConstGroup
11
+ from ai_snake_lab.utils.ConstGroup import ConstGroup
12
12
 
13
13
 
14
14
  class DEpsilon(ConstGroup):
@@ -8,11 +8,17 @@ constants/DFields.py
8
8
  License: GPL 3.0
9
9
  """
10
10
 
11
- from utils.ConstGroup import ConstGroup
11
+ from ai_snake_lab.utils.ConstGroup import ConstGroup
12
12
 
13
13
 
14
14
  class DField(ConstGroup):
15
15
  """Fields"""
16
16
 
17
+ # Simulation loop states
18
+ PAUSED: str = "paused"
17
19
  RUNNING: str = "running"
18
20
  STOPPED: str = "stopped"
21
+
22
+ # Stats dictionary keys
23
+ GAME_SCORE: str = "game_score"
24
+ GAME_NUM: str = "game_num"
@@ -8,10 +8,11 @@ constants/DFile.py
8
8
  License: GPL 3.0
9
9
  """
10
10
 
11
- from utils.ConstGroup import ConstGroup
11
+ from ai_snake_lab.utils.ConstGroup import ConstGroup
12
12
 
13
13
 
14
14
  class DFile(ConstGroup):
15
15
  """Files"""
16
16
 
17
- CSS_PATH: str = "AISim.tcss"
17
+ CSS_FILE: str = "AISim.tcss"
18
+ REPLAY_DB: str = "replay_mem.db"
@@ -8,27 +8,47 @@ constants/DLabels.py
8
8
  License: GPL 3.0
9
9
  """
10
10
 
11
- from utils.ConstGroup import ConstGroup
11
+ from ai_snake_lab.ai.models.ModelL import ModelL
12
+ from ai_snake_lab.ai.models.ModelRNN import ModelRNN
13
+
14
+ from ai_snake_lab.utils.ConstGroup import ConstGroup
12
15
 
13
16
 
14
17
  class DLabel(ConstGroup):
15
18
  """Labels"""
16
19
 
20
+ AVERAGE: str = "Average"
21
+ CURRENT: str = "Current"
22
+ CURRENT_EPSILON: str = "Current Epsilon"
23
+ DEFAULTS: str = "Defaults"
17
24
  EPSILON: str = "Epsilon"
18
25
  EPSILON_DECAY: str = "Epsilon Decay"
19
26
  EPSILON_INITIAL: str = "Initial Epsilon"
20
27
  EPSILON_MIN: str = "Minimum Epsilon"
21
28
  GAME: str = "Game"
29
+ GAMES: str = "Games"
30
+ GAME_SCORE: str = "Game Score"
31
+ GAME_NUM: str = "Game Number"
22
32
  HIGHSCORE: str = "Highscore"
23
33
  MEM_TYPE: str = "Memory Type"
24
- MEMORIES: str = "Memories"
25
34
  MIN_EPSILON: str = "Minimum Epsilon"
35
+ MODEL_LINEAR: str = "Linear"
36
+ MODEL_RNN: str = "RNN"
37
+ MODEL_TYPE: str = "Model Type"
26
38
  MOVE_DELAY: str = "Move Delay"
27
39
  PAUSE: str = "Pause"
28
40
  QUIT: str = "Quit"
29
- RUNTIME: str = "Runtime Values"
41
+ RESTART: str = "Restart"
42
+ RUNTIME: str = "Runtime"
43
+ RUNTIME_VALUES: str = "Runtime Values"
30
44
  SCORE: str = "Score"
31
45
  SETTINGS: str = "Configuration Settings"
32
46
  START: str = "Start"
33
- RESET: str = "Reset"
47
+ STORED_GAMES: str = "Stored Games"
48
+ RESTART: str = "Restart"
34
49
  UPDATE: str = "Update"
50
+
51
+ MODEL_TYPE_TABLE: dict = {
52
+ str(ModelL): MODEL_LINEAR,
53
+ ModelRNN: MODEL_RNN,
54
+ }
@@ -8,22 +8,35 @@ constants/DLayout.py
8
8
  License: GPL 3.0
9
9
  """
10
10
 
11
- from utils.ConstGroup import ConstGroup
11
+ from ai_snake_lab.utils.ConstGroup import ConstGroup
12
12
 
13
13
 
14
14
  class DLayout(ConstGroup):
15
15
  """Layout"""
16
16
 
17
+ BUTTON_BOX: str = "button_box"
17
18
  BUTTON_PAUSE: str = "button_pause"
18
19
  BUTTON_QUIT: str = "button_quit"
20
+ BUTTON_RESTART: str = "button_restart"
19
21
  BUTTON_ROW: str = "button_row"
20
22
  BUTTON_START: str = "button_start"
21
- BUTTON_RESET: str = "button_reset"
23
+ BUTTON_DEFAULTS: str = "button_defaults"
22
24
  BUTTON_UPDATE: str = "button_update"
23
25
  CUR_EPSILON: str = "cur_epsilon"
24
26
  CUR_MEM_TYPE: str = "cur_mem_type"
27
+ CUR_MODEL_TYPE: str = "cur_model_type"
28
+ FILLER_1: str = "filler_1"
29
+ FILLER_2: str = "filler_2"
30
+ FILLER_3: str = "filler_3"
31
+ FILLER_4: str = "filler_4"
32
+ FILLER_5: str = "filler_5"
33
+ FILLER_6: str = "filler_6"
34
+ FILLER_7: str = "filler_7"
35
+ FILLER_8: str = "filler_8"
25
36
  GAME_BOARD: str = "game_board"
26
37
  GAME_BOX: str = "game_box"
38
+ GAME_SCORE: str = "game_score"
39
+ GAME_SCORE_PLOT: str = "game_score_plot"
27
40
  EPSILON_DECAY: str = "epsilon_decay"
28
41
  EPSILON_INITIAL: str = "initial_epsilon"
29
42
  EPSILON_MIN: str = "epsilon_min"
@@ -31,8 +44,9 @@ class DLayout(ConstGroup):
31
44
  LABEL: str = "label"
32
45
  LABEL_SETTINGS: str = "label_settings"
33
46
  MOVE_DELAY: str = "move_delay"
34
- NUM_MEMORIES: str = "num_memories"
47
+ NUM_GAMES: str = "num_games"
35
48
  RUNTIME_BOX: str = "runtime_box"
49
+ RUNTIME: str = "runtime"
36
50
  SCORE: str = "score"
37
51
  SETTINGS_BOX: str = "settings_box"
38
52
  TITLE: str = "title"
@@ -8,10 +8,14 @@ constants/DModelL.py
8
8
  License: GPL 3.0
9
9
  """
10
10
 
11
- from utils.ConstGroup import ConstGroup
11
+ from ai_snake_lab.utils.ConstGroup import ConstGroup
12
12
 
13
13
 
14
14
  class DModelL(ConstGroup):
15
15
  """Linear Model Defaults"""
16
16
 
17
17
  LEARNING_RATE: float = 0.000009
18
+ # The number of nodes in the hidden layer
19
+ HIDDEN_SIZE: int = 170
20
+ # The dropout value, 0.2 represents 20%
21
+ P_VALUE: float = 0.2
@@ -8,7 +8,7 @@ constants/DModelRNN.py
8
8
  License: GPL 3.0
9
9
  """
10
10
 
11
- from utils.ConstGroup import ConstGroup
11
+ from ai_snake_lab.utils.ConstGroup import ConstGroup
12
12
 
13
13
 
14
14
  class DModelRNN(ConstGroup):
@@ -8,7 +8,7 @@ constants/DReplayMemory.py
8
8
  License: GPL 3.0
9
9
  """
10
10
 
11
- from utils.ConstGroup import ConstGroup
11
+ from ai_snake_lab.utils.ConstGroup import ConstGroup
12
12
 
13
13
 
14
14
  class MEM_TYPE(ConstGroup):
@@ -0,0 +1,20 @@
1
+ """
2
+ constants/DGameBoard.py
3
+
4
+ AI Snake Game Simulator
5
+ Author: Nadim-Daniel Ghaznavi
6
+ Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
7
+ GitHub: https://github.com/NadimGhaznavi/ai
8
+ License: GPL 3.0
9
+ """
10
+
11
+ from ai_snake_lab.utils.ConstGroup import ConstGroup
12
+
13
+
14
+ class DSim(ConstGroup):
15
+ """Simulation Constants"""
16
+
17
+ # Size of the statemap, this is from the GameBoard class
18
+ STATE_SIZE: int = 30
19
+ # The number of "choices" the snake has: go forward, left or right.
20
+ OUTPUT_SIZE: int = 3
@@ -68,6 +68,19 @@ class GameBoard(ScrollView):
68
68
  def board_size(self) -> int:
69
69
  return self._board_size
70
70
 
71
+ def get_binary(self, bits_needed, some_int):
72
+ # This is used in the state map, the get_state() function.
73
+ some_int = int(some_int)
74
+ bin_str = format(some_int, "b")
75
+ out_list = []
76
+ for bit in range(len(bin_str)):
77
+ out_list.append(bin_str[bit])
78
+ for zero in range(bits_needed - len(out_list)):
79
+ out_list.insert(0, "0")
80
+ for x in range(bits_needed):
81
+ out_list[x] = int(out_list[x])
82
+ return out_list
83
+
71
84
  def get_state(self):
72
85
 
73
86
  head = self.snake_head
@@ -80,7 +93,7 @@ class GameBoard(ScrollView):
80
93
  dir_r = direction == Direction.RIGHT
81
94
  dir_u = direction == Direction.UP
82
95
  dir_d = direction == Direction.DOWN
83
-
96
+ slb = self.get_binary(7, len(self.snake_body))
84
97
  state = [
85
98
  # 1. Snake collision straight ahead
86
99
  (dir_r and self.is_snake_collision(point_r))
@@ -97,47 +110,48 @@ class GameBoard(ScrollView):
97
110
  or (dir_u and self.is_snake_collision(point_l))
98
111
  or (dir_r and self.is_snake_collision(point_u))
99
112
  or (dir_l and self.is_snake_collision(point_d)),
100
- # 4. divider
101
- 0,
102
- # 5. Wall collision straight ahead
113
+ # 4. Wall collision straight ahead
103
114
  (dir_r and self.is_wall_collision(point_r))
104
115
  or (dir_l and self.is_wall_collision(point_l))
105
116
  or (dir_u and self.is_wall_collision(point_u))
106
117
  or (dir_d and self.is_wall_collision(point_d)),
107
- # 6. Wall collision to the right
118
+ # 5. Wall collision to the right
108
119
  (dir_u and self.is_wall_collision(point_r))
109
120
  or (dir_d and self.is_wall_collision(point_l))
110
121
  or (dir_l and self.is_wall_collision(point_u))
111
122
  or (dir_r and self.is_wall_collision(point_d)),
112
- # 7. Wall collision to the left
123
+ # 6. Wall collision to the left
113
124
  (dir_d and self.is_wall_collision(point_r))
114
125
  or (dir_u and self.is_wall_collision(point_l))
115
126
  or (dir_r and self.is_wall_collision(point_u))
116
127
  or (dir_l and self.is_wall_collision(point_d)),
117
- # 8. divider
118
- 0,
119
- # 9, 10, 11, 12. Last move direction
128
+ # 7 - 10. Last move direction
120
129
  dir_l,
121
130
  dir_r,
122
131
  dir_u,
123
132
  dir_d,
124
- # 13. divider
125
- 0,
126
- # 14 - 23. Food location
127
- self.food.x < self.snake_head.x, # Food left
128
- self.food.x > self.snake_head.x, # Food right
129
- self.food.y < self.snake_head.y, # Food up
130
- self.food.y > self.snake_head.y, # Food down
131
- self.food.x == self.snake_head.x,
133
+ # 11 - 19. Food location
134
+ self.food.x < self.snake_head.x, # 11. Food left
135
+ self.food.x > self.snake_head.x, # 12. Food right
136
+ self.food.y < self.snake_head.y, # 13. Food up
137
+ self.food.y > self.snake_head.y, # 14. Food down
138
+ self.food.x == self.snake_head.x, # 15.
132
139
  self.food.x == self.snake_head.x
133
- and self.food.y > self.snake_head.y, # Food ahead
140
+ and self.food.y > self.snake_head.y, # 16. Food ahead
134
141
  self.food.x == self.snake_head.x
135
- and self.food.y < self.snake_head.y, # Food behind
136
- self.food.y == self.snake_head.y,
142
+ and self.food.y < self.snake_head.y, # 17. Food behind
137
143
  self.food.y == self.snake_head.y
138
- and self.food.x > self.snake_head.x, # Food above
144
+ and self.food.x > self.snake_head.x, # 18. Food above
139
145
  self.food.y == self.snake_head.y
140
- and self.food.x < self.snake_head.x, # Food below
146
+ and self.food.x < self.snake_head.x, # 19. Food below
147
+ # 20 - 26. Snake length in binary
148
+ slb[0],
149
+ slb[1],
150
+ slb[2],
151
+ slb[3],
152
+ slb[4],
153
+ slb[5],
154
+ slb[6],
141
155
  ]
142
156
 
143
157
  # 24, 25, 26 and 27. Previous direction of the snake
@@ -54,6 +54,9 @@ class SnakeGame:
54
54
  self.game_board.update_snake(snake=self.snake, direction=self.direction)
55
55
  self.game_board.update_food(food=self.food)
56
56
 
57
+ # Track the distance from the snake head to the food to feed the reward system
58
+ self.distance_to_food = self.game_board.board_size() // 2
59
+
57
60
  # The current game score
58
61
  self.game_score = 0
59
62
 
@@ -126,6 +129,12 @@ class SnakeGame:
126
129
  game_over = True
127
130
  reward = -10
128
131
 
132
+ # Set a negative reward if the snake head is adjacent to the snake body.
133
+ # This is to discourage snake collisions.
134
+ for segment in self.snake[1:]:
135
+ if abs(self.head.x - segment.x) < 2 and abs(self.head.y - segment.y) < 2:
136
+ reward -= -1
137
+
129
138
  if game_over == True:
130
139
  # Game is over: Snake or wall collision or exceeded max moves
131
140
  self.game_reward += reward
@@ -142,6 +151,14 @@ class SnakeGame:
142
151
  else:
143
152
  self.snake.pop()
144
153
 
154
+ ## 5. See if we're closer to the food than the last move, or further away
155
+ cur_distance = abs(self.head.x - self.food.x) + abs(self.head.y - self.food.y)
156
+ if cur_distance < self.distance_to_food:
157
+ reward += 2
158
+ elif cur_distance > self.distance_to_food:
159
+ reward -= 2
160
+ self.distance_to_food = cur_distance
161
+
145
162
  self.game_reward += reward
146
163
  self.game_board.update_snake(snake=self.snake, direction=self.direction)
147
164
  self.game_board.update_food(food=self.food)