ai-snake-lab 0.1.0__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_snake_lab/AISim.py +243 -86
- ai_snake_lab/ai/AIAgent.py +34 -31
- ai_snake_lab/ai/AITrainer.py +9 -5
- ai_snake_lab/ai/ReplayMemory.py +61 -24
- ai_snake_lab/ai/models/ModelL.py +7 -4
- ai_snake_lab/ai/models/ModelRNN.py +1 -1
- ai_snake_lab/constants/DDb4EPlot.py +20 -0
- ai_snake_lab/constants/DDef.py +1 -1
- ai_snake_lab/constants/DDir.py +4 -1
- ai_snake_lab/constants/DFields.py +6 -0
- ai_snake_lab/constants/DFile.py +2 -1
- ai_snake_lab/constants/DLabels.py +24 -4
- ai_snake_lab/constants/DLayout.py +16 -2
- ai_snake_lab/constants/DModelL.py +4 -0
- ai_snake_lab/constants/DSim.py +20 -0
- ai_snake_lab/game/GameBoard.py +36 -22
- ai_snake_lab/game/SnakeGame.py +17 -0
- ai_snake_lab/ui/Db4EPlot.py +160 -0
- ai_snake_lab/utils/AISim.tcss +81 -38
- {ai_snake_lab-0.1.0.dist-info → ai_snake_lab-0.4.3.dist-info}/METADATA +39 -5
- ai_snake_lab-0.4.3.dist-info/RECORD +31 -0
- {ai_snake_lab-0.1.0.dist-info → ai_snake_lab-0.4.3.dist-info}/WHEEL +1 -1
- {ai_snake_lab-0.1.0.dist-info → ai_snake_lab-0.4.3.dist-info/licenses}/LICENSE +2 -0
- ai_snake_lab-0.1.0.dist-info/RECORD +0 -28
- {ai_snake_lab-0.1.0.dist-info → ai_snake_lab-0.4.3.dist-info}/entry_points.txt +0 -0
ai_snake_lab/ai/ReplayMemory.py
CHANGED
@@ -10,10 +10,14 @@ ai/ReplayMemory.py
|
|
10
10
|
This file contains the ReplayMemory class.
|
11
11
|
"""
|
12
12
|
|
13
|
+
import os
|
13
14
|
from collections import deque
|
14
15
|
import random, sys
|
16
|
+
import sqlite3, pickle
|
15
17
|
|
16
18
|
from constants.DReplayMemory import MEM_TYPE
|
19
|
+
from constants.DFile import DFile
|
20
|
+
from constants.DDir import DDir
|
17
21
|
|
18
22
|
|
19
23
|
class ReplayMemory:
|
@@ -27,6 +31,11 @@ class ReplayMemory:
|
|
27
31
|
self.max_states = 15000
|
28
32
|
self.max_shuffle_games = 40
|
29
33
|
self.max_games = 500
|
34
|
+
self.db_file = os.path.join(DDir.AI_SNAKE_LAB, DDir.DB, DFile.REPLAY_DB)
|
35
|
+
|
36
|
+
# Delete the replay memory file, if it exists
|
37
|
+
if os.path.exists(self.db_file):
|
38
|
+
os.remove(self.db_file)
|
30
39
|
|
31
40
|
if self._mem_type == MEM_TYPE.SHUFFLE:
|
32
41
|
# States are stored in a deque and a random sample will be returned
|
@@ -35,35 +44,50 @@ class ReplayMemory:
|
|
35
44
|
elif self._mem_type == MEM_TYPE.RANDOM_GAME:
|
36
45
|
# All of the states for a game are stored, in order, in a deque.
|
37
46
|
# A complete game will be returned
|
38
|
-
self.memories = deque(maxlen=self.max_shuffle_games)
|
39
47
|
self.cur_memory = []
|
40
48
|
|
41
|
-
|
42
|
-
|
43
|
-
|
49
|
+
# Connect to SQLite
|
50
|
+
self.conn = sqlite3.connect(self.db_file, check_same_thread=False)
|
51
|
+
self.cursor = self.conn.cursor()
|
52
|
+
self.init_db()
|
44
53
|
|
45
|
-
def
|
46
|
-
|
54
|
+
def __len__(self):
|
55
|
+
return len(self.memories)
|
47
56
|
|
48
|
-
|
49
|
-
|
50
|
-
|
57
|
+
def append(self, transition):
|
58
|
+
"""Add a transition to the current game."""
|
59
|
+
if self._mem_type != MEM_TYPE.RANDOM_GAME:
|
60
|
+
raise NotImplementedError(
|
61
|
+
"Only RANDOM_GAME memory type is implemented for SQLite backend"
|
62
|
+
)
|
63
|
+
|
64
|
+
self.cur_memory.append(transition)
|
65
|
+
_, _, _, _, done = transition
|
66
|
+
|
67
|
+
if done:
|
68
|
+
# Serialize the full game to JSON
|
69
|
+
serialized = pickle.dumps(self.cur_memory)
|
70
|
+
self.cursor.execute(
|
71
|
+
"INSERT INTO games (transitions) VALUES (?)", (serialized,)
|
72
|
+
)
|
73
|
+
self.conn.commit()
|
74
|
+
self.cur_memory = []
|
51
75
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
self.cur_memory.append(transition)
|
56
|
-
state, action, reward, next_state, done = transition
|
57
|
-
if done:
|
58
|
-
self.memories.append(self.cur_memory)
|
59
|
-
self.cur_memory = []
|
76
|
+
def close(self):
|
77
|
+
"""Close the database connection."""
|
78
|
+
self.conn.close()
|
60
79
|
|
61
80
|
def get_random_game(self):
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
81
|
+
"""Return a random full game from the database."""
|
82
|
+
self.cursor.execute("SELECT id FROM games")
|
83
|
+
all_ids = [row[0] for row in self.cursor.fetchall()]
|
84
|
+
if len(all_ids) >= self.min_games:
|
85
|
+
rand_id = random.choice(all_ids)
|
86
|
+
self.cursor.execute("SELECT transitions FROM games WHERE id=?", (rand_id,))
|
87
|
+
row = self.cursor.fetchone()
|
88
|
+
if row:
|
89
|
+
return pickle.loads(row[0])
|
90
|
+
return False
|
67
91
|
|
68
92
|
def get_random_states(self):
|
69
93
|
mem_size = len(self.memories)
|
@@ -78,8 +102,21 @@ class ReplayMemory:
|
|
78
102
|
elif self._mem_type == MEM_TYPE.RANDOM_GAME:
|
79
103
|
return self.get_random_game()
|
80
104
|
|
81
|
-
def
|
82
|
-
|
105
|
+
def get_num_games(self):
|
106
|
+
"""Return number of games stored in the database."""
|
107
|
+
self.cursor.execute("SELECT COUNT(*) FROM games")
|
108
|
+
return self.cursor.fetchone()[0]
|
109
|
+
|
110
|
+
def init_db(self):
|
111
|
+
self.cursor.execute(
|
112
|
+
"""
|
113
|
+
CREATE TABLE IF NOT EXISTS games (
|
114
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
115
|
+
transitions TEXT NOT NULL
|
116
|
+
)
|
117
|
+
"""
|
118
|
+
)
|
119
|
+
self.conn.commit()
|
83
120
|
|
84
121
|
def mem_type(self, mem_type=None):
|
85
122
|
if mem_type is not None:
|
ai_snake_lab/ai/models/ModelL.py
CHANGED
@@ -12,15 +12,18 @@ import torch
|
|
12
12
|
import torch.nn as nn
|
13
13
|
import torch.nn.functional as F
|
14
14
|
|
15
|
+
from ai_snake_lab.constants.DSim import DSim
|
16
|
+
from ai_snake_lab.constants.DModelL import DModelL
|
17
|
+
|
15
18
|
|
16
19
|
class ModelL(nn.Module):
|
17
20
|
def __init__(self, seed: int):
|
18
21
|
super(ModelL, self).__init__()
|
19
22
|
torch.manual_seed(seed)
|
20
|
-
input_size =
|
21
|
-
hidden_size =
|
22
|
-
output_size =
|
23
|
-
p_value =
|
23
|
+
input_size = DSim.STATE_SIZE # Size of the "state" as tracked by the GameBoard
|
24
|
+
hidden_size = DModelL.HIDDEN_SIZE
|
25
|
+
output_size = DSim.OUTPUT_SIZE
|
26
|
+
p_value = DModelL.P_VALUE
|
24
27
|
self.input_block = nn.Sequential(
|
25
28
|
nn.Linear(input_size, hidden_size),
|
26
29
|
nn.ReLU(),
|
@@ -0,0 +1,20 @@
|
|
1
|
+
"""
|
2
|
+
constants/DDb4EPlot.py
|
3
|
+
|
4
|
+
AI Snake Game Simulator
|
5
|
+
Author: Nadim-Daniel Ghaznavi
|
6
|
+
Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
|
7
|
+
GitHub: https://github.com/NadimGhaznavi/ai
|
8
|
+
License: GPL 3.0
|
9
|
+
"""
|
10
|
+
|
11
|
+
from utils.ConstGroup import ConstGroup
|
12
|
+
|
13
|
+
|
14
|
+
class Plot(ConstGroup):
|
15
|
+
"""Db4EPlot Constants"""
|
16
|
+
|
17
|
+
# Simulation loop states
|
18
|
+
AVERAGE: str = "average"
|
19
|
+
SLIDING: str = "sliding"
|
20
|
+
MAX_DATA_POINTS: int = 200
|
ai_snake_lab/constants/DDef.py
CHANGED
ai_snake_lab/constants/DDir.py
CHANGED
@@ -14,5 +14,11 @@ from utils.ConstGroup import ConstGroup
|
|
14
14
|
class DField(ConstGroup):
|
15
15
|
"""Fields"""
|
16
16
|
|
17
|
+
# Simulation loop states
|
18
|
+
PAUSED: str = "paused"
|
17
19
|
RUNNING: str = "running"
|
18
20
|
STOPPED: str = "stopped"
|
21
|
+
|
22
|
+
# Stats dictionary keys
|
23
|
+
GAME_SCORE: str = "game_score"
|
24
|
+
GAME_NUM: str = "game_num"
|
ai_snake_lab/constants/DFile.py
CHANGED
@@ -8,27 +8,47 @@ constants/DLabels.py
|
|
8
8
|
License: GPL 3.0
|
9
9
|
"""
|
10
10
|
|
11
|
-
from
|
11
|
+
from ai_snake_lab.ai.models.ModelL import ModelL
|
12
|
+
from ai_snake_lab.ai.models.ModelRNN import ModelRNN
|
13
|
+
|
14
|
+
from ai_snake_lab.utils.ConstGroup import ConstGroup
|
12
15
|
|
13
16
|
|
14
17
|
class DLabel(ConstGroup):
|
15
18
|
"""Labels"""
|
16
19
|
|
20
|
+
AVERAGE: str = "Average"
|
21
|
+
CURRENT: str = "Current"
|
22
|
+
CURRENT_EPSILON: str = "Current Epsilon"
|
23
|
+
DEFAULTS: str = "Defaults"
|
17
24
|
EPSILON: str = "Epsilon"
|
18
25
|
EPSILON_DECAY: str = "Epsilon Decay"
|
19
26
|
EPSILON_INITIAL: str = "Initial Epsilon"
|
20
27
|
EPSILON_MIN: str = "Minimum Epsilon"
|
21
28
|
GAME: str = "Game"
|
29
|
+
GAMES: str = "Games"
|
30
|
+
GAME_SCORE: str = "Game Score"
|
31
|
+
GAME_NUM: str = "Game Number"
|
22
32
|
HIGHSCORE: str = "Highscore"
|
23
33
|
MEM_TYPE: str = "Memory Type"
|
24
|
-
MEMORIES: str = "Memories"
|
25
34
|
MIN_EPSILON: str = "Minimum Epsilon"
|
35
|
+
MODEL_LINEAR: str = "Linear"
|
36
|
+
MODEL_RNN: str = "RNN"
|
37
|
+
MODEL_TYPE: str = "Model Type"
|
26
38
|
MOVE_DELAY: str = "Move Delay"
|
27
39
|
PAUSE: str = "Pause"
|
28
40
|
QUIT: str = "Quit"
|
29
|
-
|
41
|
+
RESTART: str = "Restart"
|
42
|
+
RUNTIME: str = "Runtime"
|
43
|
+
RUNTIME_VALUES: str = "Runtime Values"
|
30
44
|
SCORE: str = "Score"
|
31
45
|
SETTINGS: str = "Configuration Settings"
|
32
46
|
START: str = "Start"
|
33
|
-
|
47
|
+
STORED_GAMES: str = "Stored Games"
|
48
|
+
RESTART: str = "Restart"
|
34
49
|
UPDATE: str = "Update"
|
50
|
+
|
51
|
+
MODEL_TYPE_TABLE: dict = {
|
52
|
+
str(ModelL): MODEL_LINEAR,
|
53
|
+
ModelRNN: MODEL_RNN,
|
54
|
+
}
|
@@ -14,16 +14,29 @@ from utils.ConstGroup import ConstGroup
|
|
14
14
|
class DLayout(ConstGroup):
|
15
15
|
"""Layout"""
|
16
16
|
|
17
|
+
BUTTON_BOX: str = "button_box"
|
17
18
|
BUTTON_PAUSE: str = "button_pause"
|
18
19
|
BUTTON_QUIT: str = "button_quit"
|
20
|
+
BUTTON_RESTART: str = "button_restart"
|
19
21
|
BUTTON_ROW: str = "button_row"
|
20
22
|
BUTTON_START: str = "button_start"
|
21
|
-
|
23
|
+
BUTTON_DEFAULTS: str = "button_defaults"
|
22
24
|
BUTTON_UPDATE: str = "button_update"
|
23
25
|
CUR_EPSILON: str = "cur_epsilon"
|
24
26
|
CUR_MEM_TYPE: str = "cur_mem_type"
|
27
|
+
CUR_MODEL_TYPE: str = "cur_model_type"
|
28
|
+
FILLER_1: str = "filler_1"
|
29
|
+
FILLER_2: str = "filler_2"
|
30
|
+
FILLER_3: str = "filler_3"
|
31
|
+
FILLER_4: str = "filler_4"
|
32
|
+
FILLER_5: str = "filler_5"
|
33
|
+
FILLER_6: str = "filler_6"
|
34
|
+
FILLER_7: str = "filler_7"
|
35
|
+
FILLER_8: str = "filler_8"
|
25
36
|
GAME_BOARD: str = "game_board"
|
26
37
|
GAME_BOX: str = "game_box"
|
38
|
+
GAME_SCORE: str = "game_score"
|
39
|
+
GAME_SCORE_PLOT: str = "game_score_plot"
|
27
40
|
EPSILON_DECAY: str = "epsilon_decay"
|
28
41
|
EPSILON_INITIAL: str = "initial_epsilon"
|
29
42
|
EPSILON_MIN: str = "epsilon_min"
|
@@ -31,8 +44,9 @@ class DLayout(ConstGroup):
|
|
31
44
|
LABEL: str = "label"
|
32
45
|
LABEL_SETTINGS: str = "label_settings"
|
33
46
|
MOVE_DELAY: str = "move_delay"
|
34
|
-
|
47
|
+
NUM_GAMES: str = "num_games"
|
35
48
|
RUNTIME_BOX: str = "runtime_box"
|
49
|
+
RUNTIME: str = "runtime"
|
36
50
|
SCORE: str = "score"
|
37
51
|
SETTINGS_BOX: str = "settings_box"
|
38
52
|
TITLE: str = "title"
|
@@ -0,0 +1,20 @@
|
|
1
|
+
"""
|
2
|
+
constants/DGameBoard.py
|
3
|
+
|
4
|
+
AI Snake Game Simulator
|
5
|
+
Author: Nadim-Daniel Ghaznavi
|
6
|
+
Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
|
7
|
+
GitHub: https://github.com/NadimGhaznavi/ai
|
8
|
+
License: GPL 3.0
|
9
|
+
"""
|
10
|
+
|
11
|
+
from utils.ConstGroup import ConstGroup
|
12
|
+
|
13
|
+
|
14
|
+
class DSim(ConstGroup):
|
15
|
+
"""Simulation Constants"""
|
16
|
+
|
17
|
+
# Size of the statemap, this is from the GameBoard class
|
18
|
+
STATE_SIZE: int = 30
|
19
|
+
# The number of "choices" the snake has: go forward, left or right.
|
20
|
+
OUTPUT_SIZE: int = 3
|
ai_snake_lab/game/GameBoard.py
CHANGED
@@ -68,6 +68,19 @@ class GameBoard(ScrollView):
|
|
68
68
|
def board_size(self) -> int:
|
69
69
|
return self._board_size
|
70
70
|
|
71
|
+
def get_binary(self, bits_needed, some_int):
|
72
|
+
# This is used in the state map, the get_state() function.
|
73
|
+
some_int = int(some_int)
|
74
|
+
bin_str = format(some_int, "b")
|
75
|
+
out_list = []
|
76
|
+
for bit in range(len(bin_str)):
|
77
|
+
out_list.append(bin_str[bit])
|
78
|
+
for zero in range(bits_needed - len(out_list)):
|
79
|
+
out_list.insert(0, "0")
|
80
|
+
for x in range(bits_needed):
|
81
|
+
out_list[x] = int(out_list[x])
|
82
|
+
return out_list
|
83
|
+
|
71
84
|
def get_state(self):
|
72
85
|
|
73
86
|
head = self.snake_head
|
@@ -80,7 +93,7 @@ class GameBoard(ScrollView):
|
|
80
93
|
dir_r = direction == Direction.RIGHT
|
81
94
|
dir_u = direction == Direction.UP
|
82
95
|
dir_d = direction == Direction.DOWN
|
83
|
-
|
96
|
+
slb = self.get_binary(7, len(self.snake_body))
|
84
97
|
state = [
|
85
98
|
# 1. Snake collision straight ahead
|
86
99
|
(dir_r and self.is_snake_collision(point_r))
|
@@ -97,47 +110,48 @@ class GameBoard(ScrollView):
|
|
97
110
|
or (dir_u and self.is_snake_collision(point_l))
|
98
111
|
or (dir_r and self.is_snake_collision(point_u))
|
99
112
|
or (dir_l and self.is_snake_collision(point_d)),
|
100
|
-
# 4.
|
101
|
-
0,
|
102
|
-
# 5. Wall collision straight ahead
|
113
|
+
# 4. Wall collision straight ahead
|
103
114
|
(dir_r and self.is_wall_collision(point_r))
|
104
115
|
or (dir_l and self.is_wall_collision(point_l))
|
105
116
|
or (dir_u and self.is_wall_collision(point_u))
|
106
117
|
or (dir_d and self.is_wall_collision(point_d)),
|
107
|
-
#
|
118
|
+
# 5. Wall collision to the right
|
108
119
|
(dir_u and self.is_wall_collision(point_r))
|
109
120
|
or (dir_d and self.is_wall_collision(point_l))
|
110
121
|
or (dir_l and self.is_wall_collision(point_u))
|
111
122
|
or (dir_r and self.is_wall_collision(point_d)),
|
112
|
-
#
|
123
|
+
# 6. Wall collision to the left
|
113
124
|
(dir_d and self.is_wall_collision(point_r))
|
114
125
|
or (dir_u and self.is_wall_collision(point_l))
|
115
126
|
or (dir_r and self.is_wall_collision(point_u))
|
116
127
|
or (dir_l and self.is_wall_collision(point_d)),
|
117
|
-
#
|
118
|
-
0,
|
119
|
-
# 9, 10, 11, 12. Last move direction
|
128
|
+
# 7 - 10. Last move direction
|
120
129
|
dir_l,
|
121
130
|
dir_r,
|
122
131
|
dir_u,
|
123
132
|
dir_d,
|
124
|
-
#
|
125
|
-
|
126
|
-
|
127
|
-
self.food.
|
128
|
-
self.food.
|
129
|
-
self.food.
|
130
|
-
self.food.y > self.snake_head.y, # Food down
|
131
|
-
self.food.x == self.snake_head.x,
|
133
|
+
# 11 - 19. Food location
|
134
|
+
self.food.x < self.snake_head.x, # 11. Food left
|
135
|
+
self.food.x > self.snake_head.x, # 12. Food right
|
136
|
+
self.food.y < self.snake_head.y, # 13. Food up
|
137
|
+
self.food.y > self.snake_head.y, # 14. Food down
|
138
|
+
self.food.x == self.snake_head.x, # 15.
|
132
139
|
self.food.x == self.snake_head.x
|
133
|
-
and self.food.y > self.snake_head.y, # Food ahead
|
140
|
+
and self.food.y > self.snake_head.y, # 16. Food ahead
|
134
141
|
self.food.x == self.snake_head.x
|
135
|
-
and self.food.y < self.snake_head.y, # Food behind
|
136
|
-
self.food.y == self.snake_head.y,
|
142
|
+
and self.food.y < self.snake_head.y, # 17. Food behind
|
137
143
|
self.food.y == self.snake_head.y
|
138
|
-
and self.food.x > self.snake_head.x, # Food above
|
144
|
+
and self.food.x > self.snake_head.x, # 18. Food above
|
139
145
|
self.food.y == self.snake_head.y
|
140
|
-
and self.food.x < self.snake_head.x, # Food below
|
146
|
+
and self.food.x < self.snake_head.x, # 19. Food below
|
147
|
+
# 20 - 26. Snake length in binary
|
148
|
+
slb[0],
|
149
|
+
slb[1],
|
150
|
+
slb[2],
|
151
|
+
slb[3],
|
152
|
+
slb[4],
|
153
|
+
slb[5],
|
154
|
+
slb[6],
|
141
155
|
]
|
142
156
|
|
143
157
|
# 24, 25, 26 and 27. Previous direction of the snake
|
ai_snake_lab/game/SnakeGame.py
CHANGED
@@ -54,6 +54,9 @@ class SnakeGame:
|
|
54
54
|
self.game_board.update_snake(snake=self.snake, direction=self.direction)
|
55
55
|
self.game_board.update_food(food=self.food)
|
56
56
|
|
57
|
+
# Track the distance from the snake head to the food to feed the reward system
|
58
|
+
self.distance_to_food = self.game_board.board_size() // 2
|
59
|
+
|
57
60
|
# The current game score
|
58
61
|
self.game_score = 0
|
59
62
|
|
@@ -126,6 +129,12 @@ class SnakeGame:
|
|
126
129
|
game_over = True
|
127
130
|
reward = -10
|
128
131
|
|
132
|
+
# Set a negative reward if the snake head is adjacent to the snake body.
|
133
|
+
# This is to discourage snake collisions.
|
134
|
+
for segment in self.snake[1:]:
|
135
|
+
if abs(self.head.x - segment.x) < 2 and abs(self.head.y - segment.y) < 2:
|
136
|
+
reward -= -1
|
137
|
+
|
129
138
|
if game_over == True:
|
130
139
|
# Game is over: Snake or wall collision or exceeded max moves
|
131
140
|
self.game_reward += reward
|
@@ -142,6 +151,14 @@ class SnakeGame:
|
|
142
151
|
else:
|
143
152
|
self.snake.pop()
|
144
153
|
|
154
|
+
## 5. See if we're closer to the food than the last move, or further away
|
155
|
+
cur_distance = abs(self.head.x - self.food.x) + abs(self.head.y - self.food.y)
|
156
|
+
if cur_distance < self.distance_to_food:
|
157
|
+
reward += 2
|
158
|
+
elif cur_distance > self.distance_to_food:
|
159
|
+
reward -= 2
|
160
|
+
self.distance_to_food = cur_distance
|
161
|
+
|
145
162
|
self.game_reward += reward
|
146
163
|
self.game_board.update_snake(snake=self.snake, direction=self.direction)
|
147
164
|
self.game_board.update_food(food=self.food)
|
@@ -0,0 +1,160 @@
|
|
1
|
+
"""
|
2
|
+
db4e/Modules/Db4EPlot.py
|
3
|
+
|
4
|
+
Database 4 Everything
|
5
|
+
Author: Nadim-Daniel Ghaznavi
|
6
|
+
Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
|
7
|
+
GitHub: https://github.com/NadimGhaznavi/db4e
|
8
|
+
License: GPL 3.0
|
9
|
+
"""
|
10
|
+
|
11
|
+
import math
|
12
|
+
from collections import deque
|
13
|
+
from textual_plot import PlotWidget, HiResMode, LegendLocation
|
14
|
+
from textual.app import ComposeResult
|
15
|
+
|
16
|
+
from ai_snake_lab.constants.DDb4EPlot import Plot
|
17
|
+
from ai_snake_lab.constants.DLabels import DLabel
|
18
|
+
|
19
|
+
|
20
|
+
MAX_DATA_POINTS = Plot.MAX_DATA_POINTS
|
21
|
+
|
22
|
+
|
23
|
+
class Db4EPlot(PlotWidget):
|
24
|
+
"""
|
25
|
+
A widget for plotting data based on TextualPlot's PlotWidget.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, title, id, thin_method=None):
|
29
|
+
super().__init__(title, id, allow_pan_and_zoom=False)
|
30
|
+
self._plot_id = id
|
31
|
+
self._title = title
|
32
|
+
self._thin_method = thin_method
|
33
|
+
if thin_method == Plot.SLIDING:
|
34
|
+
self._all_days = deque(maxlen=MAX_DATA_POINTS)
|
35
|
+
self._all_values = deque(maxlen=MAX_DATA_POINTS)
|
36
|
+
else:
|
37
|
+
self._all_days = None
|
38
|
+
self._all_values = None
|
39
|
+
|
40
|
+
def compose(self) -> ComposeResult:
|
41
|
+
yield PlotWidget()
|
42
|
+
|
43
|
+
def load_data(self, days, values, units):
|
44
|
+
self._all_days = days
|
45
|
+
self._all_values = values
|
46
|
+
if units:
|
47
|
+
self.set_ylabel(self._title + " (" + units + ")")
|
48
|
+
else:
|
49
|
+
self.set_ylabel(self._title)
|
50
|
+
|
51
|
+
def add_data(self, day, value):
|
52
|
+
self._all_days.append(day)
|
53
|
+
self._all_values.append(value)
|
54
|
+
|
55
|
+
def set_xlabel(self, label):
|
56
|
+
return super().set_xlabel(label)
|
57
|
+
|
58
|
+
def db4e_plot(self, days=None, values=None) -> None:
|
59
|
+
if days is not None and values is not None:
|
60
|
+
plot_days = days
|
61
|
+
plot_values = values
|
62
|
+
else:
|
63
|
+
plot_days = self._all_days
|
64
|
+
plot_values = self._all_values
|
65
|
+
self.clear()
|
66
|
+
if len(plot_days) == 0:
|
67
|
+
return
|
68
|
+
if self._thin_method == Plot.AVERAGE:
|
69
|
+
reduced_days, reduced_values = self.reduce_data(plot_days, plot_values)
|
70
|
+
else:
|
71
|
+
reduced_days, reduced_values = list(self._all_days), list(self._all_values)
|
72
|
+
|
73
|
+
self.plot(
|
74
|
+
x=reduced_days,
|
75
|
+
y=reduced_values,
|
76
|
+
hires_mode=HiResMode.BRAILLE,
|
77
|
+
line_style="green",
|
78
|
+
label=DLabel.CURRENT,
|
79
|
+
)
|
80
|
+
|
81
|
+
# Add an average plot over 20 to wash out the spikes and identify when the
|
82
|
+
# AI is maxing out.
|
83
|
+
window = max(1, len(reduced_values) // 20) # e.g., 5% smoothing window
|
84
|
+
if len(reduced_values) > window:
|
85
|
+
smoothed = [
|
86
|
+
sum(reduced_values[i : i + window])
|
87
|
+
/ len(reduced_values[i : i + window])
|
88
|
+
for i in range(len(reduced_values) - window + 1)
|
89
|
+
]
|
90
|
+
smoothed_days = reduced_days[window - 1 :]
|
91
|
+
self.plot(
|
92
|
+
x=smoothed_days,
|
93
|
+
y=smoothed,
|
94
|
+
hires_mode=HiResMode.BRAILLE,
|
95
|
+
line_style="red", # distinct color for trend
|
96
|
+
label=DLabel.AVERAGE,
|
97
|
+
)
|
98
|
+
self.show_legend(location=LegendLocation.TOPLEFT)
|
99
|
+
|
100
|
+
def reduce_data2(self, times, values):
|
101
|
+
# Reduce the total number of data points, otherwise the plot gets "blurry"
|
102
|
+
step = max(1, len(times) // MAX_DATA_POINTS)
|
103
|
+
|
104
|
+
# Reduce times with step
|
105
|
+
reduced_times = times[::step]
|
106
|
+
|
107
|
+
# Bin values by step (average)
|
108
|
+
reduced_values = [
|
109
|
+
sum(values[i : i + step]) / len(values[i : i + step])
|
110
|
+
for i in range(0, len(values), step)
|
111
|
+
]
|
112
|
+
results = reduced_times[: len(reduced_values)], reduced_values
|
113
|
+
return results
|
114
|
+
|
115
|
+
def reduce_data(self, times, values):
|
116
|
+
"""Reduce times and values into <= MAX_DATA_POINTS bins.
|
117
|
+
Each bin's value is the average of the values in the bin.
|
118
|
+
Each bin's time is chosen as the last time in the bin (so last bin -> times[-1]).
|
119
|
+
"""
|
120
|
+
if not times or not values:
|
121
|
+
return [], []
|
122
|
+
|
123
|
+
assert len(times) == len(values), "times and values must be same length"
|
124
|
+
|
125
|
+
step = max(1, math.ceil(len(times) / MAX_DATA_POINTS))
|
126
|
+
|
127
|
+
reduced_times = []
|
128
|
+
reduced_values = []
|
129
|
+
for i in range(0, len(times), step):
|
130
|
+
chunk_times = times[i : i + step]
|
131
|
+
chunk_vals = values[i : i + step]
|
132
|
+
|
133
|
+
# average values (works for floats or Decimal)
|
134
|
+
avg_val = sum(chunk_vals) / len(chunk_vals)
|
135
|
+
|
136
|
+
# representative time: choose last item in the chunk so final rep is times[-1]
|
137
|
+
rep_time = chunk_times[-1]
|
138
|
+
|
139
|
+
reduced_times.append(rep_time)
|
140
|
+
reduced_values.append(avg_val)
|
141
|
+
|
142
|
+
# Guarantee the final time equals the exact last time (safety)
|
143
|
+
if reduced_times:
|
144
|
+
reduced_times[-1] = times[-1]
|
145
|
+
|
146
|
+
return reduced_times, reduced_values
|
147
|
+
|
148
|
+
def update_time_range(self, selected_time):
|
149
|
+
if selected_time == -1:
|
150
|
+
return
|
151
|
+
|
152
|
+
selected_time = int(selected_time)
|
153
|
+
max_length = len(self._all_days)
|
154
|
+
if selected_time > max_length:
|
155
|
+
new_values = self._all_values
|
156
|
+
new_times = self._all_days
|
157
|
+
else:
|
158
|
+
new_values = self._all_values[-selected_time:]
|
159
|
+
new_times = self._all_days[-selected_time:]
|
160
|
+
self.db4e_plot(new_times, new_values)
|