ai-snake-lab 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ai_snake_lab/AISim.py ADDED
@@ -0,0 +1,274 @@
1
+ """
2
+ AISim.py
3
+
4
+ AI Snake Game Simulator
5
+ Author: Nadim-Daniel Ghaznavi
6
+ Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
7
+ GitHub: https://github.com/NadimGhaznavi/ai
8
+ License: GPL 3.0
9
+ """
10
+
11
+ import threading
12
+ import time
13
+ import sys
14
+
15
+ from textual.app import App, ComposeResult
16
+ from textual.widgets import Label, Input, Button
17
+ from textual.containers import Vertical, Horizontal
18
+ from textual.reactive import var
19
+
20
+ from constants.DDef import DDef
21
+ from constants.DEpsilon import DEpsilon
22
+ from constants.DFields import DField
23
+ from constants.DFile import DFile
24
+ from constants.DLayout import DLayout
25
+ from constants.DLabels import DLabel
26
+ from constants.DReplayMemory import MEM_TYPE
27
+
28
+ from ai.AIAgent import AIAgent
29
+ from ai.EpsilonAlgo import EpsilonAlgo
30
+ from game.GameBoard import GameBoard
31
+ from game.SnakeGame import SnakeGame
32
+
33
+ RANDOM_SEED = 1970
34
+
35
+
36
+ class AISim(App):
37
+ """A Textual app that has an AI Agent playing the Snake Game."""
38
+
39
+ TITLE = DDef.APP_TITLE
40
+ CSS_PATH = DFile.CSS_PATH
41
+
42
+ ## Runtime values
43
+ # Current epsilon value (degrades in real-time)
44
+ cur_epsilon_widget = Label("N/A", id=DLayout.CUR_EPSILON)
45
+ # Current memory type
46
+ cur_mem_type_widget = Label("N/A", id=DLayout.CUR_MEM_TYPE)
47
+ # Number of stored memories
48
+ cur_num_memories_widget = Label("N/A", id=DLayout.NUM_MEMORIES)
49
+ # Runtime move delay value
50
+ cur_move_delay = DDef.MOVE_DELAY
51
+
52
+ # Intial Settings for Epsilon
53
+ initial_epsilon_input = Input(
54
+ restrict=f"0.[0-9]*",
55
+ compact=True,
56
+ id=DLayout.EPSILON_INITIAL,
57
+ classes=DLayout.INPUT_10,
58
+ )
59
+ epsilon_min_input = Input(
60
+ restrict=f"0.[0-9]*",
61
+ compact=True,
62
+ id=DLayout.EPSILON_MIN,
63
+ classes=DLayout.INPUT_10,
64
+ )
65
+ epsilon_decay_input = Input(
66
+ restrict=f"0.[0-9]*",
67
+ compact=True,
68
+ id=DLayout.EPSILON_DECAY,
69
+ classes=DLayout.INPUT_10,
70
+ )
71
+ move_delay_input = Input(
72
+ restrict=f"[0-9]*.[0-9]*",
73
+ compact=True,
74
+ id=DLayout.MOVE_DELAY,
75
+ classes=DLayout.INPUT_10,
76
+ )
77
+
78
+ # Buttons
79
+ pause_button = Button(label=DLabel.PAUSE, id=DLayout.BUTTON_PAUSE, compact=True)
80
+ start_button = Button(label=DLabel.START, id=DLayout.BUTTON_START, compact=True)
81
+ quit_button = Button(label=DLabel.QUIT, id=DLayout.BUTTON_QUIT, compact=True)
82
+ reset_button = Button(label=DLabel.RESET, id=DLayout.BUTTON_RESET, compact=True)
83
+ update_button = Button(label=DLabel.UPDATE, id=DLayout.BUTTON_UPDATE, compact=True)
84
+
85
+ def __init__(self) -> None:
86
+ super().__init__()
87
+ self.game_board = GameBoard(20, id=DLayout.GAME_BOARD)
88
+ self.snake_game = SnakeGame(game_board=self.game_board, id=DLayout.GAME_BOARD)
89
+ self.epsilon_algo = EpsilonAlgo(seed=RANDOM_SEED)
90
+ self.agent = AIAgent(self.epsilon_algo, seed=RANDOM_SEED)
91
+ self.running = False
92
+
93
+ self.score = Label("Game: 0, Highscore: 0, Score: 0")
94
+
95
+ # Setup the simulator in a background thread
96
+ self.stop_event = threading.Event()
97
+ self.simulator_thread = threading.Thread(target=self.start_sim, daemon=True)
98
+
99
+ async def action_quit(self) -> None:
100
+ """Quit the application."""
101
+ self.stop_event.set()
102
+ if self.simulator_thread.is_alive():
103
+ self.simulator_thread.join(timeout=2)
104
+ await super().action_quit()
105
+
106
+ def compose(self) -> ComposeResult:
107
+ """Create child widgets for the app."""
108
+ yield Label(DDef.APP_TITLE, id=DLayout.TITLE)
109
+ yield Horizontal(
110
+ Vertical(
111
+ Vertical(
112
+ Horizontal(
113
+ Label(
114
+ f"{DLabel.EPSILON_INITIAL} : ",
115
+ classes=DLayout.LABEL_SETTINGS,
116
+ ),
117
+ self.initial_epsilon_input,
118
+ ),
119
+ Horizontal(
120
+ Label(
121
+ f"{DLabel.EPSILON_DECAY} : ",
122
+ classes=DLayout.LABEL_SETTINGS,
123
+ ),
124
+ self.epsilon_decay_input,
125
+ ),
126
+ Horizontal(
127
+ Label(
128
+ f"{DLabel.EPSILON_MIN} : ", classes=DLayout.LABEL_SETTINGS
129
+ ),
130
+ self.epsilon_min_input,
131
+ ),
132
+ Horizontal(
133
+ Label(
134
+ f"{DLabel.MOVE_DELAY} : ",
135
+ classes=DLayout.LABEL_SETTINGS,
136
+ ),
137
+ self.move_delay_input,
138
+ ),
139
+ id=DLayout.SETTINGS_BOX,
140
+ ),
141
+ Vertical(
142
+ Horizontal(
143
+ self.start_button,
144
+ self.reset_button,
145
+ self.update_button,
146
+ self.quit_button,
147
+ ),
148
+ id=DLayout.BUTTON_ROW,
149
+ ),
150
+ ),
151
+ Vertical(
152
+ self.game_board,
153
+ id=DLayout.GAME_BOX,
154
+ ),
155
+ Vertical(
156
+ Horizontal(
157
+ Label(f"{DLabel.EPSILON} : ", classes=DLayout.LABEL),
158
+ self.cur_epsilon_widget,
159
+ ),
160
+ Horizontal(
161
+ Label(f"{DLabel.MEM_TYPE} : ", classes=DLayout.LABEL),
162
+ self.cur_mem_type_widget,
163
+ ),
164
+ Horizontal(
165
+ Label(f"{DLabel.MEMORIES} : ", classes=DLayout.LABEL),
166
+ self.cur_num_memories_widget,
167
+ ),
168
+ id=DLayout.RUNTIME_BOX,
169
+ ),
170
+ )
171
+
172
+ def on_mount(self):
173
+ self.initial_epsilon_input.value = str(DEpsilon.EPSILON_INITIAL)
174
+ self.epsilon_decay_input.value = str(DEpsilon.EPSILON_DECAY)
175
+ self.epsilon_min_input.value = str(DEpsilon.EPSILON_MIN)
176
+ self.move_delay_input.value = str(DDef.MOVE_DELAY)
177
+ settings_box = self.query_one(f"#{DLayout.SETTINGS_BOX}", Vertical)
178
+ settings_box.border_title = DLabel.SETTINGS
179
+ runtime_box = self.query_one(f"#{DLayout.RUNTIME_BOX}", Vertical)
180
+ runtime_box.border_title = DLabel.RUNTIME
181
+ self.cur_mem_type_widget.update(
182
+ MEM_TYPE.MEM_TYPE_TABLE[self.agent.memory.mem_type()]
183
+ )
184
+ self.cur_num_memories_widget.update(str(self.agent.memory.get_num_memories()))
185
+ # Initial state is that the app is stopped
186
+ self.add_class(DField.STOPPED)
187
+
188
+ def on_quit(self):
189
+ if self.running == True:
190
+ self.stop_event.set()
191
+ if self.simulator_thread.is_alive():
192
+ self.simulator_thread.join()
193
+ sys.exit(0)
194
+
195
+ def on_button_pressed(self, event: Button.Pressed) -> None:
196
+ button_id = event.button.id
197
+ # Start button was pressed
198
+ if button_id == DLayout.BUTTON_START:
199
+ self.start_thread()
200
+ self.running = True
201
+ self.add_class(DField.RUNNING)
202
+ self.remove_class(DField.STOPPED)
203
+ self.cur_move_delay = float(self.move_delay_input.value)
204
+ # Reset button was pressed
205
+ elif button_id == DLayout.BUTTON_RESET:
206
+ self.initial_epsilon_input.value = str(DEpsilon.EPSILON_INITIAL)
207
+ self.epsilon_decay_input.value = str(DEpsilon.EPSILON_DECAY)
208
+ self.epsilon_min_input.value = str(DEpsilon.EPSILON_MIN)
209
+ self.move_delay_input.value = str(DDef.MOVE_DELAY)
210
+ # Quit button was pressed
211
+ elif button_id == DLayout.BUTTON_QUIT:
212
+ self.on_quit()
213
+ # Update button was pressed
214
+ elif button_id == DLayout.BUTTON_UPDATE:
215
+ self.cur_move_delay = float(self.move_delay_input.value)
216
+
217
+ def start_sim(self):
218
+ self.snake_game.reset()
219
+ game_board = self.game_board
220
+ agent = self.agent
221
+ snake_game = self.snake_game
222
+ score = 0
223
+ highscore = 0
224
+ self.epoch = 1
225
+ game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
226
+ game_box.border_title = f"{DLabel.GAME} #{self.epoch}"
227
+
228
+ while not self.stop_event.is_set():
229
+ # The actual training loop...
230
+ old_state = game_board.get_state()
231
+ move = agent.get_move(old_state)
232
+ reward, game_over, score = snake_game.play_step(move)
233
+ if score > highscore:
234
+ highscore = score
235
+ game_box.border_subtitle = (
236
+ f"{DLabel.HIGHSCORE}: {highscore}, {DLabel.SCORE}: {score}"
237
+ )
238
+ if not game_over:
239
+ ## Keep playing
240
+ time.sleep(self.cur_move_delay)
241
+ new_state = game_board.get_state()
242
+ agent.train_short_memory(old_state, move, reward, new_state, game_over)
243
+ agent.remember(old_state, move, reward, new_state, game_over)
244
+ else:
245
+ ## Game over
246
+ self.epoch += 1
247
+ game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
248
+ game_box.border_title = f"{DLabel.GAME} #{self.epoch}"
249
+ # Remember the last move
250
+ agent.remember(old_state, move, reward, new_state, game_over)
251
+ # Train long memory
252
+ agent.train_long_memory()
253
+ # Reset the game
254
+ snake_game.reset()
255
+ # Let the agent know we've finished a game
256
+ agent.played_game(score)
257
+ # Get the current epsilon value
258
+ cur_epsilon = self.epsilon_algo.epsilon()
259
+ if cur_epsilon < 0.0001:
260
+ self.cur_epsilon_widget.update("0.0000")
261
+ else:
262
+ self.cur_epsilon_widget.update(str(round(cur_epsilon, 4)))
263
+ # Update the number of stored memories
264
+ self.cur_num_memories_widget.update(
265
+ str(self.agent.memory.get_num_memories())
266
+ )
267
+
268
+ def start_thread(self):
269
+ self.simulator_thread.start()
270
+
271
+
272
+ if __name__ == "__main__":
273
+ app = AISim()
274
+ app.run()
@@ -0,0 +1,84 @@
1
+ """
2
+ ai/Agent.py
3
+
4
+ AI Snake Game Simulator
5
+ Author: Nadim-Daniel Ghaznavi
6
+ Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
7
+ GitHub: https://github.com/NadimGhaznavi/ai
8
+ License: GPL 3.0
9
+ """
10
+
11
+ import torch
12
+ from ai.EpsilonAlgo import EpsilonAlgo
13
+ from ai.ReplayMemory import ReplayMemory
14
+ from ai.AITrainer import AITrainer
15
+ from ai.models.ModelL import ModelL
16
+ from ai.models.ModelRNN import ModelRNN
17
+
18
+ from constants.DReplayMemory import MEM_TYPE
19
+
20
+
21
+ class AIAgent:
22
+
23
+ def __init__(self, epsilon_algo: EpsilonAlgo, seed: int):
24
+ self.epsilon_algo = epsilon_algo
25
+ self.memory = ReplayMemory(seed=seed)
26
+ self.model = ModelL(seed=seed)
27
+ # self.model = ModelRNN(seed=seed)
28
+ self.trainer = AITrainer(self.model)
29
+
30
+ if type(self.model) == ModelRNN:
31
+ self.memory.mem_type(MEM_TYPE.RANDOM_GAME)
32
+
33
+ def get_model(self):
34
+ return self.model
35
+
36
+ def get_move(self, state):
37
+ random_move = self.epsilon_algo.get_move() # Explore with epsilon
38
+ if random_move != False:
39
+ return random_move # Random move was returned
40
+
41
+ # Exploit with an AI agent based action
42
+ final_move = [0, 0, 0]
43
+ if type(state) != torch.Tensor:
44
+ state = torch.tensor(state, dtype=torch.float) # Convert to a tensor
45
+ prediction = self.model(state) # Get the prediction
46
+ move = torch.argmax(prediction).item() # Select the move with the highest value
47
+ final_move[move] = 1 # Set the move
48
+ return final_move # Return
49
+
50
+ def get_optimizer(self):
51
+ return self.trainer.get_optimizer()
52
+
53
+ def played_game(self, score):
54
+ self.epsilon_algo.played_game()
55
+
56
+ def remember(self, state, action, reward, next_state, done):
57
+ # Store the state, action, reward, next_state, and done in memory
58
+ self.memory.append((state, action, reward, next_state, done))
59
+
60
+ def set_model(self, model):
61
+ self.model = model
62
+
63
+ def set_optimizer(self, optimizer):
64
+ self.trainer.set_optimizer(optimizer)
65
+
66
+ def train_long_memory(self):
67
+ # Get the states, actions, rewards, next_states, and dones from the mini_sample
68
+ memory = self.memory.get_memory()
69
+ memory_type = self.memory.mem_type()
70
+
71
+ if type(self.model) == ModelRNN:
72
+ for state, action, reward, next_state, done in memory[0]:
73
+ self.trainer.train_step(state, action, reward, next_state, [done])
74
+
75
+ elif memory_type == MEM_TYPE.SHUFFLE:
76
+ for state, action, reward, next_state, done in memory:
77
+ self.trainer.train_step(state, action, reward, next_state, [done])
78
+
79
+ else:
80
+ for state, action, reward, next_state, done in memory[0]:
81
+ self.trainer.train_step(state, action, reward, next_state, [done])
82
+
83
+ def train_short_memory(self, state, action, reward, next_state, done):
84
+ self.trainer.train_step(state, action, reward, next_state, [done])
@@ -0,0 +1,90 @@
1
+ """
2
+ ai/AITrainer.py
3
+
4
+ AI Snake Game Simulator
5
+ Author: Nadim-Daniel Ghaznavi
6
+ Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
7
+ GitHub: https://github.com/NadimGhaznavi/ai
8
+ License: GPL 3.0
9
+ """
10
+
11
+ import torch.optim as optim
12
+ import torch.nn as nn
13
+ import torch
14
+ import numpy as np
15
+ import time
16
+ import sys
17
+
18
+ from ai.models.ModelL import ModelL
19
+ from ai.models.ModelRNN import ModelRNN
20
+
21
+ from constants.DModelL import DModelL
22
+ from constants.DModelLRNN import DModelRNN
23
+
24
+
25
+ class AITrainer:
26
+
27
+ def __init__(self, model):
28
+ torch.manual_seed(1970)
29
+ self.model = model
30
+ # The learning rate needs to be adjusted for the model type
31
+ if type(model) == ModelL:
32
+ learning_rate = DModelL.LEARNING_RATE
33
+ elif type(model) == ModelRNN:
34
+ learning_rate = DModelRNN.LEARNING_RATE
35
+ self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
36
+ self.criterion = nn.MSELoss()
37
+ self.gamma = 0.9
38
+
39
+ def get_optimizer(self):
40
+ return self.optimizer
41
+
42
+ def set_optimizer(self, optimizer):
43
+ self.optimizer = optimizer
44
+
45
+ def train_step_cnn(self, state, action, reward, next_state, game_over):
46
+ state = torch.tensor(np.array(state), dtype=torch.float)
47
+ next_state = torch.tensor(np.array(next_state), dtype=torch.float)
48
+ action = torch.tensor(action, dtype=torch.long)
49
+ reward = torch.tensor(reward, dtype=torch.float)
50
+ pred = self.model(state)
51
+ target = pred.clone()
52
+ if game_over:
53
+ Q_new = reward # No future rewards, the game is over.
54
+ else:
55
+ Q_new = reward + self.gamma * torch.max(self.model(next_state).detach())
56
+ target[0][action.argmax().item()] = Q_new # Update Q value
57
+ self.optimizer.zero_grad() # Reset gradients
58
+ loss = self.criterion(target, pred) # Calculate the loss
59
+ loss.backward()
60
+ self.optimizer.step() # Adjust the weights
61
+
62
+ def train_step(self, state, action, reward, next_state, game_over):
63
+ state = torch.tensor(np.array(state), dtype=torch.float)
64
+ next_state = torch.tensor(np.array(next_state), dtype=torch.float)
65
+ action = torch.tensor(action, dtype=torch.long)
66
+ reward = torch.tensor(reward, dtype=torch.float)
67
+ if len(state.shape) == 1:
68
+ # Add a batch dimension
69
+ state = torch.unsqueeze(state, 0)
70
+ next_state = torch.unsqueeze(next_state, 0)
71
+ action = torch.unsqueeze(action, 0)
72
+ reward = torch.unsqueeze(reward, 0)
73
+ game_over = (game_over,)
74
+
75
+ pred = self.model(state)
76
+ target = pred.clone().detach()
77
+
78
+ for idx in range(len(game_over)):
79
+ Q_new = reward[idx]
80
+ if not game_over[idx][0]:
81
+ Q_new = reward[idx] + self.gamma * torch.max(
82
+ self.model(next_state[idx])
83
+ )
84
+ target[idx][action[idx].argmax().item()] = Q_new # Update Q value
85
+
86
+ self.optimizer.zero_grad() # Reset gradients
87
+
88
+ loss = self.criterion(target, pred) # Calculate the loss
89
+ loss.backward()
90
+ self.optimizer.step() # Adjust the weights
@@ -0,0 +1,73 @@
1
+ """
2
+ ai/EpsilonAlgo.py
3
+
4
+ AI Snake Game Simulator
5
+ Author: Nadim-Daniel Ghaznavi
6
+ Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
7
+ GitHub: https://github.com/NadimGhaznavi/ai
8
+ License: GPL 3.0
9
+
10
+
11
+ A class to encapsulate the functionality of the epsilon algorithm. The algorithm
12
+ injects random moves at the beginning of the simulation. The amount of moves
13
+ is controlled by the epsilon_value parameter which is in the AISnakeGame.ini and
14
+ can also be passed in when invoking the main asg.py front end.
15
+ """
16
+
17
+ import random
18
+ from random import randint
19
+ import os, sys
20
+
21
+ from constants.DEpsilon import DEpsilon
22
+
23
+
24
+ class EpsilonAlgo:
25
+
26
+ def __init__(self, seed):
27
+ # Set this random seed so things are repeatable
28
+ random.seed(seed)
29
+ self._initial_epsilon = DEpsilon.EPSILON_INITIAL
30
+ self._epsilon_min = DEpsilon.EPSILON_MIN
31
+ self._epsilon_decay = DEpsilon.EPSILON_DECAY
32
+ self._epsilon = self._initial_epsilon
33
+ self._num_games = 0
34
+ self._injected = 0
35
+ self._depleted = False
36
+
37
+ def get_move(self):
38
+ if random.random() < self._epsilon:
39
+ rand_move = [0, 0, 0]
40
+ rand_idx = randint(0, 2)
41
+ rand_move[rand_idx] = 1
42
+ self._injected += 1
43
+ return rand_move
44
+ return False
45
+
46
+ def epsilon(self):
47
+ return self._epsilon
48
+
49
+ def epsilon_decay(self, epsilon_decay=None):
50
+ if epsilon_decay is not None:
51
+ self._epsilon_decay = epsilon_decay
52
+ return self._epsilon_decay
53
+
54
+ def epsilon_min(self, epsilon_min=None):
55
+ if epsilon_min is not None:
56
+ self._epsilon_min = epsilon_min
57
+ return self._epsilon_min
58
+
59
+ def initial_epsilon(self, initial_epsilon=None):
60
+ if initial_epsilon is not None:
61
+ self._initial_epsilon = initial_epsilon
62
+ return self._initial_epsilon
63
+
64
+ def injected(self):
65
+ return self._injected
66
+
67
+ def played_game(self):
68
+ self._num_games += 1
69
+ self._epsilon = max(self._epsilon_min, self._epsilon * self._epsilon_decay)
70
+ self.reset_injected()
71
+
72
+ def reset_injected(self):
73
+ self._injected = 0
@@ -0,0 +1,90 @@
1
+ """
2
+ ai/ReplayMemory.py
3
+
4
+ AI Snake Game Simulator
5
+ Author: Nadim-Daniel Ghaznavi
6
+ Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
7
+ GitHub: https://github.com/NadimGhaznavi/ai
8
+ License: GPL 3.0
9
+
10
+ This file contains the ReplayMemory class.
11
+ """
12
+
13
+ from collections import deque
14
+ import random, sys
15
+
16
+ from constants.DReplayMemory import MEM_TYPE
17
+
18
+
19
+ class ReplayMemory:
20
+
21
+ def __init__(self, seed: int):
22
+ random.seed(seed)
23
+ self.batch_size = 250
24
+ # Valid options: shuffle, random_game, targeted_score, random_targeted_score
25
+ self._mem_type = MEM_TYPE.RANDOM_GAME
26
+ self.min_games = 1
27
+ self.max_states = 15000
28
+ self.max_shuffle_games = 40
29
+ self.max_games = 500
30
+
31
+ if self._mem_type == MEM_TYPE.SHUFFLE:
32
+ # States are stored in a deque and a random sample will be returned
33
+ self.memories = deque(maxlen=self.max_states)
34
+
35
+ elif self._mem_type == MEM_TYPE.RANDOM_GAME:
36
+ # All of the states for a game are stored, in order, in a deque.
37
+ # A complete game will be returned
38
+ self.memories = deque(maxlen=self.max_shuffle_games)
39
+ self.cur_memory = []
40
+
41
+ else:
42
+ print(f"ERROR: Unrecognized replay memory type ({self._mem_type}), exiting")
43
+ sys.exit(1)
44
+
45
+ def append(self, transition):
46
+ ## Add memories
47
+
48
+ # States are stored in a deque and a random sample will be returned
49
+ if self._mem_type == MEM_TYPE.SHUFFLE:
50
+ self.memories.append(transition)
51
+
52
+ # All of the states for a game are stored, in order, in a deque.
53
+ # A set of ordered states representing a complete game will be returned
54
+ elif self._mem_type == MEM_TYPE.RANDOM_GAME:
55
+ self.cur_memory.append(transition)
56
+ state, action, reward, next_state, done = transition
57
+ if done:
58
+ self.memories.append(self.cur_memory)
59
+ self.cur_memory = []
60
+
61
+ def get_random_game(self):
62
+ if len(self.memories) >= self.min_games:
63
+ rand_game = random.sample(self.memories, 1)
64
+ return rand_game
65
+ else:
66
+ return False
67
+
68
+ def get_random_states(self):
69
+ mem_size = len(self.memories)
70
+ if mem_size < self.batch_size:
71
+ return self.memories
72
+ return random.sample(self.memories, self.batch_size)
73
+
74
+ def get_memory(self):
75
+ if self._mem_type == MEM_TYPE.SHUFFLE:
76
+ return self.get_random_states()
77
+
78
+ elif self._mem_type == MEM_TYPE.RANDOM_GAME:
79
+ return self.get_random_game()
80
+
81
+ def get_num_memories(self):
82
+ return len(self.memories)
83
+
84
+ def mem_type(self, mem_type=None):
85
+ if mem_type is not None:
86
+ self._mem_type = mem_type
87
+ return self._mem_type
88
+
89
+ def set_memory(self, memory):
90
+ self.memory = memory
@@ -0,0 +1,40 @@
1
+ """
2
+ Modules/ModelL.py
3
+
4
+ AI Snake Game Simulator
5
+ Author: Nadim-Daniel Ghaznavi
6
+ Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
7
+ GitHub: https://github.com/NadimGhaznavi/ai
8
+ License: GPL 3.0
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ class ModelL(nn.Module):
17
+ def __init__(self, seed: int):
18
+ super(ModelL, self).__init__()
19
+ torch.manual_seed(seed)
20
+ input_size = 27 # Size of the "state" as tracked by the GameBoard
21
+ hidden_size = 170
22
+ output_size = 3
23
+ p_value = 0.1
24
+ self.input_block = nn.Sequential(
25
+ nn.Linear(input_size, hidden_size),
26
+ nn.ReLU(),
27
+ )
28
+ self.hidden_block = nn.Sequential(
29
+ nn.Linear(hidden_size, hidden_size),
30
+ nn.ReLU(),
31
+ )
32
+ self.dropout_block = nn.Dropout(p=p_value)
33
+ self.output_block = nn.Linear(hidden_size, output_size)
34
+
35
+ def forward(self, x):
36
+ x = self.input_block(x)
37
+ x = self.hidden_block(x)
38
+ x = self.dropout_block(x)
39
+ x = self.output_block(x)
40
+ return x
@@ -0,0 +1,43 @@
1
+ """
2
+ Modules/ModelRNN.py
3
+
4
+ AI Snake Game Simulator
5
+ Author: Nadim-Daniel Ghaznavi
6
+ Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
7
+ GitHub: https://github.com/NadimGhaznavi/ai
8
+ License: GPL 3.0
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ class ModelRNN(nn.Module):
17
+ def __init__(self, seed: int):
18
+ super(ModelRNN, self).__init__()
19
+ torch.manual_seed(seed)
20
+ input_size = 27
21
+ hidden_size = 200
22
+ output_size = 3
23
+ rnn_layers = 4
24
+ rnn_dropout = 0.2
25
+ self.m_in = nn.Sequential(
26
+ nn.Linear(input_size, hidden_size),
27
+ nn.ReLU(),
28
+ )
29
+ self.m_rnn = nn.RNN(
30
+ input_size=hidden_size,
31
+ hidden_size=hidden_size,
32
+ nonlinearity="tanh",
33
+ num_layers=rnn_layers,
34
+ dropout=rnn_dropout,
35
+ )
36
+ self.m_out = nn.Linear(hidden_size, output_size)
37
+
38
+ def forward(self, x):
39
+ x = self.m_in(x)
40
+ inputs = x.view(1, -1, 200)
41
+ x, h_n = self.m_rnn(inputs)
42
+ x = self.m_out(x)
43
+ return x[len(x) - 1]