ai-snake-lab 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_snake_lab/AISim.py +274 -0
- ai_snake_lab/ai/AIAgent.py +84 -0
- ai_snake_lab/ai/AITrainer.py +90 -0
- ai_snake_lab/ai/EpsilonAlgo.py +73 -0
- ai_snake_lab/ai/ReplayMemory.py +90 -0
- ai_snake_lab/ai/models/ModelL.py +40 -0
- ai_snake_lab/ai/models/ModelRNN.py +43 -0
- ai_snake_lab/constants/DDef.py +18 -0
- ai_snake_lab/constants/DDir.py +16 -0
- ai_snake_lab/constants/DEpsilon.py +19 -0
- ai_snake_lab/constants/DFields.py +18 -0
- ai_snake_lab/constants/DFile.py +17 -0
- ai_snake_lab/constants/DLabels.py +34 -0
- ai_snake_lab/constants/DLayout.py +39 -0
- ai_snake_lab/constants/DModelL.py +17 -0
- ai_snake_lab/constants/DModelLRNN.py +20 -0
- ai_snake_lab/constants/DReplayMemory.py +25 -0
- ai_snake_lab/constants/__init__.py +0 -0
- ai_snake_lab/game/GameBoard.py +221 -0
- ai_snake_lab/game/GameElements.py +27 -0
- ai_snake_lab/game/SnakeGame.py +178 -0
- ai_snake_lab/utils/AISim.tcss +115 -0
- ai_snake_lab/utils/ConstGroup.py +49 -0
- ai_snake_lab-0.1.0.dist-info/LICENSE +674 -0
- ai_snake_lab-0.1.0.dist-info/METADATA +70 -0
- ai_snake_lab-0.1.0.dist-info/RECORD +28 -0
- ai_snake_lab-0.1.0.dist-info/WHEEL +4 -0
- ai_snake_lab-0.1.0.dist-info/entry_points.txt +3 -0
ai_snake_lab/AISim.py
ADDED
@@ -0,0 +1,274 @@
|
|
1
|
+
"""
|
2
|
+
AISim.py
|
3
|
+
|
4
|
+
AI Snake Game Simulator
|
5
|
+
Author: Nadim-Daniel Ghaznavi
|
6
|
+
Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
|
7
|
+
GitHub: https://github.com/NadimGhaznavi/ai
|
8
|
+
License: GPL 3.0
|
9
|
+
"""
|
10
|
+
|
11
|
+
import threading
|
12
|
+
import time
|
13
|
+
import sys
|
14
|
+
|
15
|
+
from textual.app import App, ComposeResult
|
16
|
+
from textual.widgets import Label, Input, Button
|
17
|
+
from textual.containers import Vertical, Horizontal
|
18
|
+
from textual.reactive import var
|
19
|
+
|
20
|
+
from constants.DDef import DDef
|
21
|
+
from constants.DEpsilon import DEpsilon
|
22
|
+
from constants.DFields import DField
|
23
|
+
from constants.DFile import DFile
|
24
|
+
from constants.DLayout import DLayout
|
25
|
+
from constants.DLabels import DLabel
|
26
|
+
from constants.DReplayMemory import MEM_TYPE
|
27
|
+
|
28
|
+
from ai.AIAgent import AIAgent
|
29
|
+
from ai.EpsilonAlgo import EpsilonAlgo
|
30
|
+
from game.GameBoard import GameBoard
|
31
|
+
from game.SnakeGame import SnakeGame
|
32
|
+
|
33
|
+
RANDOM_SEED = 1970
|
34
|
+
|
35
|
+
|
36
|
+
class AISim(App):
|
37
|
+
"""A Textual app that has an AI Agent playing the Snake Game."""
|
38
|
+
|
39
|
+
TITLE = DDef.APP_TITLE
|
40
|
+
CSS_PATH = DFile.CSS_PATH
|
41
|
+
|
42
|
+
## Runtime values
|
43
|
+
# Current epsilon value (degrades in real-time)
|
44
|
+
cur_epsilon_widget = Label("N/A", id=DLayout.CUR_EPSILON)
|
45
|
+
# Current memory type
|
46
|
+
cur_mem_type_widget = Label("N/A", id=DLayout.CUR_MEM_TYPE)
|
47
|
+
# Number of stored memories
|
48
|
+
cur_num_memories_widget = Label("N/A", id=DLayout.NUM_MEMORIES)
|
49
|
+
# Runtime move delay value
|
50
|
+
cur_move_delay = DDef.MOVE_DELAY
|
51
|
+
|
52
|
+
# Intial Settings for Epsilon
|
53
|
+
initial_epsilon_input = Input(
|
54
|
+
restrict=f"0.[0-9]*",
|
55
|
+
compact=True,
|
56
|
+
id=DLayout.EPSILON_INITIAL,
|
57
|
+
classes=DLayout.INPUT_10,
|
58
|
+
)
|
59
|
+
epsilon_min_input = Input(
|
60
|
+
restrict=f"0.[0-9]*",
|
61
|
+
compact=True,
|
62
|
+
id=DLayout.EPSILON_MIN,
|
63
|
+
classes=DLayout.INPUT_10,
|
64
|
+
)
|
65
|
+
epsilon_decay_input = Input(
|
66
|
+
restrict=f"0.[0-9]*",
|
67
|
+
compact=True,
|
68
|
+
id=DLayout.EPSILON_DECAY,
|
69
|
+
classes=DLayout.INPUT_10,
|
70
|
+
)
|
71
|
+
move_delay_input = Input(
|
72
|
+
restrict=f"[0-9]*.[0-9]*",
|
73
|
+
compact=True,
|
74
|
+
id=DLayout.MOVE_DELAY,
|
75
|
+
classes=DLayout.INPUT_10,
|
76
|
+
)
|
77
|
+
|
78
|
+
# Buttons
|
79
|
+
pause_button = Button(label=DLabel.PAUSE, id=DLayout.BUTTON_PAUSE, compact=True)
|
80
|
+
start_button = Button(label=DLabel.START, id=DLayout.BUTTON_START, compact=True)
|
81
|
+
quit_button = Button(label=DLabel.QUIT, id=DLayout.BUTTON_QUIT, compact=True)
|
82
|
+
reset_button = Button(label=DLabel.RESET, id=DLayout.BUTTON_RESET, compact=True)
|
83
|
+
update_button = Button(label=DLabel.UPDATE, id=DLayout.BUTTON_UPDATE, compact=True)
|
84
|
+
|
85
|
+
def __init__(self) -> None:
|
86
|
+
super().__init__()
|
87
|
+
self.game_board = GameBoard(20, id=DLayout.GAME_BOARD)
|
88
|
+
self.snake_game = SnakeGame(game_board=self.game_board, id=DLayout.GAME_BOARD)
|
89
|
+
self.epsilon_algo = EpsilonAlgo(seed=RANDOM_SEED)
|
90
|
+
self.agent = AIAgent(self.epsilon_algo, seed=RANDOM_SEED)
|
91
|
+
self.running = False
|
92
|
+
|
93
|
+
self.score = Label("Game: 0, Highscore: 0, Score: 0")
|
94
|
+
|
95
|
+
# Setup the simulator in a background thread
|
96
|
+
self.stop_event = threading.Event()
|
97
|
+
self.simulator_thread = threading.Thread(target=self.start_sim, daemon=True)
|
98
|
+
|
99
|
+
async def action_quit(self) -> None:
|
100
|
+
"""Quit the application."""
|
101
|
+
self.stop_event.set()
|
102
|
+
if self.simulator_thread.is_alive():
|
103
|
+
self.simulator_thread.join(timeout=2)
|
104
|
+
await super().action_quit()
|
105
|
+
|
106
|
+
def compose(self) -> ComposeResult:
|
107
|
+
"""Create child widgets for the app."""
|
108
|
+
yield Label(DDef.APP_TITLE, id=DLayout.TITLE)
|
109
|
+
yield Horizontal(
|
110
|
+
Vertical(
|
111
|
+
Vertical(
|
112
|
+
Horizontal(
|
113
|
+
Label(
|
114
|
+
f"{DLabel.EPSILON_INITIAL} : ",
|
115
|
+
classes=DLayout.LABEL_SETTINGS,
|
116
|
+
),
|
117
|
+
self.initial_epsilon_input,
|
118
|
+
),
|
119
|
+
Horizontal(
|
120
|
+
Label(
|
121
|
+
f"{DLabel.EPSILON_DECAY} : ",
|
122
|
+
classes=DLayout.LABEL_SETTINGS,
|
123
|
+
),
|
124
|
+
self.epsilon_decay_input,
|
125
|
+
),
|
126
|
+
Horizontal(
|
127
|
+
Label(
|
128
|
+
f"{DLabel.EPSILON_MIN} : ", classes=DLayout.LABEL_SETTINGS
|
129
|
+
),
|
130
|
+
self.epsilon_min_input,
|
131
|
+
),
|
132
|
+
Horizontal(
|
133
|
+
Label(
|
134
|
+
f"{DLabel.MOVE_DELAY} : ",
|
135
|
+
classes=DLayout.LABEL_SETTINGS,
|
136
|
+
),
|
137
|
+
self.move_delay_input,
|
138
|
+
),
|
139
|
+
id=DLayout.SETTINGS_BOX,
|
140
|
+
),
|
141
|
+
Vertical(
|
142
|
+
Horizontal(
|
143
|
+
self.start_button,
|
144
|
+
self.reset_button,
|
145
|
+
self.update_button,
|
146
|
+
self.quit_button,
|
147
|
+
),
|
148
|
+
id=DLayout.BUTTON_ROW,
|
149
|
+
),
|
150
|
+
),
|
151
|
+
Vertical(
|
152
|
+
self.game_board,
|
153
|
+
id=DLayout.GAME_BOX,
|
154
|
+
),
|
155
|
+
Vertical(
|
156
|
+
Horizontal(
|
157
|
+
Label(f"{DLabel.EPSILON} : ", classes=DLayout.LABEL),
|
158
|
+
self.cur_epsilon_widget,
|
159
|
+
),
|
160
|
+
Horizontal(
|
161
|
+
Label(f"{DLabel.MEM_TYPE} : ", classes=DLayout.LABEL),
|
162
|
+
self.cur_mem_type_widget,
|
163
|
+
),
|
164
|
+
Horizontal(
|
165
|
+
Label(f"{DLabel.MEMORIES} : ", classes=DLayout.LABEL),
|
166
|
+
self.cur_num_memories_widget,
|
167
|
+
),
|
168
|
+
id=DLayout.RUNTIME_BOX,
|
169
|
+
),
|
170
|
+
)
|
171
|
+
|
172
|
+
def on_mount(self):
|
173
|
+
self.initial_epsilon_input.value = str(DEpsilon.EPSILON_INITIAL)
|
174
|
+
self.epsilon_decay_input.value = str(DEpsilon.EPSILON_DECAY)
|
175
|
+
self.epsilon_min_input.value = str(DEpsilon.EPSILON_MIN)
|
176
|
+
self.move_delay_input.value = str(DDef.MOVE_DELAY)
|
177
|
+
settings_box = self.query_one(f"#{DLayout.SETTINGS_BOX}", Vertical)
|
178
|
+
settings_box.border_title = DLabel.SETTINGS
|
179
|
+
runtime_box = self.query_one(f"#{DLayout.RUNTIME_BOX}", Vertical)
|
180
|
+
runtime_box.border_title = DLabel.RUNTIME
|
181
|
+
self.cur_mem_type_widget.update(
|
182
|
+
MEM_TYPE.MEM_TYPE_TABLE[self.agent.memory.mem_type()]
|
183
|
+
)
|
184
|
+
self.cur_num_memories_widget.update(str(self.agent.memory.get_num_memories()))
|
185
|
+
# Initial state is that the app is stopped
|
186
|
+
self.add_class(DField.STOPPED)
|
187
|
+
|
188
|
+
def on_quit(self):
|
189
|
+
if self.running == True:
|
190
|
+
self.stop_event.set()
|
191
|
+
if self.simulator_thread.is_alive():
|
192
|
+
self.simulator_thread.join()
|
193
|
+
sys.exit(0)
|
194
|
+
|
195
|
+
def on_button_pressed(self, event: Button.Pressed) -> None:
|
196
|
+
button_id = event.button.id
|
197
|
+
# Start button was pressed
|
198
|
+
if button_id == DLayout.BUTTON_START:
|
199
|
+
self.start_thread()
|
200
|
+
self.running = True
|
201
|
+
self.add_class(DField.RUNNING)
|
202
|
+
self.remove_class(DField.STOPPED)
|
203
|
+
self.cur_move_delay = float(self.move_delay_input.value)
|
204
|
+
# Reset button was pressed
|
205
|
+
elif button_id == DLayout.BUTTON_RESET:
|
206
|
+
self.initial_epsilon_input.value = str(DEpsilon.EPSILON_INITIAL)
|
207
|
+
self.epsilon_decay_input.value = str(DEpsilon.EPSILON_DECAY)
|
208
|
+
self.epsilon_min_input.value = str(DEpsilon.EPSILON_MIN)
|
209
|
+
self.move_delay_input.value = str(DDef.MOVE_DELAY)
|
210
|
+
# Quit button was pressed
|
211
|
+
elif button_id == DLayout.BUTTON_QUIT:
|
212
|
+
self.on_quit()
|
213
|
+
# Update button was pressed
|
214
|
+
elif button_id == DLayout.BUTTON_UPDATE:
|
215
|
+
self.cur_move_delay = float(self.move_delay_input.value)
|
216
|
+
|
217
|
+
def start_sim(self):
|
218
|
+
self.snake_game.reset()
|
219
|
+
game_board = self.game_board
|
220
|
+
agent = self.agent
|
221
|
+
snake_game = self.snake_game
|
222
|
+
score = 0
|
223
|
+
highscore = 0
|
224
|
+
self.epoch = 1
|
225
|
+
game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
|
226
|
+
game_box.border_title = f"{DLabel.GAME} #{self.epoch}"
|
227
|
+
|
228
|
+
while not self.stop_event.is_set():
|
229
|
+
# The actual training loop...
|
230
|
+
old_state = game_board.get_state()
|
231
|
+
move = agent.get_move(old_state)
|
232
|
+
reward, game_over, score = snake_game.play_step(move)
|
233
|
+
if score > highscore:
|
234
|
+
highscore = score
|
235
|
+
game_box.border_subtitle = (
|
236
|
+
f"{DLabel.HIGHSCORE}: {highscore}, {DLabel.SCORE}: {score}"
|
237
|
+
)
|
238
|
+
if not game_over:
|
239
|
+
## Keep playing
|
240
|
+
time.sleep(self.cur_move_delay)
|
241
|
+
new_state = game_board.get_state()
|
242
|
+
agent.train_short_memory(old_state, move, reward, new_state, game_over)
|
243
|
+
agent.remember(old_state, move, reward, new_state, game_over)
|
244
|
+
else:
|
245
|
+
## Game over
|
246
|
+
self.epoch += 1
|
247
|
+
game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
|
248
|
+
game_box.border_title = f"{DLabel.GAME} #{self.epoch}"
|
249
|
+
# Remember the last move
|
250
|
+
agent.remember(old_state, move, reward, new_state, game_over)
|
251
|
+
# Train long memory
|
252
|
+
agent.train_long_memory()
|
253
|
+
# Reset the game
|
254
|
+
snake_game.reset()
|
255
|
+
# Let the agent know we've finished a game
|
256
|
+
agent.played_game(score)
|
257
|
+
# Get the current epsilon value
|
258
|
+
cur_epsilon = self.epsilon_algo.epsilon()
|
259
|
+
if cur_epsilon < 0.0001:
|
260
|
+
self.cur_epsilon_widget.update("0.0000")
|
261
|
+
else:
|
262
|
+
self.cur_epsilon_widget.update(str(round(cur_epsilon, 4)))
|
263
|
+
# Update the number of stored memories
|
264
|
+
self.cur_num_memories_widget.update(
|
265
|
+
str(self.agent.memory.get_num_memories())
|
266
|
+
)
|
267
|
+
|
268
|
+
def start_thread(self):
|
269
|
+
self.simulator_thread.start()
|
270
|
+
|
271
|
+
|
272
|
+
if __name__ == "__main__":
|
273
|
+
app = AISim()
|
274
|
+
app.run()
|
@@ -0,0 +1,84 @@
|
|
1
|
+
"""
|
2
|
+
ai/Agent.py
|
3
|
+
|
4
|
+
AI Snake Game Simulator
|
5
|
+
Author: Nadim-Daniel Ghaznavi
|
6
|
+
Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
|
7
|
+
GitHub: https://github.com/NadimGhaznavi/ai
|
8
|
+
License: GPL 3.0
|
9
|
+
"""
|
10
|
+
|
11
|
+
import torch
|
12
|
+
from ai.EpsilonAlgo import EpsilonAlgo
|
13
|
+
from ai.ReplayMemory import ReplayMemory
|
14
|
+
from ai.AITrainer import AITrainer
|
15
|
+
from ai.models.ModelL import ModelL
|
16
|
+
from ai.models.ModelRNN import ModelRNN
|
17
|
+
|
18
|
+
from constants.DReplayMemory import MEM_TYPE
|
19
|
+
|
20
|
+
|
21
|
+
class AIAgent:
|
22
|
+
|
23
|
+
def __init__(self, epsilon_algo: EpsilonAlgo, seed: int):
|
24
|
+
self.epsilon_algo = epsilon_algo
|
25
|
+
self.memory = ReplayMemory(seed=seed)
|
26
|
+
self.model = ModelL(seed=seed)
|
27
|
+
# self.model = ModelRNN(seed=seed)
|
28
|
+
self.trainer = AITrainer(self.model)
|
29
|
+
|
30
|
+
if type(self.model) == ModelRNN:
|
31
|
+
self.memory.mem_type(MEM_TYPE.RANDOM_GAME)
|
32
|
+
|
33
|
+
def get_model(self):
|
34
|
+
return self.model
|
35
|
+
|
36
|
+
def get_move(self, state):
|
37
|
+
random_move = self.epsilon_algo.get_move() # Explore with epsilon
|
38
|
+
if random_move != False:
|
39
|
+
return random_move # Random move was returned
|
40
|
+
|
41
|
+
# Exploit with an AI agent based action
|
42
|
+
final_move = [0, 0, 0]
|
43
|
+
if type(state) != torch.Tensor:
|
44
|
+
state = torch.tensor(state, dtype=torch.float) # Convert to a tensor
|
45
|
+
prediction = self.model(state) # Get the prediction
|
46
|
+
move = torch.argmax(prediction).item() # Select the move with the highest value
|
47
|
+
final_move[move] = 1 # Set the move
|
48
|
+
return final_move # Return
|
49
|
+
|
50
|
+
def get_optimizer(self):
|
51
|
+
return self.trainer.get_optimizer()
|
52
|
+
|
53
|
+
def played_game(self, score):
|
54
|
+
self.epsilon_algo.played_game()
|
55
|
+
|
56
|
+
def remember(self, state, action, reward, next_state, done):
|
57
|
+
# Store the state, action, reward, next_state, and done in memory
|
58
|
+
self.memory.append((state, action, reward, next_state, done))
|
59
|
+
|
60
|
+
def set_model(self, model):
|
61
|
+
self.model = model
|
62
|
+
|
63
|
+
def set_optimizer(self, optimizer):
|
64
|
+
self.trainer.set_optimizer(optimizer)
|
65
|
+
|
66
|
+
def train_long_memory(self):
|
67
|
+
# Get the states, actions, rewards, next_states, and dones from the mini_sample
|
68
|
+
memory = self.memory.get_memory()
|
69
|
+
memory_type = self.memory.mem_type()
|
70
|
+
|
71
|
+
if type(self.model) == ModelRNN:
|
72
|
+
for state, action, reward, next_state, done in memory[0]:
|
73
|
+
self.trainer.train_step(state, action, reward, next_state, [done])
|
74
|
+
|
75
|
+
elif memory_type == MEM_TYPE.SHUFFLE:
|
76
|
+
for state, action, reward, next_state, done in memory:
|
77
|
+
self.trainer.train_step(state, action, reward, next_state, [done])
|
78
|
+
|
79
|
+
else:
|
80
|
+
for state, action, reward, next_state, done in memory[0]:
|
81
|
+
self.trainer.train_step(state, action, reward, next_state, [done])
|
82
|
+
|
83
|
+
def train_short_memory(self, state, action, reward, next_state, done):
|
84
|
+
self.trainer.train_step(state, action, reward, next_state, [done])
|
@@ -0,0 +1,90 @@
|
|
1
|
+
"""
|
2
|
+
ai/AITrainer.py
|
3
|
+
|
4
|
+
AI Snake Game Simulator
|
5
|
+
Author: Nadim-Daniel Ghaznavi
|
6
|
+
Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
|
7
|
+
GitHub: https://github.com/NadimGhaznavi/ai
|
8
|
+
License: GPL 3.0
|
9
|
+
"""
|
10
|
+
|
11
|
+
import torch.optim as optim
|
12
|
+
import torch.nn as nn
|
13
|
+
import torch
|
14
|
+
import numpy as np
|
15
|
+
import time
|
16
|
+
import sys
|
17
|
+
|
18
|
+
from ai.models.ModelL import ModelL
|
19
|
+
from ai.models.ModelRNN import ModelRNN
|
20
|
+
|
21
|
+
from constants.DModelL import DModelL
|
22
|
+
from constants.DModelLRNN import DModelRNN
|
23
|
+
|
24
|
+
|
25
|
+
class AITrainer:
|
26
|
+
|
27
|
+
def __init__(self, model):
|
28
|
+
torch.manual_seed(1970)
|
29
|
+
self.model = model
|
30
|
+
# The learning rate needs to be adjusted for the model type
|
31
|
+
if type(model) == ModelL:
|
32
|
+
learning_rate = DModelL.LEARNING_RATE
|
33
|
+
elif type(model) == ModelRNN:
|
34
|
+
learning_rate = DModelRNN.LEARNING_RATE
|
35
|
+
self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
|
36
|
+
self.criterion = nn.MSELoss()
|
37
|
+
self.gamma = 0.9
|
38
|
+
|
39
|
+
def get_optimizer(self):
|
40
|
+
return self.optimizer
|
41
|
+
|
42
|
+
def set_optimizer(self, optimizer):
|
43
|
+
self.optimizer = optimizer
|
44
|
+
|
45
|
+
def train_step_cnn(self, state, action, reward, next_state, game_over):
|
46
|
+
state = torch.tensor(np.array(state), dtype=torch.float)
|
47
|
+
next_state = torch.tensor(np.array(next_state), dtype=torch.float)
|
48
|
+
action = torch.tensor(action, dtype=torch.long)
|
49
|
+
reward = torch.tensor(reward, dtype=torch.float)
|
50
|
+
pred = self.model(state)
|
51
|
+
target = pred.clone()
|
52
|
+
if game_over:
|
53
|
+
Q_new = reward # No future rewards, the game is over.
|
54
|
+
else:
|
55
|
+
Q_new = reward + self.gamma * torch.max(self.model(next_state).detach())
|
56
|
+
target[0][action.argmax().item()] = Q_new # Update Q value
|
57
|
+
self.optimizer.zero_grad() # Reset gradients
|
58
|
+
loss = self.criterion(target, pred) # Calculate the loss
|
59
|
+
loss.backward()
|
60
|
+
self.optimizer.step() # Adjust the weights
|
61
|
+
|
62
|
+
def train_step(self, state, action, reward, next_state, game_over):
|
63
|
+
state = torch.tensor(np.array(state), dtype=torch.float)
|
64
|
+
next_state = torch.tensor(np.array(next_state), dtype=torch.float)
|
65
|
+
action = torch.tensor(action, dtype=torch.long)
|
66
|
+
reward = torch.tensor(reward, dtype=torch.float)
|
67
|
+
if len(state.shape) == 1:
|
68
|
+
# Add a batch dimension
|
69
|
+
state = torch.unsqueeze(state, 0)
|
70
|
+
next_state = torch.unsqueeze(next_state, 0)
|
71
|
+
action = torch.unsqueeze(action, 0)
|
72
|
+
reward = torch.unsqueeze(reward, 0)
|
73
|
+
game_over = (game_over,)
|
74
|
+
|
75
|
+
pred = self.model(state)
|
76
|
+
target = pred.clone().detach()
|
77
|
+
|
78
|
+
for idx in range(len(game_over)):
|
79
|
+
Q_new = reward[idx]
|
80
|
+
if not game_over[idx][0]:
|
81
|
+
Q_new = reward[idx] + self.gamma * torch.max(
|
82
|
+
self.model(next_state[idx])
|
83
|
+
)
|
84
|
+
target[idx][action[idx].argmax().item()] = Q_new # Update Q value
|
85
|
+
|
86
|
+
self.optimizer.zero_grad() # Reset gradients
|
87
|
+
|
88
|
+
loss = self.criterion(target, pred) # Calculate the loss
|
89
|
+
loss.backward()
|
90
|
+
self.optimizer.step() # Adjust the weights
|
@@ -0,0 +1,73 @@
|
|
1
|
+
"""
|
2
|
+
ai/EpsilonAlgo.py
|
3
|
+
|
4
|
+
AI Snake Game Simulator
|
5
|
+
Author: Nadim-Daniel Ghaznavi
|
6
|
+
Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
|
7
|
+
GitHub: https://github.com/NadimGhaznavi/ai
|
8
|
+
License: GPL 3.0
|
9
|
+
|
10
|
+
|
11
|
+
A class to encapsulate the functionality of the epsilon algorithm. The algorithm
|
12
|
+
injects random moves at the beginning of the simulation. The amount of moves
|
13
|
+
is controlled by the epsilon_value parameter which is in the AISnakeGame.ini and
|
14
|
+
can also be passed in when invoking the main asg.py front end.
|
15
|
+
"""
|
16
|
+
|
17
|
+
import random
|
18
|
+
from random import randint
|
19
|
+
import os, sys
|
20
|
+
|
21
|
+
from constants.DEpsilon import DEpsilon
|
22
|
+
|
23
|
+
|
24
|
+
class EpsilonAlgo:
|
25
|
+
|
26
|
+
def __init__(self, seed):
|
27
|
+
# Set this random seed so things are repeatable
|
28
|
+
random.seed(seed)
|
29
|
+
self._initial_epsilon = DEpsilon.EPSILON_INITIAL
|
30
|
+
self._epsilon_min = DEpsilon.EPSILON_MIN
|
31
|
+
self._epsilon_decay = DEpsilon.EPSILON_DECAY
|
32
|
+
self._epsilon = self._initial_epsilon
|
33
|
+
self._num_games = 0
|
34
|
+
self._injected = 0
|
35
|
+
self._depleted = False
|
36
|
+
|
37
|
+
def get_move(self):
|
38
|
+
if random.random() < self._epsilon:
|
39
|
+
rand_move = [0, 0, 0]
|
40
|
+
rand_idx = randint(0, 2)
|
41
|
+
rand_move[rand_idx] = 1
|
42
|
+
self._injected += 1
|
43
|
+
return rand_move
|
44
|
+
return False
|
45
|
+
|
46
|
+
def epsilon(self):
|
47
|
+
return self._epsilon
|
48
|
+
|
49
|
+
def epsilon_decay(self, epsilon_decay=None):
|
50
|
+
if epsilon_decay is not None:
|
51
|
+
self._epsilon_decay = epsilon_decay
|
52
|
+
return self._epsilon_decay
|
53
|
+
|
54
|
+
def epsilon_min(self, epsilon_min=None):
|
55
|
+
if epsilon_min is not None:
|
56
|
+
self._epsilon_min = epsilon_min
|
57
|
+
return self._epsilon_min
|
58
|
+
|
59
|
+
def initial_epsilon(self, initial_epsilon=None):
|
60
|
+
if initial_epsilon is not None:
|
61
|
+
self._initial_epsilon = initial_epsilon
|
62
|
+
return self._initial_epsilon
|
63
|
+
|
64
|
+
def injected(self):
|
65
|
+
return self._injected
|
66
|
+
|
67
|
+
def played_game(self):
|
68
|
+
self._num_games += 1
|
69
|
+
self._epsilon = max(self._epsilon_min, self._epsilon * self._epsilon_decay)
|
70
|
+
self.reset_injected()
|
71
|
+
|
72
|
+
def reset_injected(self):
|
73
|
+
self._injected = 0
|
@@ -0,0 +1,90 @@
|
|
1
|
+
"""
|
2
|
+
ai/ReplayMemory.py
|
3
|
+
|
4
|
+
AI Snake Game Simulator
|
5
|
+
Author: Nadim-Daniel Ghaznavi
|
6
|
+
Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
|
7
|
+
GitHub: https://github.com/NadimGhaznavi/ai
|
8
|
+
License: GPL 3.0
|
9
|
+
|
10
|
+
This file contains the ReplayMemory class.
|
11
|
+
"""
|
12
|
+
|
13
|
+
from collections import deque
|
14
|
+
import random, sys
|
15
|
+
|
16
|
+
from constants.DReplayMemory import MEM_TYPE
|
17
|
+
|
18
|
+
|
19
|
+
class ReplayMemory:
|
20
|
+
|
21
|
+
def __init__(self, seed: int):
|
22
|
+
random.seed(seed)
|
23
|
+
self.batch_size = 250
|
24
|
+
# Valid options: shuffle, random_game, targeted_score, random_targeted_score
|
25
|
+
self._mem_type = MEM_TYPE.RANDOM_GAME
|
26
|
+
self.min_games = 1
|
27
|
+
self.max_states = 15000
|
28
|
+
self.max_shuffle_games = 40
|
29
|
+
self.max_games = 500
|
30
|
+
|
31
|
+
if self._mem_type == MEM_TYPE.SHUFFLE:
|
32
|
+
# States are stored in a deque and a random sample will be returned
|
33
|
+
self.memories = deque(maxlen=self.max_states)
|
34
|
+
|
35
|
+
elif self._mem_type == MEM_TYPE.RANDOM_GAME:
|
36
|
+
# All of the states for a game are stored, in order, in a deque.
|
37
|
+
# A complete game will be returned
|
38
|
+
self.memories = deque(maxlen=self.max_shuffle_games)
|
39
|
+
self.cur_memory = []
|
40
|
+
|
41
|
+
else:
|
42
|
+
print(f"ERROR: Unrecognized replay memory type ({self._mem_type}), exiting")
|
43
|
+
sys.exit(1)
|
44
|
+
|
45
|
+
def append(self, transition):
|
46
|
+
## Add memories
|
47
|
+
|
48
|
+
# States are stored in a deque and a random sample will be returned
|
49
|
+
if self._mem_type == MEM_TYPE.SHUFFLE:
|
50
|
+
self.memories.append(transition)
|
51
|
+
|
52
|
+
# All of the states for a game are stored, in order, in a deque.
|
53
|
+
# A set of ordered states representing a complete game will be returned
|
54
|
+
elif self._mem_type == MEM_TYPE.RANDOM_GAME:
|
55
|
+
self.cur_memory.append(transition)
|
56
|
+
state, action, reward, next_state, done = transition
|
57
|
+
if done:
|
58
|
+
self.memories.append(self.cur_memory)
|
59
|
+
self.cur_memory = []
|
60
|
+
|
61
|
+
def get_random_game(self):
|
62
|
+
if len(self.memories) >= self.min_games:
|
63
|
+
rand_game = random.sample(self.memories, 1)
|
64
|
+
return rand_game
|
65
|
+
else:
|
66
|
+
return False
|
67
|
+
|
68
|
+
def get_random_states(self):
|
69
|
+
mem_size = len(self.memories)
|
70
|
+
if mem_size < self.batch_size:
|
71
|
+
return self.memories
|
72
|
+
return random.sample(self.memories, self.batch_size)
|
73
|
+
|
74
|
+
def get_memory(self):
|
75
|
+
if self._mem_type == MEM_TYPE.SHUFFLE:
|
76
|
+
return self.get_random_states()
|
77
|
+
|
78
|
+
elif self._mem_type == MEM_TYPE.RANDOM_GAME:
|
79
|
+
return self.get_random_game()
|
80
|
+
|
81
|
+
def get_num_memories(self):
|
82
|
+
return len(self.memories)
|
83
|
+
|
84
|
+
def mem_type(self, mem_type=None):
|
85
|
+
if mem_type is not None:
|
86
|
+
self._mem_type = mem_type
|
87
|
+
return self._mem_type
|
88
|
+
|
89
|
+
def set_memory(self, memory):
|
90
|
+
self.memory = memory
|
@@ -0,0 +1,40 @@
|
|
1
|
+
"""
|
2
|
+
Modules/ModelL.py
|
3
|
+
|
4
|
+
AI Snake Game Simulator
|
5
|
+
Author: Nadim-Daniel Ghaznavi
|
6
|
+
Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
|
7
|
+
GitHub: https://github.com/NadimGhaznavi/ai
|
8
|
+
License: GPL 3.0
|
9
|
+
"""
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import torch.nn as nn
|
13
|
+
import torch.nn.functional as F
|
14
|
+
|
15
|
+
|
16
|
+
class ModelL(nn.Module):
|
17
|
+
def __init__(self, seed: int):
|
18
|
+
super(ModelL, self).__init__()
|
19
|
+
torch.manual_seed(seed)
|
20
|
+
input_size = 27 # Size of the "state" as tracked by the GameBoard
|
21
|
+
hidden_size = 170
|
22
|
+
output_size = 3
|
23
|
+
p_value = 0.1
|
24
|
+
self.input_block = nn.Sequential(
|
25
|
+
nn.Linear(input_size, hidden_size),
|
26
|
+
nn.ReLU(),
|
27
|
+
)
|
28
|
+
self.hidden_block = nn.Sequential(
|
29
|
+
nn.Linear(hidden_size, hidden_size),
|
30
|
+
nn.ReLU(),
|
31
|
+
)
|
32
|
+
self.dropout_block = nn.Dropout(p=p_value)
|
33
|
+
self.output_block = nn.Linear(hidden_size, output_size)
|
34
|
+
|
35
|
+
def forward(self, x):
|
36
|
+
x = self.input_block(x)
|
37
|
+
x = self.hidden_block(x)
|
38
|
+
x = self.dropout_block(x)
|
39
|
+
x = self.output_block(x)
|
40
|
+
return x
|
@@ -0,0 +1,43 @@
|
|
1
|
+
"""
|
2
|
+
Modules/ModelRNN.py
|
3
|
+
|
4
|
+
AI Snake Game Simulator
|
5
|
+
Author: Nadim-Daniel Ghaznavi
|
6
|
+
Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
|
7
|
+
GitHub: https://github.com/NadimGhaznavi/ai
|
8
|
+
License: GPL 3.0
|
9
|
+
"""
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import torch.nn as nn
|
13
|
+
import torch.nn.functional as F
|
14
|
+
|
15
|
+
|
16
|
+
class ModelRNN(nn.Module):
|
17
|
+
def __init__(self, seed: int):
|
18
|
+
super(ModelRNN, self).__init__()
|
19
|
+
torch.manual_seed(seed)
|
20
|
+
input_size = 27
|
21
|
+
hidden_size = 200
|
22
|
+
output_size = 3
|
23
|
+
rnn_layers = 4
|
24
|
+
rnn_dropout = 0.2
|
25
|
+
self.m_in = nn.Sequential(
|
26
|
+
nn.Linear(input_size, hidden_size),
|
27
|
+
nn.ReLU(),
|
28
|
+
)
|
29
|
+
self.m_rnn = nn.RNN(
|
30
|
+
input_size=hidden_size,
|
31
|
+
hidden_size=hidden_size,
|
32
|
+
nonlinearity="tanh",
|
33
|
+
num_layers=rnn_layers,
|
34
|
+
dropout=rnn_dropout,
|
35
|
+
)
|
36
|
+
self.m_out = nn.Linear(hidden_size, output_size)
|
37
|
+
|
38
|
+
def forward(self, x):
|
39
|
+
x = self.m_in(x)
|
40
|
+
inputs = x.view(1, -1, 200)
|
41
|
+
x, h_n = self.m_rnn(inputs)
|
42
|
+
x = self.m_out(x)
|
43
|
+
return x[len(x) - 1]
|