ai-snake-lab 0.4.8__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/PKG-INFO +33 -4
  2. ai_snake_lab-0.5.0/README.md +96 -0
  3. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/ai/AIAgent.py +10 -16
  4. ai_snake_lab-0.5.0/ai_snake_lab/ai/ReplayMemory.py +254 -0
  5. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/ai/models/ModelRNN.py +9 -6
  6. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/constants/DDef.py +1 -1
  7. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/constants/DLabels.py +2 -0
  8. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/constants/DLayout.py +5 -0
  9. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/constants/DModelLRNN.py +3 -3
  10. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/constants/DReplayMemory.py +13 -4
  11. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/constants/DSim.py +1 -1
  12. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/game/GameBoard.py +117 -0
  13. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/game/SnakeGame.py +3 -3
  14. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/ui/AISim.py +49 -16
  15. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/ui/AISim.tcss +16 -3
  16. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/pyproject.toml +1 -1
  17. ai_snake_lab-0.4.8/README.md +0 -67
  18. ai_snake_lab-0.4.8/ai_snake_lab/ai/ReplayMemory.py +0 -148
  19. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/LICENSE +0 -0
  20. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/ai/AITrainer.py +0 -0
  21. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/ai/EpsilonAlgo.py +0 -0
  22. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/ai/models/ModelL.py +0 -0
  23. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/constants/DDb4EPlot.py +0 -0
  24. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/constants/DDir.py +0 -0
  25. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/constants/DEpsilon.py +0 -0
  26. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/constants/DFields.py +0 -0
  27. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/constants/DFile.py +0 -0
  28. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/constants/DModelL.py +0 -0
  29. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/constants/__init__.py +0 -0
  30. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/game/GameElements.py +0 -0
  31. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/ui/Db4EPlot.py +0 -0
  32. {ai_snake_lab-0.4.8 → ai_snake_lab-0.5.0}/ai_snake_lab/utils/ConstGroup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ai-snake-lab
3
- Version: 0.4.8
3
+ Version: 0.5.0
4
4
  Summary: Interactive reinforcement learning sandbox for experimenting with AI agents in a classic Snake Game environment.
5
5
  License: GPL-3.0
6
6
  License-File: LICENSE
@@ -35,6 +35,10 @@ Project-URL: Documentation, https://snakelab.osoyalce.com/
35
35
  Project-URL: Source, https://github.com/NadimGhaznavi/ai_snake_lab
36
36
  Description-Content-Type: text/markdown
37
37
 
38
+ # AI Snake Lab
39
+
40
+ ---
41
+
38
42
  # Introduction
39
43
 
40
44
  **AI Snake Lab** is an interactive reinforcement learning sandbox for experimenting with AI agents in a classic Snake Game environment — featuring a live Textual TUI interface, flexible replay memory database, and modular model definitions.
@@ -95,10 +99,35 @@ ai-snake-lab
95
99
 
96
100
  ---
97
101
 
98
- # Links and Acknowledgements
102
+ # Technical Docs
103
+
104
+ - [Database Schema Documentation](/pages/db_schema.html)
105
+ - [Project Layout](/pages/project_layout.html)
106
+
107
+ ---
108
+
109
+ # Acknowledgements
99
110
 
100
- This code is based on a YouTube tutorial, [Python + PyTorch + Pygame Reinforcement Learning – Train an AI to Play Snake](https://www.youtube.com/watch?v=L8ypSXwyBds&t=1042s&ab_channel=freeCodeCamp.org) by Patrick Loeber. You can access his original code [here](https://github.com/patrickloeber/snake-ai-pytorch) on GitHub. Thank you Patrick!!! You are amazing!!!!
111
+ The original code for this project was based on a YouTube tutorial, [Python + PyTorch + Pygame Reinforcement Learning – Train an AI to Play Snake](https://www.youtube.com/watch?v=L8ypSXwyBds&t=1042s&ab_channel=freeCodeCamp.org) by Patrick Loeber. You can access his original code [here](https://github.com/patrickloeber/snake-ai-pytorch) on GitHub. Thank you Patrick!!! You are amazing!!!! This project is a port of the pygame and matplotlib solution.
101
112
 
102
- Thanks also go out to Will McGugan and the [Textual](https://textual.textualize.io/) team. Textual is an amazing framework. Talk about *rapid Application Development*. Porting this took less than a day.
113
+ Thanks also go out to Will McGugan and the [Textual](https://textual.textualize.io/) team. Textual is an amazing framework. Talk about *Rapid Application Development*. Porting this from a Pygame and MatPlotLib solution to Textual took less than a day.
103
114
 
104
115
  ---
116
+
117
+ # Inspiration
118
+
119
+ Creating an artificial intelligence agent, letting it loose and watching how it performs is an amazing process. It's not unlike having children, except on a much, much, much smaller scale, at least today! Watching the AI driven Snake Game is mesmerizing. I'm constantly thinking of ways I could improve it. I credit Patrick Loeber for giving me a fun project to explore the AI space.
120
+
121
+ Much of my career has been as a Linux Systems administrator. My comfort zone is on the command line. I've never worked as a programmer and certainly not as a front end developer. [Textual](https://textual.textualize.io/), as a framework for building rich *Terminal User Interfaces* is exactly my speed and when I saw [Dolphie](https://github.com/charles-001/dolphie), I was blown away. Built-in, real-time plots of MySQL metrics: Amazing!
122
+
123
+ Richard S. Sutton is also an inspiration to me. His thoughts on *Reinforcement Learning* are a slow motion revolution. His criticisms of the existing AI landscape with it's focus on engineering a specific AI to do a specific task and then considering the job done is spot on. His vision for an AI agent that does continuous, non-linear learning remains the next frontier on the path to *General Artificial Intelligence*.
124
+
125
+ ---
126
+
127
+ # Links
128
+
129
+ - Patrick Loeber's [YouTube Tutorial](https://www.youtube.com/watch?v=L8ypSXwyBds&t=1042s&ab_channel=freeCodeCamp.org)
130
+ - Will McGugan's [Textual](https://textual.textualize.io/) *Rapid Application Development* framework
131
+ - [Dolphie](https://github.com/charles-001/dolphie): *A single pane of glass for real-time analytics into MySQL/MariaDB & ProxySQL*
132
+ - Richard Sutton's [Homepage](http://www.incompleteideas.net/)
133
+ - Richard Sutton [quotes](/pages/richard-sutton.html) and other materials.
@@ -0,0 +1,96 @@
1
+ # AI Snake Lab
2
+
3
+ ---
4
+
5
+ # Introduction
6
+
7
+ **AI Snake Lab** is an interactive reinforcement learning sandbox for experimenting with AI agents in a classic Snake Game environment — featuring a live Textual TUI interface, flexible replay memory database, and modular model definitions.
8
+
9
+ ---
10
+
11
+ # 🚀 Features
12
+
13
+ - 🐍 **Classic Snake environment** with customizable grid and rules
14
+ - 🧠 **AI agent interface** supporting multiple architectures (Linear, RNN, CNN)
15
+ - 🎮 **Textual-based simulator** for live visualization and metrics
16
+ - 💾 **SQLite-backed replay memory** for storing frames, episodes, and runs
17
+ - 🧩 **Experiment metadata tracking** — models, hyperparameters, state-map versions
18
+ - 📊 **Built-in plotting** for hashrate, scores, and learning progress
19
+
20
+ ---
21
+
22
+ # 🧰 Tech Stack
23
+
24
+ | Component | Description |
25
+ |------------|--------------|
26
+ | **Python 3.11+** | Core language |
27
+ | **Textual** | Terminal UI framework |
28
+ | **SQLite3** | Lightweight replay memory + experiment store |
29
+ | **PyTorch** *(optional)* | Deep learning backend for models |
30
+ | **Plotext / Matplotlib** | Visualization tools |
31
+
32
+ ---
33
+
34
+ # Installation
35
+
36
+ This project is on [PyPI](https://pypi.org/project/ai-snake-lab/). You can install the *AI Snake Lab* software using `pip`.
37
+
38
+ ## Create a Sandbox
39
+
40
+ ```shell
41
+ python3 -m venv snake_venv
42
+ . snake_venv/bin/activate
43
+ ```
44
+
45
+ ## Install the AI Snake Lab
46
+
47
+ After you have activated your *venv* environment:
48
+
49
+ ```shell
50
+ pip install ai-snake-lab
51
+ ```
52
+
53
+ ---
54
+
55
+ # Running the AI Snake Lab
56
+
57
+ From within your *venv* environment:
58
+
59
+ ```shell
60
+ ai-snake-lab
61
+ ```
62
+
63
+ ---
64
+
65
+ # Technical Docs
66
+
67
+ - [Database Schema Documentation](/pages/db_schema.html)
68
+ - [Project Layout](/pages/project_layout.html)
69
+
70
+ ---
71
+
72
+ # Acknowledgements
73
+
74
+ The original code for this project was based on a YouTube tutorial, [Python + PyTorch + Pygame Reinforcement Learning – Train an AI to Play Snake](https://www.youtube.com/watch?v=L8ypSXwyBds&t=1042s&ab_channel=freeCodeCamp.org) by Patrick Loeber. You can access his original code [here](https://github.com/patrickloeber/snake-ai-pytorch) on GitHub. Thank you Patrick!!! You are amazing!!!! This project is a port of the pygame and matplotlib solution.
75
+
76
+ Thanks also go out to Will McGugan and the [Textual](https://textual.textualize.io/) team. Textual is an amazing framework. Talk about *Rapid Application Development*. Porting this from a Pygame and MatPlotLib solution to Textual took less than a day.
77
+
78
+ ---
79
+
80
+ # Inspiration
81
+
82
+ Creating an artificial intelligence agent, letting it loose and watching how it performs is an amazing process. It's not unlike having children, except on a much, much, much smaller scale, at least today! Watching the AI driven Snake Game is mesmerizing. I'm constantly thinking of ways I could improve it. I credit Patrick Loeber for giving me a fun project to explore the AI space.
83
+
84
+ Much of my career has been as a Linux Systems administrator. My comfort zone is on the command line. I've never worked as a programmer and certainly not as a front end developer. [Textual](https://textual.textualize.io/), as a framework for building rich *Terminal User Interfaces* is exactly my speed and when I saw [Dolphie](https://github.com/charles-001/dolphie), I was blown away. Built-in, real-time plots of MySQL metrics: Amazing!
85
+
86
+ Richard S. Sutton is also an inspiration to me. His thoughts on *Reinforcement Learning* are a slow motion revolution. His criticisms of the existing AI landscape with it's focus on engineering a specific AI to do a specific task and then considering the job done is spot on. His vision for an AI agent that does continuous, non-linear learning remains the next frontier on the path to *General Artificial Intelligence*.
87
+
88
+ ---
89
+
90
+ # Links
91
+
92
+ - Patrick Loeber's [YouTube Tutorial](https://www.youtube.com/watch?v=L8ypSXwyBds&t=1042s&ab_channel=freeCodeCamp.org)
93
+ - Will McGugan's [Textual](https://textual.textualize.io/) *Rapid Application Development* framework
94
+ - [Dolphie](https://github.com/charles-001/dolphie): *A single pane of glass for real-time analytics into MySQL/MariaDB & ProxySQL*
95
+ - Richard Sutton's [Homepage](http://www.incompleteideas.net/)
96
+ - Richard Sutton [quotes](/pages/richard-sutton.html) and other materials.
@@ -60,28 +60,22 @@ class AIAgent:
60
60
  def played_game(self, score):
61
61
  self.epsilon_algo.played_game()
62
62
 
63
- def remember(self, state, action, reward, next_state, done):
63
+ def remember(self, state, action, reward, next_state, done, score=None):
64
64
  # Store the state, action, reward, next_state, and done in memory
65
- self.memory.append((state, action, reward, next_state, done))
65
+ self.memory.append((state, action, reward, next_state, done, score))
66
66
 
67
67
  def set_optimizer(self, optimizer):
68
68
  self.trainer.set_optimizer(optimizer)
69
69
 
70
70
  def train_long_memory(self):
71
- # Train on 5 games
72
- max_games = 2
73
- # Get a random full game
74
- while max_games > 0:
75
- max_games -= 1
76
- game = self.memory.get_random_game()
77
- if not game:
78
- return # no games to train on yet
79
-
80
- for count, (state, action, reward, next_state, done) in enumerate(
81
- game, start=1
82
- ):
83
- # print(f"Move #{count}: {action}")
84
- self.trainer.train_step(state, action, reward, next_state, [done])
71
+ # Ask ReplayMemory for data
72
+ training_data = self.memory.get_training_data(n_games=1)
73
+ if not training_data:
74
+ return # either no memory or user chose None
75
+
76
+ for state, action, reward, next_state, done, *_ in training_data:
77
+ self.trainer.train_step(state, action, reward, next_state, [done])
85
78
 
86
79
  def train_short_memory(self, state, action, reward, next_state, done):
80
+ # Always train on the current frame
87
81
  self.trainer.train_step(state, action, reward, next_state, [done])
@@ -0,0 +1,254 @@
1
+ """
2
+ ai/ReplayMemory.py
3
+
4
+ AI Snake Game Simulator
5
+ Author: Nadim-Daniel Ghaznavi
6
+ Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
7
+ GitHub: https://github.com/NadimGhaznavi/ai
8
+ License: GPL 3.0
9
+
10
+ This file contains the ReplayMemory class.
11
+ """
12
+
13
+ import os
14
+ import random
15
+ import sqlite3, pickle
16
+ import tempfile
17
+
18
+ from ai_snake_lab.constants.DReplayMemory import MEM_TYPE
19
+ from ai_snake_lab.constants.DDef import DDef
20
+
21
+
22
+ class ReplayMemory:
23
+
24
+ def __init__(self, seed: int):
25
+ random.seed(seed)
26
+ self.batch_size = 250
27
+ # Valid options: shuffle, random_game or none
28
+ self._mem_type = MEM_TYPE.RANDOM_GAME
29
+ self.min_games = 1
30
+
31
+ # All of the states for a game are stored, in order.
32
+ self.cur_memory = []
33
+
34
+ # Get a temporary directory for the DB file
35
+ self._tmpfile = tempfile.NamedTemporaryFile(suffix=DDef.DOT_DB, delete=False)
36
+ self.db_file = self._tmpfile.name
37
+
38
+ # Connect to SQLite
39
+ self.conn = sqlite3.connect(self.db_file, check_same_thread=False)
40
+
41
+ # Get a cursor
42
+ self.cursor = self.conn.cursor()
43
+
44
+ # We don't need the file handle anymore
45
+ self._tmpfile.close()
46
+
47
+ # Intialize the schema
48
+ self.init_db()
49
+
50
+ def __enter__(self):
51
+ return self
52
+
53
+ def __exit__(self, exc_type, exc_val, exc_tb):
54
+ self.close()
55
+
56
+ def __del__(self):
57
+ try:
58
+ self.close()
59
+ except Exception:
60
+ pass # avoid errors on interpreter shutdown
61
+
62
+ def append(self, transition, final_score=None):
63
+ """Add a transition to the current game."""
64
+ old_state, move, reward, new_state, done, final_score = transition
65
+
66
+ self.cur_memory.append((old_state, move, reward, new_state, done))
67
+
68
+ if done:
69
+ if final_score is None:
70
+ raise ValueError("final_score must be provided when the game ends")
71
+
72
+ total_frames = len(self.cur_memory)
73
+
74
+ # Record the game
75
+ self.cursor.execute(
76
+ "INSERT INTO games (score, total_frames) VALUES (?, ?)",
77
+ (final_score, total_frames),
78
+ )
79
+ game_id = self.cursor.lastrowid
80
+
81
+ # Record the frames
82
+ for i, (state, action, reward, next_state, done) in enumerate(
83
+ self.cur_memory
84
+ ):
85
+ self.cursor.execute(
86
+ """
87
+ INSERT INTO frames (game_id, frame_index, state, action, reward, next_state, done)
88
+ VALUES (?, ?, ?, ?, ?, ?, ?)
89
+ """,
90
+ (
91
+ game_id,
92
+ i,
93
+ pickle.dumps(state),
94
+ pickle.dumps(action),
95
+ reward,
96
+ pickle.dumps(next_state),
97
+ done,
98
+ ),
99
+ )
100
+
101
+ self.conn.commit()
102
+ self.cur_memory = []
103
+
104
+ def close(self):
105
+ """Close the database connection."""
106
+ if getattr(self, "conn", None):
107
+ self.conn.close()
108
+ self.conn = None
109
+ if getattr(self, "db_file", None) and os.path.exists(self.db_file):
110
+ os.remove(self.db_file)
111
+ self.db_file = None
112
+
113
+ def get_average_game_length(self):
114
+ self.cursor.execute("SELECT AVG(total_frames) FROM games")
115
+ avg = self.cursor.fetchone()[0]
116
+ return int(avg) if avg else 0
117
+
118
+ def get_random_frames(self, n=None):
119
+ if n is None:
120
+ n = self.get_average_game_length() or 32 # fallback if no data
121
+
122
+ self.cursor.execute(
123
+ "SELECT state, action, reward, next_state, done "
124
+ "FROM frames ORDER BY RANDOM() LIMIT ?",
125
+ (n,),
126
+ )
127
+ rows = self.cursor.fetchall()
128
+
129
+ frames = [
130
+ (
131
+ pickle.loads(state_blob),
132
+ pickle.loads(action),
133
+ float(reward),
134
+ pickle.loads(next_state_blob),
135
+ bool(done),
136
+ )
137
+ for state_blob, action, reward, next_state_blob, done in rows
138
+ ]
139
+ return frames
140
+
141
+ def get_random_game(self):
142
+ self.cursor.execute("SELECT id FROM games")
143
+ all_ids = [row[0] for row in self.cursor.fetchall()]
144
+ if not all_ids or len(all_ids) < self.min_games:
145
+ return False
146
+
147
+ rand_id = random.choice(all_ids)
148
+ self.cursor.execute(
149
+ "SELECT state, action, reward, next_state, done "
150
+ "FROM frames WHERE game_id = ? ORDER BY frame_index ASC",
151
+ (rand_id,),
152
+ )
153
+ rows = self.cursor.fetchall()
154
+ if not rows:
155
+ return False
156
+
157
+ game = [
158
+ (
159
+ pickle.loads(state_blob),
160
+ pickle.loads(action),
161
+ float(reward),
162
+ pickle.loads(next_state_blob),
163
+ bool(done),
164
+ )
165
+ for state_blob, action, reward, next_state_blob, done in rows
166
+ ]
167
+ return game
168
+
169
+ def get_num_games(self):
170
+ """Return number of games stored in the database."""
171
+ self.cursor.execute("SELECT COUNT(*) FROM games")
172
+ return self.cursor.fetchone()[0]
173
+
174
+ def get_training_data(self, n_games=None, n_frames=None):
175
+ """
176
+ Returns a list of transitions for training based on the current memory type.
177
+
178
+ - n_games: used for RANDOM_GAME (how many full games to sample)
179
+ - n_frames: used for SHUFFLE (how many frames to sample)
180
+ - Returns empty list if memory type is NONE or if database/memory is empty
181
+ """
182
+ mem_type = self.mem_type()
183
+
184
+ print(f"SELECTED memory type: {mem_type}")
185
+ if mem_type == MEM_TYPE.NONE:
186
+ return []
187
+
188
+ elif mem_type == MEM_TYPE.RANDOM_GAME:
189
+ n_games = n_games or 1
190
+ training_data = []
191
+ for _ in range(n_games):
192
+ game = self.get_random_game()
193
+ if game:
194
+ training_data.extend(game)
195
+ return training_data
196
+
197
+ elif mem_type == MEM_TYPE.SHUFFLE:
198
+ n_frames = n_frames or self.get_average_game_length()
199
+ frames = self.get_random_frames(n=n_frames)
200
+ return frames
201
+
202
+ else:
203
+ raise ValueError(f"Unknown memory type: {mem_type}")
204
+
205
+ def init_db(self):
206
+ self.cursor.execute(
207
+ """
208
+ CREATE TABLE IF NOT EXISTS games (
209
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
210
+ score INTEGER NOT NULL,
211
+ total_frames INTEGER NOT NULL
212
+ );
213
+ """
214
+ )
215
+ self.conn.commit()
216
+
217
+ self.cursor.execute(
218
+ """
219
+ CREATE TABLE IF NOT EXISTS frames (
220
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
221
+ game_id INTEGER NOT NULL,
222
+ frame_index INTEGER NOT NULL,
223
+ state BLOB NOT NULL,
224
+ action BLOB NOT NULL,
225
+ reward INTEGER NOT NULL,
226
+ next_state BLOB NOT NULL,
227
+ done INTEGER NOT NULL, -- 0 or 1
228
+ FOREIGN KEY (game_id) REFERENCES games(id)
229
+ );
230
+ """
231
+ )
232
+ self.conn.commit()
233
+
234
+ self.cursor.execute(
235
+ """
236
+ CREATE UNIQUE INDEX IF NOT EXISTS idx_game_frame ON frames (game_id, frame_index);
237
+ """
238
+ )
239
+ self.conn.commit()
240
+
241
+ self.cursor.execute(
242
+ """
243
+ CREATE INDEX IF NOT EXISTS idx_frames_game_id ON frames (game_id);
244
+ """
245
+ )
246
+ self.conn.commit()
247
+
248
+ def mem_type(self, mem_type=None):
249
+ if mem_type is not None:
250
+ self._mem_type = mem_type
251
+ return self._mem_type
252
+
253
+ def set_memory(self, memory):
254
+ self.memory = memory
@@ -12,16 +12,19 @@ import torch
12
12
  import torch.nn as nn
13
13
  import torch.nn.functional as F
14
14
 
15
+ from ai_snake_lab.constants.DSim import DSim
16
+ from ai_snake_lab.constants.DModelLRNN import DModelRNN
17
+
15
18
 
16
19
  class ModelRNN(nn.Module):
17
20
  def __init__(self, seed: int):
18
21
  super(ModelRNN, self).__init__()
19
22
  torch.manual_seed(seed)
20
- input_size = 30
21
- hidden_size = 200
22
- output_size = 3
23
- rnn_layers = 4
24
- rnn_dropout = 0.2
23
+ input_size = DSim.STATE_SIZE
24
+ hidden_size = DModelRNN.HIDDEN_SIZE
25
+ output_size = DSim.OUTPUT_SIZE
26
+ rnn_layers = DModelRNN.RNN_LAYERS
27
+ rnn_dropout = DModelRNN.RNN_DROPOUT
25
28
  self.m_in = nn.Sequential(
26
29
  nn.Linear(input_size, hidden_size),
27
30
  nn.ReLU(),
@@ -37,7 +40,7 @@ class ModelRNN(nn.Module):
37
40
 
38
41
  def forward(self, x):
39
42
  x = self.m_in(x)
40
- inputs = x.view(1, -1, 200)
43
+ inputs = x.view(1, -1, DModelRNN.HIDDEN_SIZE)
41
44
  x, h_n = self.m_rnn(inputs)
42
45
  x = self.m_out(x)
43
46
  return x[len(x) - 1]
@@ -14,6 +14,6 @@ from ai_snake_lab.utils.ConstGroup import ConstGroup
14
14
  class DDef(ConstGroup):
15
15
  """Defaults"""
16
16
 
17
- APP_TITLE: str = "AI Snake Game Lab"
17
+ APP_TITLE: str = "AI Snake Lab"
18
18
  DOT_DB: str = ".db"
19
19
  MOVE_DELAY: float = 0.0
@@ -30,12 +30,14 @@ class DLabel(ConstGroup):
30
30
  GAME_SCORE: str = "Game Score"
31
31
  GAME_NUM: str = "Game Number"
32
32
  HIGHSCORE: str = "Highscore"
33
+ HIGHSCORES: str = "Highscores"
33
34
  MEM_TYPE: str = "Memory Type"
34
35
  MIN_EPSILON: str = "Minimum Epsilon"
35
36
  MODEL_LINEAR: str = "Linear"
36
37
  MODEL_RNN: str = "RNN"
37
38
  MODEL_TYPE: str = "Model Type"
38
39
  MOVE_DELAY: str = "Move Delay"
40
+ N_SLASH_A: str = "N/A"
39
41
  PAUSE: str = "Pause"
40
42
  QUIT: str = "Quit"
41
43
  RESTART: str = "Restart"
@@ -37,12 +37,17 @@ class DLayout(ConstGroup):
37
37
  GAME_BOX: str = "game_box"
38
38
  GAME_SCORE: str = "game_score"
39
39
  GAME_SCORE_PLOT: str = "game_score_plot"
40
+ HIGHSCORES: str = "highscores"
41
+ HIGHSCORES_BOX: str = "highscores_box"
42
+ HIGHSCORES_HEADER: str = "highscores_header"
40
43
  EPSILON_DECAY: str = "epsilon_decay"
41
44
  EPSILON_INITIAL: str = "initial_epsilon"
42
45
  EPSILON_MIN: str = "epsilon_min"
43
46
  INPUT_10: str = "input_10"
44
47
  LABEL: str = "label"
45
48
  LABEL_SETTINGS: str = "label_settings"
49
+ LABEL_SETTINGS_12: str = "label_settings_12"
50
+ MEM_TYPE: str = "memory_type"
46
51
  MOVE_DELAY: str = "move_delay"
47
52
  NUM_GAMES: str = "num_games"
48
53
  RUNTIME_BOX: str = "runtime_box"
@@ -15,6 +15,6 @@ class DModelRNN(ConstGroup):
15
15
  """RNN Model Defaults"""
16
16
 
17
17
  LEARNING_RATE: float = 0.0007
18
- INPUT_SIZE: int = 400
19
- MAX_MEMORIES: int = 20
20
- MAX_MEMORY: int = 100000
18
+ HIDDEN_SIZE: int = 200
19
+ RNN_LAYERS: int = 4
20
+ RNN_DROPOUT: float = 0.2
@@ -14,12 +14,21 @@ from ai_snake_lab.utils.ConstGroup import ConstGroup
14
14
  class MEM_TYPE(ConstGroup):
15
15
  """Replay Memory Type"""
16
16
 
17
- SHUFFLE: str = "shuffle"
18
- SHUFFLE_LABEL: str = "Shuffled set"
17
+ NONE: str = "none"
18
+ NONE_LABEL: str = "None"
19
19
  RANDOM_GAME: str = "random_game"
20
- RANDOM_GAME_LABEL: str = "Random game"
20
+ RANDOM_GAME_LABEL: str = "Random Game"
21
+ SHUFFLE: str = "shuffle"
22
+ SHUFFLE_LABEL: str = "Random Frames"
21
23
 
22
24
  MEM_TYPE_TABLE: dict = {
23
- SHUFFLE: SHUFFLE_LABEL,
25
+ NONE: NONE_LABEL,
24
26
  RANDOM_GAME: RANDOM_GAME_LABEL,
27
+ SHUFFLE: SHUFFLE_LABEL,
25
28
  }
29
+
30
+ MEMORY_TYPES: list = [
31
+ (NONE_LABEL, NONE),
32
+ (RANDOM_GAME_LABEL, RANDOM_GAME),
33
+ (SHUFFLE_LABEL, SHUFFLE),
34
+ ]
@@ -15,6 +15,6 @@ class DSim(ConstGroup):
15
15
  """Simulation Constants"""
16
16
 
17
17
  # Size of the statemap, this is from the GameBoard class
18
- STATE_SIZE: int = 30
18
+ STATE_SIZE: int = 27
19
19
  # The number of "choices" the snake has: go forward, left or right.
20
20
  OUTPUT_SIZE: int = 3
@@ -82,6 +82,123 @@ class GameBoard(ScrollView):
82
82
  return out_list
83
83
 
84
84
  def get_state(self):
85
+ head = self.snake_head
86
+ direction = self.direction
87
+
88
+ # Adjacent points
89
+ point_l = Offset(head.x - 1, head.y)
90
+ point_r = Offset(head.x + 1, head.y)
91
+ point_u = Offset(head.x, head.y - 1)
92
+ point_d = Offset(head.x, head.y + 1)
93
+
94
+ # Direction flags
95
+ dir_l = direction == Direction.LEFT
96
+ dir_r = direction == Direction.RIGHT
97
+ dir_u = direction == Direction.UP
98
+ dir_d = direction == Direction.DOWN
99
+
100
+ # Length encoded in 7-bit binary
101
+ slb = self.get_binary(7, len(self.snake_body))
102
+
103
+ # Normalized distances to walls (0=touching, 1=center)
104
+ width = height = self.board_size()
105
+ dist_left = head.x / width
106
+ dist_right = (width - head.x - 1) / width
107
+ dist_up = head.y / height
108
+ dist_down = (height - head.y - 1) / height
109
+
110
+ # Relative food direction (normalized)
111
+ dx = self.food.x - head.x
112
+ dy = self.food.y - head.y
113
+ food_dx = dx / max(1, width)
114
+ food_dy = dy / max(1, height)
115
+
116
+ # Free space straight ahead
117
+ free_ahead = 0
118
+ probe = Offset(head.x, head.y)
119
+ while (
120
+ 0 <= probe.x < width
121
+ and 0 <= probe.y < height
122
+ and not self.is_snake_collision(probe)
123
+ ):
124
+ free_ahead += 1
125
+ if dir_r:
126
+ probe = Offset(probe.x + 1, probe.y)
127
+ elif dir_l:
128
+ probe = Offset(probe.x - 1, probe.y)
129
+ elif dir_u:
130
+ probe = Offset(probe.x, probe.y - 1)
131
+ elif dir_d:
132
+ probe = Offset(probe.x, probe.y + 1)
133
+ free_ahead = free_ahead / max(width, height) # normalize
134
+
135
+ # Local free cell count (0–4)
136
+ adjacent_points = [point_l, point_r, point_u, point_d]
137
+ local_free = (
138
+ sum(
139
+ 1
140
+ for p in adjacent_points
141
+ if not self.is_wall_collision(p) and not self.is_snake_collision(p)
142
+ )
143
+ / 4.0
144
+ )
145
+
146
+ # Optional context (if tracked elsewhere)
147
+ recent_growth = getattr(self, "recent_growth", 0.0)
148
+ time_since_food = getattr(self, "steps_since_food", 0.0) / 100.0 # normalize
149
+
150
+ # --- EXISTING FEATURES ---
151
+ state = [
152
+ # 1-3. Snake collision directions
153
+ (dir_r and self.is_snake_collision(point_r))
154
+ or (dir_l and self.is_snake_collision(point_l))
155
+ or (dir_u and self.is_snake_collision(point_u))
156
+ or (dir_d and self.is_snake_collision(point_d)),
157
+ (dir_u and self.is_snake_collision(point_r))
158
+ or (dir_d and self.is_snake_collision(point_l))
159
+ or (dir_l and self.is_snake_collision(point_u))
160
+ or (dir_r and self.is_snake_collision(point_d)),
161
+ (dir_d and self.is_snake_collision(point_r))
162
+ or (dir_u and self.is_snake_collision(point_l))
163
+ or (dir_r and self.is_snake_collision(point_u))
164
+ or (dir_l and self.is_snake_collision(point_d)),
165
+ # 4-6. Wall collision directions
166
+ (dir_r and self.is_wall_collision(point_r))
167
+ or (dir_l and self.is_wall_collision(point_l))
168
+ or (dir_u and self.is_wall_collision(point_u))
169
+ or (dir_d and self.is_wall_collision(point_d)),
170
+ (dir_u and self.is_wall_collision(point_r))
171
+ or (dir_d and self.is_wall_collision(point_l))
172
+ or (dir_l and self.is_wall_collision(point_u))
173
+ or (dir_r and self.is_wall_collision(point_d)),
174
+ (dir_d and self.is_wall_collision(point_r))
175
+ or (dir_u and self.is_wall_collision(point_l))
176
+ or (dir_r and self.is_wall_collision(point_u))
177
+ or (dir_l and self.is_wall_collision(point_d)),
178
+ # 7-10. Direction flags
179
+ dir_l,
180
+ dir_r,
181
+ dir_u,
182
+ dir_d,
183
+ # 11-14. Food relative direction
184
+ food_dx,
185
+ food_dy,
186
+ # 15-21. Snake length bits
187
+ *slb,
188
+ # 22-26. Distances
189
+ dist_left,
190
+ dist_right,
191
+ dist_up,
192
+ dist_down,
193
+ free_ahead,
194
+ local_free,
195
+ recent_growth,
196
+ time_since_food,
197
+ ]
198
+
199
+ return [float(x) for x in state]
200
+
201
+ def get_state2(self):
85
202
 
86
203
  head = self.snake_head
87
204
  direction = self.direction
@@ -155,9 +155,9 @@ class SnakeGame:
155
155
 
156
156
  ## 6. Set a negative reward if the snake head is adjacent to the snake body.
157
157
  # This is to discourage snake collisions.
158
- for segment in self.snake[1:]:
159
- if abs(self.head.x - segment.x) < 2 and abs(self.head.y - segment.y) < 2:
160
- reward -= -2
158
+ # for segment in self.snake[1:]:
159
+ # if abs(self.head.x - segment.x) < 2 and abs(self.head.y - segment.y) < 2:
160
+ # reward -= -2
161
161
 
162
162
  self.game_reward += reward
163
163
  self.game_board.update_snake(snake=self.snake, direction=self.direction)
@@ -14,9 +14,8 @@ import sys, os
14
14
  from datetime import datetime, timedelta
15
15
 
16
16
  from textual.app import App, ComposeResult
17
- from textual.widgets import Label, Input, Button, Static
17
+ from textual.widgets import Label, Input, Button, Static, Log, Select
18
18
  from textual.containers import Vertical, Horizontal
19
- from textual.reactive import var
20
19
  from textual.theme import Theme
21
20
 
22
21
  from ai_snake_lab.constants.DDef import DDef
@@ -32,10 +31,13 @@ from ai_snake_lab.constants.DDb4EPlot import Plot
32
31
 
33
32
  from ai_snake_lab.ai.AIAgent import AIAgent
34
33
  from ai_snake_lab.ai.EpsilonAlgo import EpsilonAlgo
34
+
35
35
  from ai_snake_lab.game.GameBoard import GameBoard
36
36
  from ai_snake_lab.game.SnakeGame import SnakeGame
37
+
37
38
  from ai_snake_lab.ui.Db4EPlot import Db4EPlot
38
39
 
40
+
39
41
  RANDOM_SEED = 1970
40
42
 
41
43
  snake_lab_theme = Theme(
@@ -67,17 +69,15 @@ class AISim(App):
67
69
 
68
70
  ## Runtime values
69
71
  # Current epsilon value (degrades in real-time)
70
- cur_epsilon_widget = Label("N/A", id=DLayout.CUR_EPSILON)
71
- # Current memory type
72
- cur_mem_type_widget = Label("N/A", id=DLayout.CUR_MEM_TYPE)
72
+ cur_epsilon_widget = Label(DLabel.N_SLASH_A, id=DLayout.CUR_EPSILON)
73
73
  # Current model type
74
- cur_model_type_widget = Label("N/A", id=DLayout.CUR_MODEL_TYPE)
74
+ cur_model_type_widget = Label(DLabel.N_SLASH_A, id=DLayout.CUR_MODEL_TYPE)
75
75
  # Time delay between moves
76
76
  cur_move_delay = DDef.MOVE_DELAY
77
77
  # Number of stored games in the ReplayMemory
78
- cur_num_games_widget = Label("N/A", id=DLayout.NUM_GAMES)
78
+ cur_num_games_widget = Label(DLabel.N_SLASH_A, id=DLayout.NUM_GAMES)
79
79
  # Elapsed time
80
- cur_runtime_widget = Label("N/A", id=DLayout.RUNTIME)
80
+ cur_runtime_widget = Label(DLabel.N_SLASH_A, id=DLayout.RUNTIME)
81
81
 
82
82
  # Intial Settings for Epsilon
83
83
  initial_epsilon_input = Input(
@@ -185,6 +185,13 @@ class AISim(App):
185
185
  ),
186
186
  self.move_delay_input,
187
187
  ),
188
+ Horizontal(
189
+ Label(
190
+ f"{DLabel.MEM_TYPE}",
191
+ classes=DLayout.LABEL_SETTINGS_12,
192
+ ),
193
+ Select(MEM_TYPE.MEMORY_TYPES, compact=True, id=DLayout.MEM_TYPE),
194
+ ),
188
195
  id=DLayout.SETTINGS_BOX,
189
196
  )
190
197
 
@@ -202,7 +209,7 @@ class AISim(App):
202
209
  ),
203
210
  Horizontal(
204
211
  Label(f"{DLabel.MEM_TYPE}", classes=DLayout.LABEL),
205
- self.cur_mem_type_widget,
212
+ Label(DLabel.N_SLASH_A, id=DLayout.CUR_MEM_TYPE),
206
213
  ),
207
214
  Horizontal(
208
215
  Label(f"{DLabel.STORED_GAMES}", classes=DLayout.LABEL),
@@ -236,7 +243,11 @@ class AISim(App):
236
243
  )
237
244
 
238
245
  # Empty fillers
239
- yield Static(id=DLayout.FILLER_1)
246
+ yield Vertical(
247
+ Static(id=DLayout.HIGHSCORES_HEADER),
248
+ Log(highlight=False, auto_scroll=True, id=DLayout.HIGHSCORES),
249
+ id=DLayout.HIGHSCORES_BOX,
250
+ )
240
251
  yield Static(id=DLayout.FILLER_2)
241
252
  yield Static(id=DLayout.FILLER_3)
242
253
 
@@ -252,10 +263,14 @@ class AISim(App):
252
263
  settings_box.border_title = DLabel.SETTINGS
253
264
  runtime_box = self.query_one(f"#{DLayout.RUNTIME_BOX}", Vertical)
254
265
  runtime_box.border_title = DLabel.RUNTIME_VALUES
255
- self.cur_mem_type_widget.update(
256
- MEM_TYPE.MEM_TYPE_TABLE[self.agent.memory.mem_type()]
257
- )
258
- self.cur_num_games_widget.update(str(self.agent.memory.get_num_games()))
266
+ highscore_box = self.query_one(f"#{DLayout.HIGHSCORES_BOX}", Vertical)
267
+ highscore_box.border_title = DLabel.HIGHSCORES
268
+ cur_mem_type_widget = self.query_one(f"#{DLayout.CUR_MEM_TYPE}", Label)
269
+ cur_mem_type_widget.update(DLabel.N_SLASH_A)
270
+ highscores_header = self.query_one(f"#{DLayout.HIGHSCORES_HEADER}", Static)
271
+ highscores_header.update(f" [b #3e99af]{DLabel.GAME:6s}{DLabel.SCORE:6s}[/]")
272
+ memory_type_widget = self.query_one(f"#{DLayout.MEM_TYPE}")
273
+ memory_type_widget.value = MEM_TYPE.RANDOM_GAME
259
274
  # Initial state is that the app is stopped
260
275
  self.add_class(DField.STOPPED)
261
276
  # Register the theme
@@ -305,6 +320,8 @@ class AISim(App):
305
320
  game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
306
321
  game_box.border_title = ""
307
322
  game_box.border_subtitle = ""
323
+ highscores = self.query_one(f"#{DLayout.HIGHSCORES}", Log)
324
+ highscores.clear()
308
325
 
309
326
  # Recreate events and get a new thread
310
327
  self.stop_event = threading.Event()
@@ -324,13 +341,21 @@ class AISim(App):
324
341
  self.remove_class(DField.PAUSED)
325
342
  self.cur_move_delay = float(self.move_delay_input.value)
326
343
  self.cur_model_type_widget.update(self.agent.model_type())
344
+ memory_type_widget = self.query_one(f"#{DLayout.MEM_TYPE}")
345
+ self.agent.memory.mem_type(memory_type_widget.value)
346
+ cur_mem_type_widget = self.query_one(f"#{DLayout.CUR_MEM_TYPE}", Label)
347
+ cur_mem_type_widget.update(
348
+ MEM_TYPE.MEM_TYPE_TABLE[memory_type_widget.value]
349
+ )
327
350
 
328
- # Reset button was pressed
351
+ # Defaults button was pressed
329
352
  elif button_id == DLayout.BUTTON_DEFAULTS:
330
353
  self.initial_epsilon_input.value = str(DEpsilon.EPSILON_INITIAL)
331
354
  self.epsilon_decay_input.value = str(DEpsilon.EPSILON_DECAY)
332
355
  self.epsilon_min_input.value = str(DEpsilon.EPSILON_MIN)
333
356
  self.move_delay_input.value = str(DDef.MOVE_DELAY)
357
+ memory_type_widget = self.query_one(f"#{DLayout.MEM_TYPE}")
358
+ memory_type_widget.value = MEM_TYPE.RANDOM_GAME
334
359
 
335
360
  # Quit button was pressed
336
361
  elif button_id == DLayout.BUTTON_QUIT:
@@ -339,6 +364,8 @@ class AISim(App):
339
364
  # Update button was pressed
340
365
  elif button_id == DLayout.BUTTON_UPDATE:
341
366
  self.cur_move_delay = float(self.move_delay_input.value)
367
+ memory_type_widget = self.query_one(f"#{DLayout.MEM_TYPE}")
368
+ self.agent.memory.mem_type(memory_type_widget.value)
342
369
 
343
370
  def start_sim(self):
344
371
  self.snake_game.reset()
@@ -351,6 +378,8 @@ class AISim(App):
351
378
  game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
352
379
  game_box.border_title = f"{DLabel.GAME} #{self.epoch}"
353
380
  start_time = datetime.now()
381
+ self.cur_num_games_widget.update(str(self.agent.memory.get_num_games()))
382
+ highscores = self.query_one(f"#{DLayout.HIGHSCORES}", Log)
354
383
 
355
384
  while not self.stop_event.is_set():
356
385
  if self.pause_event.is_set():
@@ -363,6 +392,8 @@ class AISim(App):
363
392
  reward, game_over, score = snake_game.play_step(move)
364
393
  if score > highscore:
365
394
  highscore = score
395
+ # Update the UI
396
+ highscores.write_line(f"{self.epoch:6d} {score:6d}")
366
397
  game_box.border_subtitle = (
367
398
  f"{DLabel.HIGHSCORE}: {highscore}, {DLabel.SCORE}: {score}"
368
399
  )
@@ -378,7 +409,9 @@ class AISim(App):
378
409
  game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
379
410
  game_box.border_title = f"{DLabel.GAME} #{self.epoch}"
380
411
  # Remember the last move
381
- agent.remember(old_state, move, reward, new_state, game_over)
412
+ agent.remember(
413
+ old_state, move, reward, new_state, game_over, score=score
414
+ )
382
415
  # Train long memory
383
416
  agent.train_long_memory()
384
417
  # Reset the game
@@ -2,7 +2,7 @@ Screen {
2
2
  layout: grid;
3
3
  grid-size: 3 4;
4
4
  grid-rows: 3 7 6 11 10;
5
- grid-columns: 32 46 30;
5
+ grid-columns: 32 46 32;
6
6
  }
7
7
 
8
8
  #title {
@@ -34,6 +34,7 @@ Screen {
34
34
  }
35
35
 
36
36
  #runtime_box {
37
+ height: 100%;
37
38
  border-title-color: #5fc442;
38
39
  border-title-style: bold;
39
40
  border: round #0c323e;
@@ -41,7 +42,14 @@ Screen {
41
42
  background: black;
42
43
  }
43
44
 
44
- #filler_1 {
45
+ #highscores_box {
46
+ row-span: 2;
47
+ border-title-color: #5fc442;
48
+ border-title-style: bold;
49
+ border: round #0c323e;
50
+ padding: 0 1;
51
+ background: black;
52
+
45
53
  }
46
54
 
47
55
  #filler_2 {
@@ -54,7 +62,7 @@ Screen {
54
62
  dock: bottom;
55
63
  border: round #0c323e;
56
64
  height: 15;
57
- width: 108;
65
+ width: 110;
58
66
  background: black
59
67
  }
60
68
 
@@ -121,6 +129,11 @@ Button {
121
129
  width: 18;
122
130
  }
123
131
 
132
+ .label_settings_12 {
133
+ color: #5fc442;
134
+ width: 12;
135
+ }
136
+
124
137
  .paused #button_pause {
125
138
  display: none;
126
139
  }
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ai-snake-lab"
3
- version = "0.4.8"
3
+ version = "0.5.0"
4
4
  description = "Interactive reinforcement learning sandbox for experimenting with AI agents in a classic Snake Game environment."
5
5
  authors = [{ name = "Nadim-Daniel Ghaznavi", email = "nghaznavi@gmail.com" }]
6
6
  license = { text = "GPL-3.0" }
@@ -1,67 +0,0 @@
1
- # Introduction
2
-
3
- **AI Snake Lab** is an interactive reinforcement learning sandbox for experimenting with AI agents in a classic Snake Game environment — featuring a live Textual TUI interface, flexible replay memory database, and modular model definitions.
4
-
5
- ---
6
-
7
- # 🚀 Features
8
-
9
- - 🐍 **Classic Snake environment** with customizable grid and rules
10
- - 🧠 **AI agent interface** supporting multiple architectures (Linear, RNN, CNN)
11
- - 🎮 **Textual-based simulator** for live visualization and metrics
12
- - 💾 **SQLite-backed replay memory** for storing frames, episodes, and runs
13
- - 🧩 **Experiment metadata tracking** — models, hyperparameters, state-map versions
14
- - 📊 **Built-in plotting** for hashrate, scores, and learning progress
15
-
16
- ---
17
-
18
- # 🧰 Tech Stack
19
-
20
- | Component | Description |
21
- |------------|--------------|
22
- | **Python 3.11+** | Core language |
23
- | **Textual** | Terminal UI framework |
24
- | **SQLite3** | Lightweight replay memory + experiment store |
25
- | **PyTorch** *(optional)* | Deep learning backend for models |
26
- | **Plotext / Matplotlib** | Visualization tools |
27
-
28
- ---
29
-
30
- # Installation
31
-
32
- This project is on [PyPI](https://pypi.org/project/ai-snake-lab/). You can install the *AI Snake Lab* software using `pip`.
33
-
34
- ## Create a Sandbox
35
-
36
- ```shell
37
- python3 -m venv snake_venv
38
- . snake_venv/bin/activate
39
- ```
40
-
41
- ## Install the AI Snake Lab
42
-
43
- After you have activated your *venv* environment:
44
-
45
- ```shell
46
- pip install ai-snake-lab
47
- ```
48
-
49
- ---
50
-
51
- # Running the AI Snake Lab
52
-
53
- From within your *venv* environment:
54
-
55
- ```shell
56
- ai-snake-lab
57
- ```
58
-
59
- ---
60
-
61
- # Links and Acknowledgements
62
-
63
- This code is based on a YouTube tutorial, [Python + PyTorch + Pygame Reinforcement Learning – Train an AI to Play Snake](https://www.youtube.com/watch?v=L8ypSXwyBds&t=1042s&ab_channel=freeCodeCamp.org) by Patrick Loeber. You can access his original code [here](https://github.com/patrickloeber/snake-ai-pytorch) on GitHub. Thank you Patrick!!! You are amazing!!!!
64
-
65
- Thanks also go out to Will McGugan and the [Textual](https://textual.textualize.io/) team. Textual is an amazing framework. Talk about *rapid Application Development*. Porting this took less than a day.
66
-
67
- ---
@@ -1,148 +0,0 @@
1
- """
2
- ai/ReplayMemory.py
3
-
4
- AI Snake Game Simulator
5
- Author: Nadim-Daniel Ghaznavi
6
- Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
7
- GitHub: https://github.com/NadimGhaznavi/ai
8
- License: GPL 3.0
9
-
10
- This file contains the ReplayMemory class.
11
- """
12
-
13
- import os
14
- from collections import deque
15
- import random
16
- import sqlite3, pickle
17
- import tempfile
18
- import shutil
19
-
20
- from ai_snake_lab.constants.DReplayMemory import MEM_TYPE
21
- from ai_snake_lab.constants.DDef import DDef
22
-
23
-
24
- class ReplayMemory:
25
-
26
- def __init__(self, seed: int):
27
- random.seed(seed)
28
- self.batch_size = 250
29
- # Valid options: shuffle, random_game, targeted_score, random_targeted_score
30
- self._mem_type = MEM_TYPE.RANDOM_GAME
31
- self.min_games = 1
32
- self.max_states = 15000
33
- self.max_shuffle_games = 40
34
- self.max_games = 500
35
-
36
- if self._mem_type == MEM_TYPE.SHUFFLE:
37
- # States are stored in a deque and a random sample will be returned
38
- self.memories = deque(maxlen=self.max_states)
39
-
40
- elif self._mem_type == MEM_TYPE.RANDOM_GAME:
41
- # All of the states for a game are stored, in order, in a deque.
42
- # A complete game will be returned
43
- self.cur_memory = []
44
-
45
- # Get a temporary directory for the DB file
46
- self._tmpfile = tempfile.NamedTemporaryFile(suffix=DDef.DOT_DB, delete=False)
47
- self.db_file = self._tmpfile.name
48
-
49
- # Connect to SQLite
50
- self.conn = sqlite3.connect(self.db_file, check_same_thread=False)
51
-
52
- # Get a cursor
53
- self.cursor = self.conn.cursor()
54
-
55
- # We don't need the file handle anymore
56
- self._tmpfile.close()
57
-
58
- # Intialize the schema
59
- self.init_db()
60
-
61
- def __enter__(self):
62
- return self
63
-
64
- def __exit__(self, exc_type, exc_val, exc_tb):
65
- self.close()
66
-
67
- def __del__(self):
68
- try:
69
- self.close()
70
- except Exception:
71
- pass # avoid errors on interpreter shutdown
72
-
73
- def append(self, transition):
74
- """Add a transition to the current game."""
75
- if self._mem_type != MEM_TYPE.RANDOM_GAME:
76
- raise NotImplementedError(
77
- "Only RANDOM_GAME memory type is implemented for SQLite backend"
78
- )
79
-
80
- self.cur_memory.append(transition)
81
- _, _, _, _, done = transition
82
-
83
- if done:
84
- # Serialize the full game to JSON
85
- serialized = pickle.dumps(self.cur_memory)
86
- self.cursor.execute(
87
- "INSERT INTO games (transitions) VALUES (?)", (serialized,)
88
- )
89
- self.conn.commit()
90
- self.cur_memory = []
91
-
92
- def close(self):
93
- """Close the database connection."""
94
- if getattr(self, "conn", None):
95
- self.conn.close()
96
- self.conn = None
97
- if getattr(self, "db_file", None) and os.path.exists(self.db_file):
98
- os.remove(self.db_file)
99
- self.db_file = None
100
-
101
- def get_random_game(self):
102
- """Return a random full game from the database."""
103
- self.cursor.execute("SELECT id FROM games")
104
- all_ids = [row[0] for row in self.cursor.fetchall()]
105
- if len(all_ids) >= self.min_games:
106
- rand_id = random.choice(all_ids)
107
- self.cursor.execute("SELECT transitions FROM games WHERE id=?", (rand_id,))
108
- row = self.cursor.fetchone()
109
- if row:
110
- return pickle.loads(row[0])
111
- return False
112
-
113
- def get_random_states(self):
114
- mem_size = len(self.memories)
115
- if mem_size < self.batch_size:
116
- return self.memories
117
- return random.sample(self.memories, self.batch_size)
118
-
119
- def get_memory(self):
120
- if self._mem_type == MEM_TYPE.SHUFFLE:
121
- return self.get_random_states()
122
-
123
- elif self._mem_type == MEM_TYPE.RANDOM_GAME:
124
- return self.get_random_game()
125
-
126
- def get_num_games(self):
127
- """Return number of games stored in the database."""
128
- self.cursor.execute("SELECT COUNT(*) FROM games")
129
- return self.cursor.fetchone()[0]
130
-
131
- def init_db(self):
132
- self.cursor.execute(
133
- """
134
- CREATE TABLE IF NOT EXISTS games (
135
- id INTEGER PRIMARY KEY AUTOINCREMENT,
136
- transitions TEXT NOT NULL
137
- )
138
- """
139
- )
140
- self.conn.commit()
141
-
142
- def mem_type(self, mem_type=None):
143
- if mem_type is not None:
144
- self._mem_type = mem_type
145
- return self._mem_type
146
-
147
- def set_memory(self, memory):
148
- self.memory = memory
File without changes