ai-snake-lab 0.1.0__tar.gz → 0.4.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/LICENSE +2 -0
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/PKG-INFO +39 -5
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/README.md +36 -3
- ai_snake_lab-0.4.3/ai_snake_lab/AISim.py +431 -0
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/ai/AIAgent.py +34 -31
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/ai/AITrainer.py +9 -5
- ai_snake_lab-0.4.3/ai_snake_lab/ai/ReplayMemory.py +127 -0
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/ai/models/ModelL.py +7 -4
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/ai/models/ModelRNN.py +1 -1
- ai_snake_lab-0.4.3/ai_snake_lab/constants/DDb4EPlot.py +20 -0
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DDef.py +1 -1
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DDir.py +4 -1
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DFields.py +6 -0
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DFile.py +2 -1
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DLabels.py +24 -4
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DLayout.py +16 -2
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DModelL.py +4 -0
- ai_snake_lab-0.4.3/ai_snake_lab/constants/DSim.py +20 -0
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/game/GameBoard.py +36 -22
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/game/SnakeGame.py +17 -0
- ai_snake_lab-0.4.3/ai_snake_lab/ui/Db4EPlot.py +160 -0
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/utils/AISim.tcss +81 -38
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/pyproject.toml +1 -1
- ai_snake_lab-0.1.0/ai_snake_lab/AISim.py +0 -274
- ai_snake_lab-0.1.0/ai_snake_lab/ai/ReplayMemory.py +0 -90
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/ai/EpsilonAlgo.py +0 -0
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DEpsilon.py +0 -0
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DModelLRNN.py +0 -0
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DReplayMemory.py +0 -0
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/__init__.py +0 -0
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/game/GameElements.py +0 -0
- {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/utils/ConstGroup.py +0 -0
@@ -672,3 +672,5 @@ may consider it more useful to permit linking proprietary applications with
|
|
672
672
|
the library. If this is what you want to do, use the GNU Lesser General
|
673
673
|
Public License instead of this License. But first, please read
|
674
674
|
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
675
|
+
|
676
|
+
|
@@ -1,8 +1,9 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: ai-snake-lab
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.3
|
4
4
|
Summary: Interactive reinforcement learning sandbox for experimenting with AI agents in a classic Snake Game environment.
|
5
5
|
License: GPL-3.0
|
6
|
+
License-File: LICENSE
|
6
7
|
Keywords: AI,Reinforcement Learning,Textual,Snake,Simulation,SQLite,Python
|
7
8
|
Author: Nadim-Daniel Ghaznavi
|
8
9
|
Author-email: nghaznavi@gmail.com
|
@@ -63,8 +64,41 @@ Description-Content-Type: text/markdown
|
|
63
64
|
|
64
65
|
---
|
65
66
|
|
66
|
-
#
|
67
|
+
# Installation
|
68
|
+
|
69
|
+
This project is on [PyPI](https://pypi.org/project/ai-snake-lab/). You can install the *AI Snake Lab* software using `pip`.
|
70
|
+
|
71
|
+
## Create a Sandbox
|
72
|
+
|
73
|
+
```shell
|
74
|
+
python3 -m venv snake_venv
|
75
|
+
. snake_venv/bin/activate
|
76
|
+
```
|
77
|
+
|
78
|
+
## Install the AI Snake Lab
|
79
|
+
|
80
|
+
After you have activated your *venv* environment:
|
81
|
+
|
82
|
+
```shell
|
83
|
+
pip install ai-snake-lab
|
84
|
+
```
|
85
|
+
|
86
|
+
---
|
87
|
+
|
88
|
+
# Running the AI Snake Lab
|
89
|
+
|
90
|
+
From within your *venv* environment:
|
91
|
+
|
92
|
+
```shell
|
93
|
+
ai-snake-lab
|
94
|
+
```
|
95
|
+
|
96
|
+
---
|
97
|
+
|
98
|
+
# Links and Acknowledgements
|
99
|
+
|
100
|
+
This code is based on a YouTube tutorial, [Python + PyTorch + Pygame Reinforcement Learning – Train an AI to Play Snake](https://www.youtube.com/watch?v=L8ypSXwyBds&t=1042s&ab_channel=freeCodeCamp.org) by Patrick Loeber. You can access his original code [here](https://github.com/patrickloeber/snake-ai-pytorch) on GitHub. Thank you Patrick!!! You are amazing!!!!
|
101
|
+
|
102
|
+
Thanks also go out to Will McGugan and the [Textual](https://textual.textualize.io/) team. Textual is an amazing framework. Talk about *rapid Application Development*. Porting this took less than a day.
|
67
103
|
|
68
|
-
* [Project Layout](/pages/project_layout.html)
|
69
|
-
* [SQLite3 Schema](/pages/db_schema.html)
|
70
104
|
---
|
@@ -27,8 +27,41 @@
|
|
27
27
|
|
28
28
|
---
|
29
29
|
|
30
|
-
#
|
30
|
+
# Installation
|
31
|
+
|
32
|
+
This project is on [PyPI](https://pypi.org/project/ai-snake-lab/). You can install the *AI Snake Lab* software using `pip`.
|
33
|
+
|
34
|
+
## Create a Sandbox
|
35
|
+
|
36
|
+
```shell
|
37
|
+
python3 -m venv snake_venv
|
38
|
+
. snake_venv/bin/activate
|
39
|
+
```
|
40
|
+
|
41
|
+
## Install the AI Snake Lab
|
42
|
+
|
43
|
+
After you have activated your *venv* environment:
|
44
|
+
|
45
|
+
```shell
|
46
|
+
pip install ai-snake-lab
|
47
|
+
```
|
48
|
+
|
49
|
+
---
|
50
|
+
|
51
|
+
# Running the AI Snake Lab
|
52
|
+
|
53
|
+
From within your *venv* environment:
|
54
|
+
|
55
|
+
```shell
|
56
|
+
ai-snake-lab
|
57
|
+
```
|
58
|
+
|
59
|
+
---
|
60
|
+
|
61
|
+
# Links and Acknowledgements
|
62
|
+
|
63
|
+
This code is based on a YouTube tutorial, [Python + PyTorch + Pygame Reinforcement Learning – Train an AI to Play Snake](https://www.youtube.com/watch?v=L8ypSXwyBds&t=1042s&ab_channel=freeCodeCamp.org) by Patrick Loeber. You can access his original code [here](https://github.com/patrickloeber/snake-ai-pytorch) on GitHub. Thank you Patrick!!! You are amazing!!!!
|
64
|
+
|
65
|
+
Thanks also go out to Will McGugan and the [Textual](https://textual.textualize.io/) team. Textual is an amazing framework. Talk about *rapid Application Development*. Porting this took less than a day.
|
31
66
|
|
32
|
-
* [Project Layout](/pages/project_layout.html)
|
33
|
-
* [SQLite3 Schema](/pages/db_schema.html)
|
34
67
|
---
|
@@ -0,0 +1,431 @@
|
|
1
|
+
"""
|
2
|
+
AISim.py
|
3
|
+
|
4
|
+
AI Snake Game Simulator
|
5
|
+
Author: Nadim-Daniel Ghaznavi
|
6
|
+
Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
|
7
|
+
GitHub: https://github.com/NadimGhaznavi/ai
|
8
|
+
License: GPL 3.0
|
9
|
+
"""
|
10
|
+
|
11
|
+
import threading
|
12
|
+
import time
|
13
|
+
import sys, os
|
14
|
+
from datetime import datetime, timedelta
|
15
|
+
|
16
|
+
from textual.app import App, ComposeResult
|
17
|
+
from textual.widgets import Label, Input, Button, Static
|
18
|
+
from textual.containers import Vertical, Horizontal
|
19
|
+
from textual.reactive import var
|
20
|
+
from textual.theme import Theme
|
21
|
+
|
22
|
+
from ai_snake_lab.constants.DDef import DDef
|
23
|
+
from ai_snake_lab.constants.DEpsilon import DEpsilon
|
24
|
+
from ai_snake_lab.constants.DFields import DField
|
25
|
+
from ai_snake_lab.constants.DFile import DFile
|
26
|
+
from ai_snake_lab.constants.DLayout import DLayout
|
27
|
+
from ai_snake_lab.constants.DLabels import DLabel
|
28
|
+
from ai_snake_lab.constants.DReplayMemory import MEM_TYPE
|
29
|
+
from ai_snake_lab.constants.DDir import DDir
|
30
|
+
from ai_snake_lab.constants.DDb4EPlot import Plot
|
31
|
+
|
32
|
+
|
33
|
+
from ai_snake_lab.ai.AIAgent import AIAgent
|
34
|
+
from ai_snake_lab.ai.EpsilonAlgo import EpsilonAlgo
|
35
|
+
from ai_snake_lab.game.GameBoard import GameBoard
|
36
|
+
from ai_snake_lab.game.SnakeGame import SnakeGame
|
37
|
+
from ai_snake_lab.ui.Db4EPlot import Db4EPlot
|
38
|
+
|
39
|
+
RANDOM_SEED = 1970
|
40
|
+
|
41
|
+
snake_lab_theme = Theme(
|
42
|
+
name="db4e",
|
43
|
+
primary="#88C0D0",
|
44
|
+
secondary="#1f6a83ff",
|
45
|
+
accent="#B48EAD",
|
46
|
+
foreground="#31b8e6",
|
47
|
+
background="black",
|
48
|
+
success="#A3BE8C",
|
49
|
+
warning="#EBCB8B",
|
50
|
+
error="#BF616A",
|
51
|
+
surface="black",
|
52
|
+
panel="#000000",
|
53
|
+
dark=True,
|
54
|
+
variables={
|
55
|
+
"block-cursor-text-style": "none",
|
56
|
+
"footer-key-foreground": "#88C0D0",
|
57
|
+
"input-selection-background": "#81a1c1 35%",
|
58
|
+
},
|
59
|
+
)
|
60
|
+
|
61
|
+
|
62
|
+
class AISim(App):
|
63
|
+
"""A Textual app that has an AI Agent playing the Snake Game."""
|
64
|
+
|
65
|
+
TITLE = DDef.APP_TITLE
|
66
|
+
CSS_PATH = os.path.join(DDir.UTILS, DFile.CSS_FILE)
|
67
|
+
|
68
|
+
## Runtime values
|
69
|
+
# Current epsilon value (degrades in real-time)
|
70
|
+
cur_epsilon_widget = Label("N/A", id=DLayout.CUR_EPSILON)
|
71
|
+
# Current memory type
|
72
|
+
cur_mem_type_widget = Label("N/A", id=DLayout.CUR_MEM_TYPE)
|
73
|
+
# Current model type
|
74
|
+
cur_model_type_widget = Label("N/A", id=DLayout.CUR_MODEL_TYPE)
|
75
|
+
# Time delay between moves
|
76
|
+
cur_move_delay = DDef.MOVE_DELAY
|
77
|
+
# Number of stored games in the ReplayMemory
|
78
|
+
cur_num_games_widget = Label("N/A", id=DLayout.NUM_GAMES)
|
79
|
+
# Elapsed time
|
80
|
+
cur_runtime_widget = Label("N/A", id=DLayout.RUNTIME)
|
81
|
+
|
82
|
+
# Intial Settings for Epsilon
|
83
|
+
initial_epsilon_input = Input(
|
84
|
+
restrict=f"0.[0-9]*",
|
85
|
+
compact=True,
|
86
|
+
id=DLayout.EPSILON_INITIAL,
|
87
|
+
classes=DLayout.INPUT_10,
|
88
|
+
)
|
89
|
+
epsilon_min_input = Input(
|
90
|
+
restrict=f"0.[0-9]*",
|
91
|
+
compact=True,
|
92
|
+
id=DLayout.EPSILON_MIN,
|
93
|
+
classes=DLayout.INPUT_10,
|
94
|
+
)
|
95
|
+
epsilon_decay_input = Input(
|
96
|
+
restrict=f"0.[0-9]*",
|
97
|
+
compact=True,
|
98
|
+
id=DLayout.EPSILON_DECAY,
|
99
|
+
classes=DLayout.INPUT_10,
|
100
|
+
)
|
101
|
+
move_delay_input = Input(
|
102
|
+
restrict=f"[0-9]*.[0-9]*",
|
103
|
+
compact=True,
|
104
|
+
id=DLayout.MOVE_DELAY,
|
105
|
+
classes=DLayout.INPUT_10,
|
106
|
+
)
|
107
|
+
|
108
|
+
# Buttons
|
109
|
+
pause_button = Button(label=DLabel.PAUSE, id=DLayout.BUTTON_PAUSE, compact=True)
|
110
|
+
restart_button = Button(
|
111
|
+
label=DLabel.RESTART, id=DLayout.BUTTON_RESTART, compact=True
|
112
|
+
)
|
113
|
+
start_button = Button(label=DLabel.START, id=DLayout.BUTTON_START, compact=True)
|
114
|
+
quit_button = Button(label=DLabel.QUIT, id=DLayout.BUTTON_QUIT, compact=True)
|
115
|
+
defaults_button = Button(
|
116
|
+
label=DLabel.DEFAULTS, id=DLayout.BUTTON_DEFAULTS, compact=True
|
117
|
+
)
|
118
|
+
update_button = Button(label=DLabel.UPDATE, id=DLayout.BUTTON_UPDATE, compact=True)
|
119
|
+
|
120
|
+
# A dictionary to hold runtime statistics
|
121
|
+
stats = {
|
122
|
+
DField.GAME_SCORE: {
|
123
|
+
DField.GAME_NUM: [],
|
124
|
+
DField.GAME_SCORE: [],
|
125
|
+
}
|
126
|
+
}
|
127
|
+
|
128
|
+
game_score_plot = Db4EPlot(
|
129
|
+
title=DLabel.GAME_SCORE, id=DLayout.GAME_SCORE_PLOT, thin_method=Plot.SLIDING
|
130
|
+
)
|
131
|
+
|
132
|
+
def __init__(self) -> None:
|
133
|
+
super().__init__()
|
134
|
+
self.game_board = GameBoard(20, id=DLayout.GAME_BOARD)
|
135
|
+
self.snake_game = SnakeGame(game_board=self.game_board, id=DLayout.GAME_BOARD)
|
136
|
+
self.epsilon_algo = EpsilonAlgo(seed=RANDOM_SEED)
|
137
|
+
self.agent = AIAgent(self.epsilon_algo, seed=RANDOM_SEED)
|
138
|
+
self.cur_state = DField.STOPPED
|
139
|
+
self.game_score_plot._x_label = DLabel.GAME_NUM
|
140
|
+
self.game_score_plot._y_label = DLabel.GAME_SCORE
|
141
|
+
|
142
|
+
# Setup the simulator in a background thread
|
143
|
+
self.stop_event = threading.Event()
|
144
|
+
self.pause_event = threading.Event()
|
145
|
+
self.running = DField.STOPPED
|
146
|
+
self.simulator_thread = threading.Thread(target=self.start_sim, daemon=True)
|
147
|
+
|
148
|
+
async def action_quit(self) -> None:
|
149
|
+
"""Quit the application."""
|
150
|
+
self.stop_event.set()
|
151
|
+
if self.simulator_thread.is_alive():
|
152
|
+
self.simulator_thread.join(timeout=2)
|
153
|
+
await super().action_quit()
|
154
|
+
|
155
|
+
def compose(self) -> ComposeResult:
|
156
|
+
"""Create child widgets for the app."""
|
157
|
+
|
158
|
+
# Title bar
|
159
|
+
yield Label(DDef.APP_TITLE, id=DLayout.TITLE)
|
160
|
+
|
161
|
+
# Configuration Settings
|
162
|
+
yield Vertical(
|
163
|
+
Horizontal(
|
164
|
+
Label(
|
165
|
+
f"{DLabel.EPSILON_INITIAL}",
|
166
|
+
classes=DLayout.LABEL_SETTINGS,
|
167
|
+
),
|
168
|
+
self.initial_epsilon_input,
|
169
|
+
),
|
170
|
+
Horizontal(
|
171
|
+
Label(
|
172
|
+
f"{DLabel.EPSILON_DECAY}",
|
173
|
+
classes=DLayout.LABEL_SETTINGS,
|
174
|
+
),
|
175
|
+
self.epsilon_decay_input,
|
176
|
+
),
|
177
|
+
Horizontal(
|
178
|
+
Label(f"{DLabel.EPSILON_MIN}", classes=DLayout.LABEL_SETTINGS),
|
179
|
+
self.epsilon_min_input,
|
180
|
+
),
|
181
|
+
Horizontal(
|
182
|
+
Label(
|
183
|
+
f"{DLabel.MOVE_DELAY}",
|
184
|
+
classes=DLayout.LABEL_SETTINGS,
|
185
|
+
),
|
186
|
+
self.move_delay_input,
|
187
|
+
),
|
188
|
+
id=DLayout.SETTINGS_BOX,
|
189
|
+
)
|
190
|
+
|
191
|
+
# The Snake Game
|
192
|
+
yield Vertical(
|
193
|
+
self.game_board,
|
194
|
+
id=DLayout.GAME_BOX,
|
195
|
+
)
|
196
|
+
|
197
|
+
# Runtime values
|
198
|
+
yield Vertical(
|
199
|
+
Horizontal(
|
200
|
+
Label(f"{DLabel.EPSILON}", classes=DLayout.LABEL),
|
201
|
+
self.cur_epsilon_widget,
|
202
|
+
),
|
203
|
+
Horizontal(
|
204
|
+
Label(f"{DLabel.MEM_TYPE}", classes=DLayout.LABEL),
|
205
|
+
self.cur_mem_type_widget,
|
206
|
+
),
|
207
|
+
Horizontal(
|
208
|
+
Label(f"{DLabel.STORED_GAMES}", classes=DLayout.LABEL),
|
209
|
+
self.cur_num_games_widget,
|
210
|
+
),
|
211
|
+
Horizontal(
|
212
|
+
Label(f"{DLabel.MODEL_TYPE}", classes=DLayout.LABEL),
|
213
|
+
self.cur_model_type_widget,
|
214
|
+
),
|
215
|
+
Horizontal(
|
216
|
+
Label(f"{DLabel.RUNTIME}", classes=DLayout.LABEL),
|
217
|
+
self.cur_runtime_widget,
|
218
|
+
),
|
219
|
+
id=DLayout.RUNTIME_BOX,
|
220
|
+
)
|
221
|
+
|
222
|
+
# Buttons
|
223
|
+
yield Vertical(
|
224
|
+
Horizontal(
|
225
|
+
self.start_button,
|
226
|
+
self.pause_button,
|
227
|
+
self.quit_button,
|
228
|
+
classes=DLayout.BUTTON_ROW,
|
229
|
+
),
|
230
|
+
Horizontal(
|
231
|
+
self.defaults_button,
|
232
|
+
self.update_button,
|
233
|
+
self.restart_button,
|
234
|
+
classes=DLayout.BUTTON_ROW,
|
235
|
+
),
|
236
|
+
)
|
237
|
+
|
238
|
+
# Empty fillers
|
239
|
+
yield Static(id=DLayout.FILLER_1)
|
240
|
+
yield Static(id=DLayout.FILLER_2)
|
241
|
+
yield Static(id=DLayout.FILLER_3)
|
242
|
+
|
243
|
+
# The game score plot
|
244
|
+
yield self.game_score_plot
|
245
|
+
|
246
|
+
def on_mount(self):
|
247
|
+
self.initial_epsilon_input.value = str(DEpsilon.EPSILON_INITIAL)
|
248
|
+
self.epsilon_decay_input.value = str(DEpsilon.EPSILON_DECAY)
|
249
|
+
self.epsilon_min_input.value = str(DEpsilon.EPSILON_MIN)
|
250
|
+
self.move_delay_input.value = str(DDef.MOVE_DELAY)
|
251
|
+
settings_box = self.query_one(f"#{DLayout.SETTINGS_BOX}", Vertical)
|
252
|
+
settings_box.border_title = DLabel.SETTINGS
|
253
|
+
runtime_box = self.query_one(f"#{DLayout.RUNTIME_BOX}", Vertical)
|
254
|
+
runtime_box.border_title = DLabel.RUNTIME_VALUES
|
255
|
+
self.cur_mem_type_widget.update(
|
256
|
+
MEM_TYPE.MEM_TYPE_TABLE[self.agent.memory.mem_type()]
|
257
|
+
)
|
258
|
+
self.cur_num_games_widget.update(str(self.agent.memory.get_num_games()))
|
259
|
+
# Initial state is that the app is stopped
|
260
|
+
self.add_class(DField.STOPPED)
|
261
|
+
# Register the theme
|
262
|
+
self.register_theme(snake_lab_theme)
|
263
|
+
|
264
|
+
# Set the app's theme
|
265
|
+
self.theme = "db4e"
|
266
|
+
|
267
|
+
def on_quit(self):
|
268
|
+
if self.running == DField.RUNNING:
|
269
|
+
self.stop_event.set()
|
270
|
+
if self.simulator_thread.is_alive():
|
271
|
+
self.simulator_thread.join()
|
272
|
+
sys.exit(0)
|
273
|
+
|
274
|
+
def on_button_pressed(self, event: Button.Pressed) -> None:
|
275
|
+
button_id = event.button.id
|
276
|
+
|
277
|
+
# Pause button was pressed
|
278
|
+
if button_id == DLayout.BUTTON_PAUSE:
|
279
|
+
self.pause_event.set()
|
280
|
+
self.running = DField.PAUSED
|
281
|
+
self.remove_class(DField.RUNNING)
|
282
|
+
self.add_class(DField.PAUSED)
|
283
|
+
self.cur_move_delay = float(self.move_delay_input.value)
|
284
|
+
self.cur_model_type_widget.update(self.agent.model_type())
|
285
|
+
|
286
|
+
# Restart button was pressed
|
287
|
+
elif button_id == DLayout.BUTTON_RESTART:
|
288
|
+
self.running = DField.STOPPED
|
289
|
+
self.add_class(DField.STOPPED)
|
290
|
+
self.remove_class(DField.PAUSED)
|
291
|
+
|
292
|
+
# Signal thread to stop
|
293
|
+
self.stop_event.set()
|
294
|
+
# Unpause so we can exit cleanly
|
295
|
+
self.pause_event.clear()
|
296
|
+
# Join the old thread
|
297
|
+
if self.simulator_thread.is_alive():
|
298
|
+
self.simulator_thread.join(timeout=2)
|
299
|
+
|
300
|
+
# Reset the game and the UI
|
301
|
+
self.snake_game.reset()
|
302
|
+
score = 0
|
303
|
+
highscore = 0
|
304
|
+
self.epoch = 1
|
305
|
+
game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
|
306
|
+
game_box.border_title = ""
|
307
|
+
game_box.border_subtitle = ""
|
308
|
+
|
309
|
+
# Recreate events and get a new thread
|
310
|
+
self.stop_event = threading.Event()
|
311
|
+
self.pause_event = threading.Event()
|
312
|
+
self.simulator_thread = threading.Thread(target=self.start_sim, daemon=True)
|
313
|
+
|
314
|
+
# Start button was pressed
|
315
|
+
elif button_id == DLayout.BUTTON_START:
|
316
|
+
if self.running == DField.STOPPED:
|
317
|
+
self.start_thread()
|
318
|
+
elif self.running == DField.PAUSED:
|
319
|
+
self.pause_event.clear()
|
320
|
+
self.pause_event.clear()
|
321
|
+
self.running = DField.RUNNING
|
322
|
+
self.add_class(DField.RUNNING)
|
323
|
+
self.remove_class(DField.STOPPED)
|
324
|
+
self.remove_class(DField.PAUSED)
|
325
|
+
self.cur_move_delay = float(self.move_delay_input.value)
|
326
|
+
self.cur_model_type_widget.update(self.agent.model_type())
|
327
|
+
|
328
|
+
# Reset button was pressed
|
329
|
+
elif button_id == DLayout.BUTTON_DEFAULTS:
|
330
|
+
self.initial_epsilon_input.value = str(DEpsilon.EPSILON_INITIAL)
|
331
|
+
self.epsilon_decay_input.value = str(DEpsilon.EPSILON_DECAY)
|
332
|
+
self.epsilon_min_input.value = str(DEpsilon.EPSILON_MIN)
|
333
|
+
self.move_delay_input.value = str(DDef.MOVE_DELAY)
|
334
|
+
|
335
|
+
# Quit button was pressed
|
336
|
+
elif button_id == DLayout.BUTTON_QUIT:
|
337
|
+
self.on_quit()
|
338
|
+
|
339
|
+
# Update button was pressed
|
340
|
+
elif button_id == DLayout.BUTTON_UPDATE:
|
341
|
+
self.cur_move_delay = float(self.move_delay_input.value)
|
342
|
+
|
343
|
+
def start_sim(self):
|
344
|
+
self.snake_game.reset()
|
345
|
+
game_board = self.game_board
|
346
|
+
agent = self.agent
|
347
|
+
snake_game = self.snake_game
|
348
|
+
score = 0
|
349
|
+
highscore = 0
|
350
|
+
self.epoch = 1
|
351
|
+
game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
|
352
|
+
game_box.border_title = f"{DLabel.GAME} #{self.epoch}"
|
353
|
+
start_time = datetime.now()
|
354
|
+
|
355
|
+
while not self.stop_event.is_set():
|
356
|
+
if self.pause_event.is_set():
|
357
|
+
self.pause_event.wait()
|
358
|
+
time.sleep(0.2)
|
359
|
+
continue
|
360
|
+
# The actual training loop...
|
361
|
+
old_state = game_board.get_state()
|
362
|
+
move = agent.get_move(old_state)
|
363
|
+
reward, game_over, score = snake_game.play_step(move)
|
364
|
+
if score > highscore:
|
365
|
+
highscore = score
|
366
|
+
game_box.border_subtitle = (
|
367
|
+
f"{DLabel.HIGHSCORE}: {highscore}, {DLabel.SCORE}: {score}"
|
368
|
+
)
|
369
|
+
if not game_over:
|
370
|
+
## Keep playing
|
371
|
+
time.sleep(self.cur_move_delay)
|
372
|
+
new_state = game_board.get_state()
|
373
|
+
agent.train_short_memory(old_state, move, reward, new_state, game_over)
|
374
|
+
agent.remember(old_state, move, reward, new_state, game_over)
|
375
|
+
else:
|
376
|
+
## Game over
|
377
|
+
self.epoch += 1
|
378
|
+
game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
|
379
|
+
game_box.border_title = f"{DLabel.GAME} #{self.epoch}"
|
380
|
+
# Remember the last move
|
381
|
+
agent.remember(old_state, move, reward, new_state, game_over)
|
382
|
+
# Train long memory
|
383
|
+
agent.train_long_memory()
|
384
|
+
# Reset the game
|
385
|
+
snake_game.reset()
|
386
|
+
# Let the agent know we've finished a game
|
387
|
+
agent.played_game(score)
|
388
|
+
# Get the current epsilon value
|
389
|
+
cur_epsilon = self.epsilon_algo.epsilon()
|
390
|
+
if cur_epsilon < 0.0001:
|
391
|
+
self.cur_epsilon_widget.update("0.0000")
|
392
|
+
else:
|
393
|
+
self.cur_epsilon_widget.update(str(round(cur_epsilon, 4)))
|
394
|
+
# Update the number of stored memories
|
395
|
+
self.cur_num_games_widget.update(str(self.agent.memory.get_num_games()))
|
396
|
+
# Update the stats object
|
397
|
+
self.stats[DField.GAME_SCORE][DField.GAME_NUM].append(self.epoch)
|
398
|
+
self.stats[DField.GAME_SCORE][DField.GAME_SCORE].append(score)
|
399
|
+
# Update the plot object
|
400
|
+
self.game_score_plot.add_data(self.epoch, score)
|
401
|
+
self.game_score_plot.db4e_plot()
|
402
|
+
# Update the runtime widget
|
403
|
+
elapsed_secs = (datetime.now() - start_time).total_seconds()
|
404
|
+
runtime = minutes_to_uptime(elapsed_secs)
|
405
|
+
self.cur_runtime_widget.update(runtime)
|
406
|
+
|
407
|
+
def start_thread(self):
|
408
|
+
self.simulator_thread.start()
|
409
|
+
|
410
|
+
|
411
|
+
def minutes_to_uptime(seconds: int):
|
412
|
+
# Return a string like:
|
413
|
+
# 0h 0m 45s
|
414
|
+
# 1d 7h 32m
|
415
|
+
days, minutes = divmod(int(seconds), 86400)
|
416
|
+
hours, minutes = divmod(minutes, 3600)
|
417
|
+
minutes, seconds = divmod(minutes, 60)
|
418
|
+
|
419
|
+
if days > 0:
|
420
|
+
return f"{days}d {hours}h {minutes}m"
|
421
|
+
elif hours > 0:
|
422
|
+
return f"{hours}h {minutes}m"
|
423
|
+
elif minutes > 0:
|
424
|
+
return f"{minutes}m {seconds}s"
|
425
|
+
else:
|
426
|
+
return f"{seconds}s"
|
427
|
+
|
428
|
+
|
429
|
+
if __name__ == "__main__":
|
430
|
+
app = AISim()
|
431
|
+
app.run()
|
@@ -9,13 +9,14 @@ ai/Agent.py
|
|
9
9
|
"""
|
10
10
|
|
11
11
|
import torch
|
12
|
-
from ai.EpsilonAlgo import EpsilonAlgo
|
13
|
-
from ai.ReplayMemory import ReplayMemory
|
14
|
-
from ai.AITrainer import AITrainer
|
15
|
-
from ai.models.ModelL import ModelL
|
16
|
-
from ai.models.ModelRNN import ModelRNN
|
12
|
+
from ai_snake_lab.ai.EpsilonAlgo import EpsilonAlgo
|
13
|
+
from ai_snake_lab.ai.ReplayMemory import ReplayMemory
|
14
|
+
from ai_snake_lab.ai.AITrainer import AITrainer
|
15
|
+
from ai_snake_lab.ai.models.ModelL import ModelL
|
16
|
+
from ai_snake_lab.ai.models.ModelRNN import ModelRNN
|
17
17
|
|
18
|
-
from constants.DReplayMemory import MEM_TYPE
|
18
|
+
from ai_snake_lab.constants.DReplayMemory import MEM_TYPE
|
19
|
+
from ai_snake_lab.constants.DLabels import DLabel
|
19
20
|
|
20
21
|
|
21
22
|
class AIAgent:
|
@@ -23,16 +24,13 @@ class AIAgent:
|
|
23
24
|
def __init__(self, epsilon_algo: EpsilonAlgo, seed: int):
|
24
25
|
self.epsilon_algo = epsilon_algo
|
25
26
|
self.memory = ReplayMemory(seed=seed)
|
26
|
-
self.
|
27
|
-
|
28
|
-
self.trainer = AITrainer(self.
|
27
|
+
# self._model = ModelL(seed=seed)
|
28
|
+
self._model = ModelRNN(seed=seed)
|
29
|
+
self.trainer = AITrainer(model=self._model)
|
29
30
|
|
30
|
-
if type(self.
|
31
|
+
if type(self._model) == ModelRNN:
|
31
32
|
self.memory.mem_type(MEM_TYPE.RANDOM_GAME)
|
32
33
|
|
33
|
-
def get_model(self):
|
34
|
-
return self.model
|
35
|
-
|
36
34
|
def get_move(self, state):
|
37
35
|
random_move = self.epsilon_algo.get_move() # Explore with epsilon
|
38
36
|
if random_move != False:
|
@@ -42,7 +40,7 @@ class AIAgent:
|
|
42
40
|
final_move = [0, 0, 0]
|
43
41
|
if type(state) != torch.Tensor:
|
44
42
|
state = torch.tensor(state, dtype=torch.float) # Convert to a tensor
|
45
|
-
prediction = self.
|
43
|
+
prediction = self._model(state) # Get the prediction
|
46
44
|
move = torch.argmax(prediction).item() # Select the move with the highest value
|
47
45
|
final_move[move] = 1 # Set the move
|
48
46
|
return final_move # Return
|
@@ -50,6 +48,15 @@ class AIAgent:
|
|
50
48
|
def get_optimizer(self):
|
51
49
|
return self.trainer.get_optimizer()
|
52
50
|
|
51
|
+
def model_type(self):
|
52
|
+
if type(self._model) == ModelL:
|
53
|
+
return DLabel.MODEL_LINEAR
|
54
|
+
elif type(self._model) == ModelRNN:
|
55
|
+
return DLabel.MODEL_RNN
|
56
|
+
|
57
|
+
def model(self):
|
58
|
+
return self._model
|
59
|
+
|
53
60
|
def played_game(self, score):
|
54
61
|
self.epsilon_algo.played_game()
|
55
62
|
|
@@ -57,27 +64,23 @@ class AIAgent:
|
|
57
64
|
# Store the state, action, reward, next_state, and done in memory
|
58
65
|
self.memory.append((state, action, reward, next_state, done))
|
59
66
|
|
60
|
-
def set_model(self, model):
|
61
|
-
self.model = model
|
62
|
-
|
63
67
|
def set_optimizer(self, optimizer):
|
64
68
|
self.trainer.set_optimizer(optimizer)
|
65
69
|
|
66
70
|
def train_long_memory(self):
|
67
|
-
#
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
for state, action, reward, next_state, done in
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
for state, action, reward, next_state, done in memory[0]:
|
71
|
+
# Train on 5 games
|
72
|
+
max_games = 2
|
73
|
+
# Get a random full game
|
74
|
+
while max_games > 0:
|
75
|
+
max_games -= 1
|
76
|
+
game = self.memory.get_random_game()
|
77
|
+
if not game:
|
78
|
+
return # no games to train on yet
|
79
|
+
|
80
|
+
for count, (state, action, reward, next_state, done) in enumerate(
|
81
|
+
game, start=1
|
82
|
+
):
|
83
|
+
# print(f"Move #{count}: {action}")
|
81
84
|
self.trainer.train_step(state, action, reward, next_state, [done])
|
82
85
|
|
83
86
|
def train_short_memory(self, state, action, reward, next_state, done):
|
@@ -15,23 +15,27 @@ import numpy as np
|
|
15
15
|
import time
|
16
16
|
import sys
|
17
17
|
|
18
|
-
from ai.models.ModelL import ModelL
|
19
|
-
from ai.models.ModelRNN import ModelRNN
|
18
|
+
from ai_snake_lab.ai.models.ModelL import ModelL
|
19
|
+
from ai_snake_lab.ai.models.ModelRNN import ModelRNN
|
20
20
|
|
21
|
-
from constants.DModelL import DModelL
|
22
|
-
from constants.DModelLRNN import DModelRNN
|
21
|
+
from ai_snake_lab.constants.DModelL import DModelL
|
22
|
+
from ai_snake_lab.constants.DModelLRNN import DModelRNN
|
23
23
|
|
24
24
|
|
25
25
|
class AITrainer:
|
26
26
|
|
27
27
|
def __init__(self, model):
|
28
28
|
torch.manual_seed(1970)
|
29
|
-
|
29
|
+
|
30
30
|
# The learning rate needs to be adjusted for the model type
|
31
31
|
if type(model) == ModelL:
|
32
32
|
learning_rate = DModelL.LEARNING_RATE
|
33
33
|
elif type(model) == ModelRNN:
|
34
34
|
learning_rate = DModelRNN.LEARNING_RATE
|
35
|
+
else:
|
36
|
+
raise ValueError(f"Unknown model type: {type(model)}")
|
37
|
+
|
38
|
+
self.model = model
|
35
39
|
self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
|
36
40
|
self.criterion = nn.MSELoss()
|
37
41
|
self.gamma = 0.9
|