ai-snake-lab 0.1.0__tar.gz → 0.4.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/LICENSE +2 -0
  2. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/PKG-INFO +39 -5
  3. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/README.md +36 -3
  4. ai_snake_lab-0.4.3/ai_snake_lab/AISim.py +431 -0
  5. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/ai/AIAgent.py +34 -31
  6. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/ai/AITrainer.py +9 -5
  7. ai_snake_lab-0.4.3/ai_snake_lab/ai/ReplayMemory.py +127 -0
  8. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/ai/models/ModelL.py +7 -4
  9. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/ai/models/ModelRNN.py +1 -1
  10. ai_snake_lab-0.4.3/ai_snake_lab/constants/DDb4EPlot.py +20 -0
  11. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DDef.py +1 -1
  12. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DDir.py +4 -1
  13. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DFields.py +6 -0
  14. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DFile.py +2 -1
  15. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DLabels.py +24 -4
  16. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DLayout.py +16 -2
  17. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DModelL.py +4 -0
  18. ai_snake_lab-0.4.3/ai_snake_lab/constants/DSim.py +20 -0
  19. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/game/GameBoard.py +36 -22
  20. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/game/SnakeGame.py +17 -0
  21. ai_snake_lab-0.4.3/ai_snake_lab/ui/Db4EPlot.py +160 -0
  22. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/utils/AISim.tcss +81 -38
  23. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/pyproject.toml +1 -1
  24. ai_snake_lab-0.1.0/ai_snake_lab/AISim.py +0 -274
  25. ai_snake_lab-0.1.0/ai_snake_lab/ai/ReplayMemory.py +0 -90
  26. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/ai/EpsilonAlgo.py +0 -0
  27. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DEpsilon.py +0 -0
  28. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DModelLRNN.py +0 -0
  29. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/DReplayMemory.py +0 -0
  30. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/constants/__init__.py +0 -0
  31. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/game/GameElements.py +0 -0
  32. {ai_snake_lab-0.1.0 → ai_snake_lab-0.4.3}/ai_snake_lab/utils/ConstGroup.py +0 -0
@@ -672,3 +672,5 @@ may consider it more useful to permit linking proprietary applications with
672
672
  the library. If this is what you want to do, use the GNU Lesser General
673
673
  Public License instead of this License. But first, please read
674
674
  <https://www.gnu.org/licenses/why-not-lgpl.html>.
675
+
676
+
@@ -1,8 +1,9 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: ai-snake-lab
3
- Version: 0.1.0
3
+ Version: 0.4.3
4
4
  Summary: Interactive reinforcement learning sandbox for experimenting with AI agents in a classic Snake Game environment.
5
5
  License: GPL-3.0
6
+ License-File: LICENSE
6
7
  Keywords: AI,Reinforcement Learning,Textual,Snake,Simulation,SQLite,Python
7
8
  Author: Nadim-Daniel Ghaznavi
8
9
  Author-email: nghaznavi@gmail.com
@@ -63,8 +64,41 @@ Description-Content-Type: text/markdown
63
64
 
64
65
  ---
65
66
 
66
- # Links
67
+ # Installation
68
+
69
+ This project is on [PyPI](https://pypi.org/project/ai-snake-lab/). You can install the *AI Snake Lab* software using `pip`.
70
+
71
+ ## Create a Sandbox
72
+
73
+ ```shell
74
+ python3 -m venv snake_venv
75
+ . snake_venv/bin/activate
76
+ ```
77
+
78
+ ## Install the AI Snake Lab
79
+
80
+ After you have activated your *venv* environment:
81
+
82
+ ```shell
83
+ pip install ai-snake-lab
84
+ ```
85
+
86
+ ---
87
+
88
+ # Running the AI Snake Lab
89
+
90
+ From within your *venv* environment:
91
+
92
+ ```shell
93
+ ai-snake-lab
94
+ ```
95
+
96
+ ---
97
+
98
+ # Links and Acknowledgements
99
+
100
+ This code is based on a YouTube tutorial, [Python + PyTorch + Pygame Reinforcement Learning – Train an AI to Play Snake](https://www.youtube.com/watch?v=L8ypSXwyBds&t=1042s&ab_channel=freeCodeCamp.org) by Patrick Loeber. You can access his original code [here](https://github.com/patrickloeber/snake-ai-pytorch) on GitHub. Thank you Patrick!!! You are amazing!!!!
101
+
102
+ Thanks also go out to Will McGugan and the [Textual](https://textual.textualize.io/) team. Textual is an amazing framework. Talk about *rapid Application Development*. Porting this took less than a day.
67
103
 
68
- * [Project Layout](/pages/project_layout.html)
69
- * [SQLite3 Schema](/pages/db_schema.html)
70
104
  ---
@@ -27,8 +27,41 @@
27
27
 
28
28
  ---
29
29
 
30
- # Links
30
+ # Installation
31
+
32
+ This project is on [PyPI](https://pypi.org/project/ai-snake-lab/). You can install the *AI Snake Lab* software using `pip`.
33
+
34
+ ## Create a Sandbox
35
+
36
+ ```shell
37
+ python3 -m venv snake_venv
38
+ . snake_venv/bin/activate
39
+ ```
40
+
41
+ ## Install the AI Snake Lab
42
+
43
+ After you have activated your *venv* environment:
44
+
45
+ ```shell
46
+ pip install ai-snake-lab
47
+ ```
48
+
49
+ ---
50
+
51
+ # Running the AI Snake Lab
52
+
53
+ From within your *venv* environment:
54
+
55
+ ```shell
56
+ ai-snake-lab
57
+ ```
58
+
59
+ ---
60
+
61
+ # Links and Acknowledgements
62
+
63
+ This code is based on a YouTube tutorial, [Python + PyTorch + Pygame Reinforcement Learning – Train an AI to Play Snake](https://www.youtube.com/watch?v=L8ypSXwyBds&t=1042s&ab_channel=freeCodeCamp.org) by Patrick Loeber. You can access his original code [here](https://github.com/patrickloeber/snake-ai-pytorch) on GitHub. Thank you Patrick!!! You are amazing!!!!
64
+
65
+ Thanks also go out to Will McGugan and the [Textual](https://textual.textualize.io/) team. Textual is an amazing framework. Talk about *rapid Application Development*. Porting this took less than a day.
31
66
 
32
- * [Project Layout](/pages/project_layout.html)
33
- * [SQLite3 Schema](/pages/db_schema.html)
34
67
  ---
@@ -0,0 +1,431 @@
1
+ """
2
+ AISim.py
3
+
4
+ AI Snake Game Simulator
5
+ Author: Nadim-Daniel Ghaznavi
6
+ Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
7
+ GitHub: https://github.com/NadimGhaznavi/ai
8
+ License: GPL 3.0
9
+ """
10
+
11
+ import threading
12
+ import time
13
+ import sys, os
14
+ from datetime import datetime, timedelta
15
+
16
+ from textual.app import App, ComposeResult
17
+ from textual.widgets import Label, Input, Button, Static
18
+ from textual.containers import Vertical, Horizontal
19
+ from textual.reactive import var
20
+ from textual.theme import Theme
21
+
22
+ from ai_snake_lab.constants.DDef import DDef
23
+ from ai_snake_lab.constants.DEpsilon import DEpsilon
24
+ from ai_snake_lab.constants.DFields import DField
25
+ from ai_snake_lab.constants.DFile import DFile
26
+ from ai_snake_lab.constants.DLayout import DLayout
27
+ from ai_snake_lab.constants.DLabels import DLabel
28
+ from ai_snake_lab.constants.DReplayMemory import MEM_TYPE
29
+ from ai_snake_lab.constants.DDir import DDir
30
+ from ai_snake_lab.constants.DDb4EPlot import Plot
31
+
32
+
33
+ from ai_snake_lab.ai.AIAgent import AIAgent
34
+ from ai_snake_lab.ai.EpsilonAlgo import EpsilonAlgo
35
+ from ai_snake_lab.game.GameBoard import GameBoard
36
+ from ai_snake_lab.game.SnakeGame import SnakeGame
37
+ from ai_snake_lab.ui.Db4EPlot import Db4EPlot
38
+
39
+ RANDOM_SEED = 1970
40
+
41
+ snake_lab_theme = Theme(
42
+ name="db4e",
43
+ primary="#88C0D0",
44
+ secondary="#1f6a83ff",
45
+ accent="#B48EAD",
46
+ foreground="#31b8e6",
47
+ background="black",
48
+ success="#A3BE8C",
49
+ warning="#EBCB8B",
50
+ error="#BF616A",
51
+ surface="black",
52
+ panel="#000000",
53
+ dark=True,
54
+ variables={
55
+ "block-cursor-text-style": "none",
56
+ "footer-key-foreground": "#88C0D0",
57
+ "input-selection-background": "#81a1c1 35%",
58
+ },
59
+ )
60
+
61
+
62
+ class AISim(App):
63
+ """A Textual app that has an AI Agent playing the Snake Game."""
64
+
65
+ TITLE = DDef.APP_TITLE
66
+ CSS_PATH = os.path.join(DDir.UTILS, DFile.CSS_FILE)
67
+
68
+ ## Runtime values
69
+ # Current epsilon value (degrades in real-time)
70
+ cur_epsilon_widget = Label("N/A", id=DLayout.CUR_EPSILON)
71
+ # Current memory type
72
+ cur_mem_type_widget = Label("N/A", id=DLayout.CUR_MEM_TYPE)
73
+ # Current model type
74
+ cur_model_type_widget = Label("N/A", id=DLayout.CUR_MODEL_TYPE)
75
+ # Time delay between moves
76
+ cur_move_delay = DDef.MOVE_DELAY
77
+ # Number of stored games in the ReplayMemory
78
+ cur_num_games_widget = Label("N/A", id=DLayout.NUM_GAMES)
79
+ # Elapsed time
80
+ cur_runtime_widget = Label("N/A", id=DLayout.RUNTIME)
81
+
82
+ # Intial Settings for Epsilon
83
+ initial_epsilon_input = Input(
84
+ restrict=f"0.[0-9]*",
85
+ compact=True,
86
+ id=DLayout.EPSILON_INITIAL,
87
+ classes=DLayout.INPUT_10,
88
+ )
89
+ epsilon_min_input = Input(
90
+ restrict=f"0.[0-9]*",
91
+ compact=True,
92
+ id=DLayout.EPSILON_MIN,
93
+ classes=DLayout.INPUT_10,
94
+ )
95
+ epsilon_decay_input = Input(
96
+ restrict=f"0.[0-9]*",
97
+ compact=True,
98
+ id=DLayout.EPSILON_DECAY,
99
+ classes=DLayout.INPUT_10,
100
+ )
101
+ move_delay_input = Input(
102
+ restrict=f"[0-9]*.[0-9]*",
103
+ compact=True,
104
+ id=DLayout.MOVE_DELAY,
105
+ classes=DLayout.INPUT_10,
106
+ )
107
+
108
+ # Buttons
109
+ pause_button = Button(label=DLabel.PAUSE, id=DLayout.BUTTON_PAUSE, compact=True)
110
+ restart_button = Button(
111
+ label=DLabel.RESTART, id=DLayout.BUTTON_RESTART, compact=True
112
+ )
113
+ start_button = Button(label=DLabel.START, id=DLayout.BUTTON_START, compact=True)
114
+ quit_button = Button(label=DLabel.QUIT, id=DLayout.BUTTON_QUIT, compact=True)
115
+ defaults_button = Button(
116
+ label=DLabel.DEFAULTS, id=DLayout.BUTTON_DEFAULTS, compact=True
117
+ )
118
+ update_button = Button(label=DLabel.UPDATE, id=DLayout.BUTTON_UPDATE, compact=True)
119
+
120
+ # A dictionary to hold runtime statistics
121
+ stats = {
122
+ DField.GAME_SCORE: {
123
+ DField.GAME_NUM: [],
124
+ DField.GAME_SCORE: [],
125
+ }
126
+ }
127
+
128
+ game_score_plot = Db4EPlot(
129
+ title=DLabel.GAME_SCORE, id=DLayout.GAME_SCORE_PLOT, thin_method=Plot.SLIDING
130
+ )
131
+
132
+ def __init__(self) -> None:
133
+ super().__init__()
134
+ self.game_board = GameBoard(20, id=DLayout.GAME_BOARD)
135
+ self.snake_game = SnakeGame(game_board=self.game_board, id=DLayout.GAME_BOARD)
136
+ self.epsilon_algo = EpsilonAlgo(seed=RANDOM_SEED)
137
+ self.agent = AIAgent(self.epsilon_algo, seed=RANDOM_SEED)
138
+ self.cur_state = DField.STOPPED
139
+ self.game_score_plot._x_label = DLabel.GAME_NUM
140
+ self.game_score_plot._y_label = DLabel.GAME_SCORE
141
+
142
+ # Setup the simulator in a background thread
143
+ self.stop_event = threading.Event()
144
+ self.pause_event = threading.Event()
145
+ self.running = DField.STOPPED
146
+ self.simulator_thread = threading.Thread(target=self.start_sim, daemon=True)
147
+
148
+ async def action_quit(self) -> None:
149
+ """Quit the application."""
150
+ self.stop_event.set()
151
+ if self.simulator_thread.is_alive():
152
+ self.simulator_thread.join(timeout=2)
153
+ await super().action_quit()
154
+
155
+ def compose(self) -> ComposeResult:
156
+ """Create child widgets for the app."""
157
+
158
+ # Title bar
159
+ yield Label(DDef.APP_TITLE, id=DLayout.TITLE)
160
+
161
+ # Configuration Settings
162
+ yield Vertical(
163
+ Horizontal(
164
+ Label(
165
+ f"{DLabel.EPSILON_INITIAL}",
166
+ classes=DLayout.LABEL_SETTINGS,
167
+ ),
168
+ self.initial_epsilon_input,
169
+ ),
170
+ Horizontal(
171
+ Label(
172
+ f"{DLabel.EPSILON_DECAY}",
173
+ classes=DLayout.LABEL_SETTINGS,
174
+ ),
175
+ self.epsilon_decay_input,
176
+ ),
177
+ Horizontal(
178
+ Label(f"{DLabel.EPSILON_MIN}", classes=DLayout.LABEL_SETTINGS),
179
+ self.epsilon_min_input,
180
+ ),
181
+ Horizontal(
182
+ Label(
183
+ f"{DLabel.MOVE_DELAY}",
184
+ classes=DLayout.LABEL_SETTINGS,
185
+ ),
186
+ self.move_delay_input,
187
+ ),
188
+ id=DLayout.SETTINGS_BOX,
189
+ )
190
+
191
+ # The Snake Game
192
+ yield Vertical(
193
+ self.game_board,
194
+ id=DLayout.GAME_BOX,
195
+ )
196
+
197
+ # Runtime values
198
+ yield Vertical(
199
+ Horizontal(
200
+ Label(f"{DLabel.EPSILON}", classes=DLayout.LABEL),
201
+ self.cur_epsilon_widget,
202
+ ),
203
+ Horizontal(
204
+ Label(f"{DLabel.MEM_TYPE}", classes=DLayout.LABEL),
205
+ self.cur_mem_type_widget,
206
+ ),
207
+ Horizontal(
208
+ Label(f"{DLabel.STORED_GAMES}", classes=DLayout.LABEL),
209
+ self.cur_num_games_widget,
210
+ ),
211
+ Horizontal(
212
+ Label(f"{DLabel.MODEL_TYPE}", classes=DLayout.LABEL),
213
+ self.cur_model_type_widget,
214
+ ),
215
+ Horizontal(
216
+ Label(f"{DLabel.RUNTIME}", classes=DLayout.LABEL),
217
+ self.cur_runtime_widget,
218
+ ),
219
+ id=DLayout.RUNTIME_BOX,
220
+ )
221
+
222
+ # Buttons
223
+ yield Vertical(
224
+ Horizontal(
225
+ self.start_button,
226
+ self.pause_button,
227
+ self.quit_button,
228
+ classes=DLayout.BUTTON_ROW,
229
+ ),
230
+ Horizontal(
231
+ self.defaults_button,
232
+ self.update_button,
233
+ self.restart_button,
234
+ classes=DLayout.BUTTON_ROW,
235
+ ),
236
+ )
237
+
238
+ # Empty fillers
239
+ yield Static(id=DLayout.FILLER_1)
240
+ yield Static(id=DLayout.FILLER_2)
241
+ yield Static(id=DLayout.FILLER_3)
242
+
243
+ # The game score plot
244
+ yield self.game_score_plot
245
+
246
+ def on_mount(self):
247
+ self.initial_epsilon_input.value = str(DEpsilon.EPSILON_INITIAL)
248
+ self.epsilon_decay_input.value = str(DEpsilon.EPSILON_DECAY)
249
+ self.epsilon_min_input.value = str(DEpsilon.EPSILON_MIN)
250
+ self.move_delay_input.value = str(DDef.MOVE_DELAY)
251
+ settings_box = self.query_one(f"#{DLayout.SETTINGS_BOX}", Vertical)
252
+ settings_box.border_title = DLabel.SETTINGS
253
+ runtime_box = self.query_one(f"#{DLayout.RUNTIME_BOX}", Vertical)
254
+ runtime_box.border_title = DLabel.RUNTIME_VALUES
255
+ self.cur_mem_type_widget.update(
256
+ MEM_TYPE.MEM_TYPE_TABLE[self.agent.memory.mem_type()]
257
+ )
258
+ self.cur_num_games_widget.update(str(self.agent.memory.get_num_games()))
259
+ # Initial state is that the app is stopped
260
+ self.add_class(DField.STOPPED)
261
+ # Register the theme
262
+ self.register_theme(snake_lab_theme)
263
+
264
+ # Set the app's theme
265
+ self.theme = "db4e"
266
+
267
+ def on_quit(self):
268
+ if self.running == DField.RUNNING:
269
+ self.stop_event.set()
270
+ if self.simulator_thread.is_alive():
271
+ self.simulator_thread.join()
272
+ sys.exit(0)
273
+
274
+ def on_button_pressed(self, event: Button.Pressed) -> None:
275
+ button_id = event.button.id
276
+
277
+ # Pause button was pressed
278
+ if button_id == DLayout.BUTTON_PAUSE:
279
+ self.pause_event.set()
280
+ self.running = DField.PAUSED
281
+ self.remove_class(DField.RUNNING)
282
+ self.add_class(DField.PAUSED)
283
+ self.cur_move_delay = float(self.move_delay_input.value)
284
+ self.cur_model_type_widget.update(self.agent.model_type())
285
+
286
+ # Restart button was pressed
287
+ elif button_id == DLayout.BUTTON_RESTART:
288
+ self.running = DField.STOPPED
289
+ self.add_class(DField.STOPPED)
290
+ self.remove_class(DField.PAUSED)
291
+
292
+ # Signal thread to stop
293
+ self.stop_event.set()
294
+ # Unpause so we can exit cleanly
295
+ self.pause_event.clear()
296
+ # Join the old thread
297
+ if self.simulator_thread.is_alive():
298
+ self.simulator_thread.join(timeout=2)
299
+
300
+ # Reset the game and the UI
301
+ self.snake_game.reset()
302
+ score = 0
303
+ highscore = 0
304
+ self.epoch = 1
305
+ game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
306
+ game_box.border_title = ""
307
+ game_box.border_subtitle = ""
308
+
309
+ # Recreate events and get a new thread
310
+ self.stop_event = threading.Event()
311
+ self.pause_event = threading.Event()
312
+ self.simulator_thread = threading.Thread(target=self.start_sim, daemon=True)
313
+
314
+ # Start button was pressed
315
+ elif button_id == DLayout.BUTTON_START:
316
+ if self.running == DField.STOPPED:
317
+ self.start_thread()
318
+ elif self.running == DField.PAUSED:
319
+ self.pause_event.clear()
320
+ self.pause_event.clear()
321
+ self.running = DField.RUNNING
322
+ self.add_class(DField.RUNNING)
323
+ self.remove_class(DField.STOPPED)
324
+ self.remove_class(DField.PAUSED)
325
+ self.cur_move_delay = float(self.move_delay_input.value)
326
+ self.cur_model_type_widget.update(self.agent.model_type())
327
+
328
+ # Reset button was pressed
329
+ elif button_id == DLayout.BUTTON_DEFAULTS:
330
+ self.initial_epsilon_input.value = str(DEpsilon.EPSILON_INITIAL)
331
+ self.epsilon_decay_input.value = str(DEpsilon.EPSILON_DECAY)
332
+ self.epsilon_min_input.value = str(DEpsilon.EPSILON_MIN)
333
+ self.move_delay_input.value = str(DDef.MOVE_DELAY)
334
+
335
+ # Quit button was pressed
336
+ elif button_id == DLayout.BUTTON_QUIT:
337
+ self.on_quit()
338
+
339
+ # Update button was pressed
340
+ elif button_id == DLayout.BUTTON_UPDATE:
341
+ self.cur_move_delay = float(self.move_delay_input.value)
342
+
343
+ def start_sim(self):
344
+ self.snake_game.reset()
345
+ game_board = self.game_board
346
+ agent = self.agent
347
+ snake_game = self.snake_game
348
+ score = 0
349
+ highscore = 0
350
+ self.epoch = 1
351
+ game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
352
+ game_box.border_title = f"{DLabel.GAME} #{self.epoch}"
353
+ start_time = datetime.now()
354
+
355
+ while not self.stop_event.is_set():
356
+ if self.pause_event.is_set():
357
+ self.pause_event.wait()
358
+ time.sleep(0.2)
359
+ continue
360
+ # The actual training loop...
361
+ old_state = game_board.get_state()
362
+ move = agent.get_move(old_state)
363
+ reward, game_over, score = snake_game.play_step(move)
364
+ if score > highscore:
365
+ highscore = score
366
+ game_box.border_subtitle = (
367
+ f"{DLabel.HIGHSCORE}: {highscore}, {DLabel.SCORE}: {score}"
368
+ )
369
+ if not game_over:
370
+ ## Keep playing
371
+ time.sleep(self.cur_move_delay)
372
+ new_state = game_board.get_state()
373
+ agent.train_short_memory(old_state, move, reward, new_state, game_over)
374
+ agent.remember(old_state, move, reward, new_state, game_over)
375
+ else:
376
+ ## Game over
377
+ self.epoch += 1
378
+ game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
379
+ game_box.border_title = f"{DLabel.GAME} #{self.epoch}"
380
+ # Remember the last move
381
+ agent.remember(old_state, move, reward, new_state, game_over)
382
+ # Train long memory
383
+ agent.train_long_memory()
384
+ # Reset the game
385
+ snake_game.reset()
386
+ # Let the agent know we've finished a game
387
+ agent.played_game(score)
388
+ # Get the current epsilon value
389
+ cur_epsilon = self.epsilon_algo.epsilon()
390
+ if cur_epsilon < 0.0001:
391
+ self.cur_epsilon_widget.update("0.0000")
392
+ else:
393
+ self.cur_epsilon_widget.update(str(round(cur_epsilon, 4)))
394
+ # Update the number of stored memories
395
+ self.cur_num_games_widget.update(str(self.agent.memory.get_num_games()))
396
+ # Update the stats object
397
+ self.stats[DField.GAME_SCORE][DField.GAME_NUM].append(self.epoch)
398
+ self.stats[DField.GAME_SCORE][DField.GAME_SCORE].append(score)
399
+ # Update the plot object
400
+ self.game_score_plot.add_data(self.epoch, score)
401
+ self.game_score_plot.db4e_plot()
402
+ # Update the runtime widget
403
+ elapsed_secs = (datetime.now() - start_time).total_seconds()
404
+ runtime = minutes_to_uptime(elapsed_secs)
405
+ self.cur_runtime_widget.update(runtime)
406
+
407
+ def start_thread(self):
408
+ self.simulator_thread.start()
409
+
410
+
411
+ def minutes_to_uptime(seconds: int):
412
+ # Return a string like:
413
+ # 0h 0m 45s
414
+ # 1d 7h 32m
415
+ days, minutes = divmod(int(seconds), 86400)
416
+ hours, minutes = divmod(minutes, 3600)
417
+ minutes, seconds = divmod(minutes, 60)
418
+
419
+ if days > 0:
420
+ return f"{days}d {hours}h {minutes}m"
421
+ elif hours > 0:
422
+ return f"{hours}h {minutes}m"
423
+ elif minutes > 0:
424
+ return f"{minutes}m {seconds}s"
425
+ else:
426
+ return f"{seconds}s"
427
+
428
+
429
+ if __name__ == "__main__":
430
+ app = AISim()
431
+ app.run()
@@ -9,13 +9,14 @@ ai/Agent.py
9
9
  """
10
10
 
11
11
  import torch
12
- from ai.EpsilonAlgo import EpsilonAlgo
13
- from ai.ReplayMemory import ReplayMemory
14
- from ai.AITrainer import AITrainer
15
- from ai.models.ModelL import ModelL
16
- from ai.models.ModelRNN import ModelRNN
12
+ from ai_snake_lab.ai.EpsilonAlgo import EpsilonAlgo
13
+ from ai_snake_lab.ai.ReplayMemory import ReplayMemory
14
+ from ai_snake_lab.ai.AITrainer import AITrainer
15
+ from ai_snake_lab.ai.models.ModelL import ModelL
16
+ from ai_snake_lab.ai.models.ModelRNN import ModelRNN
17
17
 
18
- from constants.DReplayMemory import MEM_TYPE
18
+ from ai_snake_lab.constants.DReplayMemory import MEM_TYPE
19
+ from ai_snake_lab.constants.DLabels import DLabel
19
20
 
20
21
 
21
22
  class AIAgent:
@@ -23,16 +24,13 @@ class AIAgent:
23
24
  def __init__(self, epsilon_algo: EpsilonAlgo, seed: int):
24
25
  self.epsilon_algo = epsilon_algo
25
26
  self.memory = ReplayMemory(seed=seed)
26
- self.model = ModelL(seed=seed)
27
- # self.model = ModelRNN(seed=seed)
28
- self.trainer = AITrainer(self.model)
27
+ # self._model = ModelL(seed=seed)
28
+ self._model = ModelRNN(seed=seed)
29
+ self.trainer = AITrainer(model=self._model)
29
30
 
30
- if type(self.model) == ModelRNN:
31
+ if type(self._model) == ModelRNN:
31
32
  self.memory.mem_type(MEM_TYPE.RANDOM_GAME)
32
33
 
33
- def get_model(self):
34
- return self.model
35
-
36
34
  def get_move(self, state):
37
35
  random_move = self.epsilon_algo.get_move() # Explore with epsilon
38
36
  if random_move != False:
@@ -42,7 +40,7 @@ class AIAgent:
42
40
  final_move = [0, 0, 0]
43
41
  if type(state) != torch.Tensor:
44
42
  state = torch.tensor(state, dtype=torch.float) # Convert to a tensor
45
- prediction = self.model(state) # Get the prediction
43
+ prediction = self._model(state) # Get the prediction
46
44
  move = torch.argmax(prediction).item() # Select the move with the highest value
47
45
  final_move[move] = 1 # Set the move
48
46
  return final_move # Return
@@ -50,6 +48,15 @@ class AIAgent:
50
48
  def get_optimizer(self):
51
49
  return self.trainer.get_optimizer()
52
50
 
51
+ def model_type(self):
52
+ if type(self._model) == ModelL:
53
+ return DLabel.MODEL_LINEAR
54
+ elif type(self._model) == ModelRNN:
55
+ return DLabel.MODEL_RNN
56
+
57
+ def model(self):
58
+ return self._model
59
+
53
60
  def played_game(self, score):
54
61
  self.epsilon_algo.played_game()
55
62
 
@@ -57,27 +64,23 @@ class AIAgent:
57
64
  # Store the state, action, reward, next_state, and done in memory
58
65
  self.memory.append((state, action, reward, next_state, done))
59
66
 
60
- def set_model(self, model):
61
- self.model = model
62
-
63
67
  def set_optimizer(self, optimizer):
64
68
  self.trainer.set_optimizer(optimizer)
65
69
 
66
70
  def train_long_memory(self):
67
- # Get the states, actions, rewards, next_states, and dones from the mini_sample
68
- memory = self.memory.get_memory()
69
- memory_type = self.memory.mem_type()
70
-
71
- if type(self.model) == ModelRNN:
72
- for state, action, reward, next_state, done in memory[0]:
73
- self.trainer.train_step(state, action, reward, next_state, [done])
74
-
75
- elif memory_type == MEM_TYPE.SHUFFLE:
76
- for state, action, reward, next_state, done in memory:
77
- self.trainer.train_step(state, action, reward, next_state, [done])
78
-
79
- else:
80
- for state, action, reward, next_state, done in memory[0]:
71
+ # Train on 5 games
72
+ max_games = 2
73
+ # Get a random full game
74
+ while max_games > 0:
75
+ max_games -= 1
76
+ game = self.memory.get_random_game()
77
+ if not game:
78
+ return # no games to train on yet
79
+
80
+ for count, (state, action, reward, next_state, done) in enumerate(
81
+ game, start=1
82
+ ):
83
+ # print(f"Move #{count}: {action}")
81
84
  self.trainer.train_step(state, action, reward, next_state, [done])
82
85
 
83
86
  def train_short_memory(self, state, action, reward, next_state, done):
@@ -15,23 +15,27 @@ import numpy as np
15
15
  import time
16
16
  import sys
17
17
 
18
- from ai.models.ModelL import ModelL
19
- from ai.models.ModelRNN import ModelRNN
18
+ from ai_snake_lab.ai.models.ModelL import ModelL
19
+ from ai_snake_lab.ai.models.ModelRNN import ModelRNN
20
20
 
21
- from constants.DModelL import DModelL
22
- from constants.DModelLRNN import DModelRNN
21
+ from ai_snake_lab.constants.DModelL import DModelL
22
+ from ai_snake_lab.constants.DModelLRNN import DModelRNN
23
23
 
24
24
 
25
25
  class AITrainer:
26
26
 
27
27
  def __init__(self, model):
28
28
  torch.manual_seed(1970)
29
- self.model = model
29
+
30
30
  # The learning rate needs to be adjusted for the model type
31
31
  if type(model) == ModelL:
32
32
  learning_rate = DModelL.LEARNING_RATE
33
33
  elif type(model) == ModelRNN:
34
34
  learning_rate = DModelRNN.LEARNING_RATE
35
+ else:
36
+ raise ValueError(f"Unknown model type: {type(model)}")
37
+
38
+ self.model = model
35
39
  self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
36
40
  self.criterion = nn.MSELoss()
37
41
  self.gamma = 0.9