TCAMpy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- TCAMpy.py +1534 -0
- tcampy-0.1.0.dist-info/METADATA +60 -0
- tcampy-0.1.0.dist-info/RECORD +6 -0
- tcampy-0.1.0.dist-info/WHEEL +5 -0
- tcampy-0.1.0.dist-info/licenses/LICENSE +674 -0
- tcampy-0.1.0.dist-info/top_level.txt +1 -0
TCAMpy.py
ADDED
|
@@ -0,0 +1,1534 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import time
|
|
3
|
+
import random
|
|
4
|
+
import hashlib
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import altair as alt
|
|
8
|
+
import streamlit as st
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
import matplotlib.animation as animation
|
|
11
|
+
|
|
12
|
+
from sklearn.metrics import r2_score, mean_absolute_error
|
|
13
|
+
from sklearn.model_selection import train_test_split
|
|
14
|
+
from sklearn.ensemble import RandomForestRegressor
|
|
15
|
+
from streamlit_javascript import st_javascript
|
|
16
|
+
from scipy.ndimage import gaussian_filter
|
|
17
|
+
from scipy.stats import skew, kurtosis
|
|
18
|
+
from functools import wraps
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
|
|
21
|
+
class TModel:
|
|
22
|
+
"""
|
|
23
|
+
Class for a cellular automata, modeling tumor growth.
|
|
24
|
+
|
|
25
|
+
Parameters:
|
|
26
|
+
cycles (int): duration of the model given in hours
|
|
27
|
+
side (int): the length of the side of field (10um)
|
|
28
|
+
pmax (int): maximum proliferation potential of RTC
|
|
29
|
+
PA (int): chance for apoptosis of RTC (in percent)
|
|
30
|
+
CCT (int): cell cycle time of cells given in hours
|
|
31
|
+
Dt (float): time step of the model given in days
|
|
32
|
+
PS (int): STC-STC division chance (in percent)
|
|
33
|
+
mu (int): migration capacity of cancer cells
|
|
34
|
+
I (int): strength of the immune cells (1-5)
|
|
35
|
+
M (int): tumor mutation chance (in percent)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, cycles, side, pmax, PA, CCT, Dt, PS, mu, I, M):
|
|
39
|
+
# Parameters
|
|
40
|
+
self.cycles = cycles
|
|
41
|
+
self.side = side
|
|
42
|
+
self.pmax = pmax
|
|
43
|
+
self.CCT = CCT
|
|
44
|
+
self.Dt = Dt
|
|
45
|
+
self.mu = mu
|
|
46
|
+
self.I = I
|
|
47
|
+
self.M = M
|
|
48
|
+
|
|
49
|
+
# Single model data
|
|
50
|
+
self.stc_number = []
|
|
51
|
+
self.rtc_number = []
|
|
52
|
+
self.wbc_number = []
|
|
53
|
+
self.immune = []
|
|
54
|
+
self.mutate = []
|
|
55
|
+
self.mutmap = []
|
|
56
|
+
self.field = []
|
|
57
|
+
self.images = []
|
|
58
|
+
|
|
59
|
+
# Multiple models data
|
|
60
|
+
self.stats = []
|
|
61
|
+
self.runs = []
|
|
62
|
+
|
|
63
|
+
# Chances
|
|
64
|
+
self.PP = CCT*Dt/24*100
|
|
65
|
+
self.PM = 100*mu/24
|
|
66
|
+
self.PA = PA
|
|
67
|
+
self.PS = PS
|
|
68
|
+
|
|
69
|
+
# Immune Data
|
|
70
|
+
self.it_ratio = []
|
|
71
|
+
self.kill_day = []
|
|
72
|
+
|
|
73
|
+
# ---------------------------------------------------------------------
|
|
74
|
+
def init_state(self):
|
|
75
|
+
"""
|
|
76
|
+
Creates the initial state with one STC in the middle.
|
|
77
|
+
Creates the field for immune cells and mutations too.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
self.field = np.zeros((self.side, self.side))
|
|
81
|
+
self.immune = np.zeros((self.side, self.side))
|
|
82
|
+
self.mutate = np.zeros((self.side, self.side))
|
|
83
|
+
self.mutmap = np.zeros((self.side, self.side))
|
|
84
|
+
|
|
85
|
+
self.mod_cell(self.side//2, self.side//2, self.pmax+1)
|
|
86
|
+
|
|
87
|
+
# ---------------------------------------------------------------------
|
|
88
|
+
def find_tumor_cells(self):
|
|
89
|
+
"""
|
|
90
|
+
Saves the coordinates of tumor cells to self.tumor_cells.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
# Where are tumor cells?
|
|
94
|
+
coords = np.nonzero(self.field)
|
|
95
|
+
coords = np.transpose(coords)
|
|
96
|
+
|
|
97
|
+
# Shuffle to randomize direction
|
|
98
|
+
np.random.shuffle(coords)
|
|
99
|
+
self.tumor_cells = coords
|
|
100
|
+
|
|
101
|
+
# ---------------------------------------------------------------------
|
|
102
|
+
def count_tumor_cells(self):
|
|
103
|
+
"""
|
|
104
|
+
Saves the number of STCs/RTCs to self.stc_number/self.rtc_number.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
# Count RTC and STC
|
|
108
|
+
stc_count = np.count_nonzero(self.field == self.pmax + 1)
|
|
109
|
+
rtc_count = len(self.tumor_cells) - stc_count
|
|
110
|
+
|
|
111
|
+
# Save the current number
|
|
112
|
+
self.stc_number.append(stc_count)
|
|
113
|
+
self.rtc_number.append(rtc_count)
|
|
114
|
+
|
|
115
|
+
# ---------------------------------------------------------------------
|
|
116
|
+
def get_neighbours(self, x, y, neighbour_type):
|
|
117
|
+
"""
|
|
118
|
+
Returns the neighboring coordinates of a given cell in a 2D NumPy matrix.
|
|
119
|
+
|
|
120
|
+
Parameters:
|
|
121
|
+
x, y (int): representing the coordinates of the cell
|
|
122
|
+
neighbour_type (int): type of neighboring cells (1-5)
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
list: a list with the coords of the neighbouring cells
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
directions = [
|
|
129
|
+
(-1, 0), (1, 0),
|
|
130
|
+
(0, -1), (0, 1),
|
|
131
|
+
(-1,-1), (-1,1),
|
|
132
|
+
(1, -1), (1, 1)]
|
|
133
|
+
|
|
134
|
+
coords = []
|
|
135
|
+
for dx, dy in directions:
|
|
136
|
+
nx, ny = x + dx, y + dy
|
|
137
|
+
if 0 < nx < self.side-1 and 0 < ny < self.side-1:
|
|
138
|
+
coords.append([nx, ny])
|
|
139
|
+
|
|
140
|
+
neighbours = []
|
|
141
|
+
for n in coords:
|
|
142
|
+
match neighbour_type:
|
|
143
|
+
case 1:
|
|
144
|
+
# Return list of empty cells
|
|
145
|
+
if (self.field[n[0],n[1]] == 0 and self.immune[n[0],n[1]] == 0):
|
|
146
|
+
neighbours.append(n)
|
|
147
|
+
case 2:
|
|
148
|
+
# Return list of tumor cells
|
|
149
|
+
if self.field[n[0],n[1]] != 0:
|
|
150
|
+
neighbours.append(n)
|
|
151
|
+
case 3:
|
|
152
|
+
# Return list of immune cells
|
|
153
|
+
if self.immune[n[0],n[1]] != 0:
|
|
154
|
+
neighbours.append(n)
|
|
155
|
+
case 4:
|
|
156
|
+
# Return list of any cells
|
|
157
|
+
if (self.field[n[0],n[1]] != 0 or self.immune[n[0],n[1]] != 0):
|
|
158
|
+
neighbours.append(n)
|
|
159
|
+
case 5:
|
|
160
|
+
# # Return list of free/immune cells
|
|
161
|
+
if self.immune[n[0],n[1]] == 0:
|
|
162
|
+
neighbours.append(n)
|
|
163
|
+
|
|
164
|
+
return neighbours
|
|
165
|
+
|
|
166
|
+
# ---------------------------------------------------------------------
|
|
167
|
+
def cell_step(self, x, y, step_type):
|
|
168
|
+
"""
|
|
169
|
+
The function that makes a single cell do one of the following actions:
|
|
170
|
+
prolif STC - STC, prolif STC - RTC, prolif RTC - RTC, migration (1-4).
|
|
171
|
+
New mutations can appear every time a cell proliferates with M chance.
|
|
172
|
+
|
|
173
|
+
Parameters:
|
|
174
|
+
x, y (int): representing the coordinates of the cell
|
|
175
|
+
step_type (int): type of division or migration (1-4)
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
# Choose random target position
|
|
179
|
+
free_nb = self.get_neighbours(x, y, 1)
|
|
180
|
+
nx, ny = free_nb[random.randint(1,len(free_nb)) - 1]
|
|
181
|
+
|
|
182
|
+
match step_type:
|
|
183
|
+
case 1:
|
|
184
|
+
# Proliferation STC -> STC + STC
|
|
185
|
+
self.field[nx, ny] = self.pmax+1
|
|
186
|
+
case 2:
|
|
187
|
+
# Proliferation STC -> STC + RTC
|
|
188
|
+
self.field[nx, ny] = self.pmax
|
|
189
|
+
case 3:
|
|
190
|
+
# Proliferation RTC -> RTC + RTC
|
|
191
|
+
self.field[x, y] -= 1
|
|
192
|
+
self.field[nx, ny] = self.field[x, y]
|
|
193
|
+
case 4:
|
|
194
|
+
# Migration
|
|
195
|
+
self.field[nx, ny] = self.field[x, y]
|
|
196
|
+
self.field[x, y] = 0
|
|
197
|
+
|
|
198
|
+
if step_type < 4 and self.field[x, y] == 0:
|
|
199
|
+
self.mutate[x, y] = 0
|
|
200
|
+
|
|
201
|
+
elif step_type < 4:
|
|
202
|
+
# Inherit mother's mutation
|
|
203
|
+
self.mutate[nx, ny] = self.mutate[x, y]
|
|
204
|
+
|
|
205
|
+
# Chance of a new mutation
|
|
206
|
+
if self.M >= random.randint(1, 100):
|
|
207
|
+
mut = random.choice([-1,1])
|
|
208
|
+
self.mutate[nx, ny] = np.clip(self.mutate[nx, ny]+mut, -3, 3)
|
|
209
|
+
|
|
210
|
+
# Mutation influences pp value
|
|
211
|
+
if step_type != 1:
|
|
212
|
+
self.field[nx, ny] = np.clip(self.field[nx, ny]+mut, 1, self.pmax)
|
|
213
|
+
|
|
214
|
+
self.mutmap[nx, ny] = self.mutate[nx, ny]
|
|
215
|
+
else:
|
|
216
|
+
self.mutate[nx, ny] = self.mutate[x, y]
|
|
217
|
+
self.mutmap[nx, ny] = self.mutate[x, y]
|
|
218
|
+
self.mutate[x, y] = 0
|
|
219
|
+
|
|
220
|
+
# ---------------------------------------------------------------------
|
|
221
|
+
def tumor_action(self):
|
|
222
|
+
"""
|
|
223
|
+
This is the function that decides what action a cell will do.
|
|
224
|
+
Either kills the cell or calls the 'cell_step' function.
|
|
225
|
+
This function goes through every single cell in the field.
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
for cell in self.tumor_cells:
|
|
229
|
+
x, y = cell
|
|
230
|
+
is_stc = (self.field[x, y] == self.pmax + 1)
|
|
231
|
+
|
|
232
|
+
# Probabilities
|
|
233
|
+
probs = np.array([self.PA, self.PP, self.PM, 0], dtype=float)
|
|
234
|
+
if is_stc:
|
|
235
|
+
probs[0] = 0
|
|
236
|
+
if not self.get_neighbours(x, y, 1):
|
|
237
|
+
probs[1:3] = 0
|
|
238
|
+
probs = self.mutate_probs(probs, x, y)
|
|
239
|
+
probs /= probs.sum()
|
|
240
|
+
|
|
241
|
+
# Choose action
|
|
242
|
+
choice = np.random.choice(4, p=probs)
|
|
243
|
+
|
|
244
|
+
if choice == 0: # apoptosis
|
|
245
|
+
self.field[x, y] = 0
|
|
246
|
+
self.mutate[x, y] = 0
|
|
247
|
+
|
|
248
|
+
elif choice == 1: # proliferation
|
|
249
|
+
if is_stc and np.random.rand() < self.PS/100:
|
|
250
|
+
self.cell_step(x, y, 1) # STC-STC division
|
|
251
|
+
elif is_stc:
|
|
252
|
+
self.cell_step(x, y, 2) # STC-RTC division
|
|
253
|
+
else:
|
|
254
|
+
self.cell_step(x, y, 3) # RTC-RTC division
|
|
255
|
+
|
|
256
|
+
elif choice == 2: # migration
|
|
257
|
+
self.cell_step(x, y, 4)
|
|
258
|
+
|
|
259
|
+
# ---------------------------------------------------------------------
|
|
260
|
+
def mutate_probs(self, chances, x, y):
|
|
261
|
+
"""
|
|
262
|
+
The function that changes the cell action chances
|
|
263
|
+
based on the current mutation status of the cell.
|
|
264
|
+
|
|
265
|
+
Parameters:
|
|
266
|
+
chances (list of float): the base action chances
|
|
267
|
+
x, y (int): representing coordinates of the cell
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
mut_state = self.mutate[x, y]/2
|
|
271
|
+
|
|
272
|
+
if mut_state > 0:
|
|
273
|
+
mut_state += 1
|
|
274
|
+
chances[0] = chances[0]/mut_state # Decreased chance for apoptosis
|
|
275
|
+
chances[1] = chances[1]*mut_state # Increased proliferation chance
|
|
276
|
+
elif mut_state < 0:
|
|
277
|
+
mut_state -= 1
|
|
278
|
+
chances[0] = chances[0]*abs(mut_state) # Increased chance for apoptosis
|
|
279
|
+
chances[1] = chances[1]/abs(mut_state) # Decreased proliferation chance
|
|
280
|
+
|
|
281
|
+
if chances.sum() <= 100:
|
|
282
|
+
chances[3] = 100 - chances.sum()
|
|
283
|
+
return chances
|
|
284
|
+
|
|
285
|
+
# ---------------------------------------------------------------------
|
|
286
|
+
def immune_response(self, offset = 10, alpha = 0.002, it_targ = 0.1, infil = 0.3):
|
|
287
|
+
"""
|
|
288
|
+
The function that simulates immune cells.
|
|
289
|
+
Spawns, moves and activates immune cells.
|
|
290
|
+
|
|
291
|
+
Parameters:
|
|
292
|
+
offset (int): distance of spawnpoints ("frame") from the tumor
|
|
293
|
+
alpha (float): controls strength (slope) of immune exhaustion
|
|
294
|
+
it_targ (float): desired mean immune/tumor ratio during simulation
|
|
295
|
+
infil (float): "searching/infiltrating" threshold for wbcs (0-1)
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
# Current tumor cell locations
|
|
299
|
+
self.find_tumor_cells()
|
|
300
|
+
tumor_size = len(self.tumor_cells)
|
|
301
|
+
if tumor_size == 0:
|
|
302
|
+
# Count existing immune cells
|
|
303
|
+
coords = np.argwhere(self.immune > 0)
|
|
304
|
+
self.immune_cells = coords
|
|
305
|
+
immune_size = len(coords)
|
|
306
|
+
for x, y in coords:
|
|
307
|
+
self.immune[x, y] -= 1
|
|
308
|
+
self.wbc_number.append(immune_size)
|
|
309
|
+
return
|
|
310
|
+
|
|
311
|
+
tumor_x = [x for x, _ in self.tumor_cells]
|
|
312
|
+
tumor_y = [y for _, y in self.tumor_cells]
|
|
313
|
+
|
|
314
|
+
# Find spawnpoints (a "frame" around tumor)
|
|
315
|
+
frame = []
|
|
316
|
+
left = max(1, min(tumor_x) - offset)
|
|
317
|
+
right = min(self.side - 2, max(tumor_x) + offset)
|
|
318
|
+
top = max(1, min(tumor_y) - offset)
|
|
319
|
+
bottom = min(self.side - 2, max(tumor_y) + offset)
|
|
320
|
+
|
|
321
|
+
for j in range(left, right + 1):
|
|
322
|
+
frame.append([top, j])
|
|
323
|
+
frame.append([bottom, j])
|
|
324
|
+
for i in range(top, bottom + 1):
|
|
325
|
+
frame.append([i, left])
|
|
326
|
+
frame.append([i, right])
|
|
327
|
+
self.spawnpoints = np.array(frame)
|
|
328
|
+
|
|
329
|
+
# Immune exhaustion = time-dependent decline
|
|
330
|
+
IE = 1.0 / (1.0 + alpha * self.cycles)
|
|
331
|
+
IE = max(IE, 0.2)
|
|
332
|
+
|
|
333
|
+
# Saturating spawn (sigmoid-like), delayed onse
|
|
334
|
+
spawn = self.I * (tumor_size / (tumor_size + self.I * 100)) * IE
|
|
335
|
+
|
|
336
|
+
coords = np.nonzero(self.immune)
|
|
337
|
+
it_ratio = len(np.transpose(coords)) / tumor_size
|
|
338
|
+
for cell in self.spawnpoints:
|
|
339
|
+
x, y = cell
|
|
340
|
+
if np.random.rand() < spawn / 50:
|
|
341
|
+
if self.immune[x, y] == 0 and self.field[x, y] == 0 and it_ratio <= it_targ:
|
|
342
|
+
# Immune cell lifespan: I-1 weeks to I+1 weeks
|
|
343
|
+
min_life = min(24, (self.I-1)*168)
|
|
344
|
+
max_life = (self.I+1)*168
|
|
345
|
+
self.immune[x, y] = np.random.randint(min_life, max_life)
|
|
346
|
+
|
|
347
|
+
# Chemoattractant map for tumor density
|
|
348
|
+
self.chemo = (self.field > 0).astype(float)
|
|
349
|
+
self.chemo = gaussian_filter(self.chemo, sigma=5)
|
|
350
|
+
self.chemo = self.chemo / np.max(self.chemo)
|
|
351
|
+
|
|
352
|
+
# Immune action
|
|
353
|
+
coords = np.nonzero(self.immune)
|
|
354
|
+
kills_per_hour = 0
|
|
355
|
+
self.immune_cells = np.transpose(coords)
|
|
356
|
+
|
|
357
|
+
# Temporary immune grid
|
|
358
|
+
new_immune = np.zeros_like(self.immune)
|
|
359
|
+
|
|
360
|
+
for (x, y) in self.immune_cells:
|
|
361
|
+
strength = self.immune[x, y]
|
|
362
|
+
if strength <= 0:
|
|
363
|
+
continue
|
|
364
|
+
|
|
365
|
+
# Kill prob on contact: (0.15 - 0.3, if I=5, IE = 0)
|
|
366
|
+
tumor_nb = self.get_neighbours(x, y, 2)
|
|
367
|
+
if tumor_nb:
|
|
368
|
+
tx, ty = random.choice(tumor_nb)
|
|
369
|
+
kill = (0.05*self.I) * np.exp(-0.25*self.mutate[tx,ty]) * IE
|
|
370
|
+
kill = min(kill, 0.3)
|
|
371
|
+
if np.random.rand() < kill:
|
|
372
|
+
self.field[tx, ty] = 0
|
|
373
|
+
self.mutate[tx, ty] = 0
|
|
374
|
+
kills_per_hour += 1
|
|
375
|
+
|
|
376
|
+
# # Multiple moves/cycle as immune cells are faster
|
|
377
|
+
moves = int(1 + self.I * (1 - self.chemo[x, y]))
|
|
378
|
+
for _ in range(moves):
|
|
379
|
+
free_nb = self.get_neighbours(x, y, 1)
|
|
380
|
+
if not free_nb:
|
|
381
|
+
strength -= 1
|
|
382
|
+
break
|
|
383
|
+
|
|
384
|
+
# Biased movement towards tumor density (chemotaxis)
|
|
385
|
+
t_dens = [self.chemo[i, j] for (i, j) in free_nb]
|
|
386
|
+
if sum(t_dens) > 0:
|
|
387
|
+
weights = np.array(t_dens) / sum(t_dens)
|
|
388
|
+
tx, ty = free_nb[np.random.choice(len(free_nb), p=weights)]
|
|
389
|
+
else:
|
|
390
|
+
tx, ty = random.choice(free_nb)
|
|
391
|
+
x, y = tx, ty
|
|
392
|
+
strength -= 1
|
|
393
|
+
|
|
394
|
+
if strength > 0:
|
|
395
|
+
new_immune[x, y] = strength
|
|
396
|
+
self.immune = new_immune
|
|
397
|
+
|
|
398
|
+
# Save number of immune cells
|
|
399
|
+
immune_size = len(self.immune_cells)
|
|
400
|
+
self.wbc_number.append(immune_size)
|
|
401
|
+
self.it_ratio.append(immune_size / tumor_size)
|
|
402
|
+
|
|
403
|
+
# Infiltrating immune cells
|
|
404
|
+
wbc_infil = sum(1 for (x,y) in self.immune_cells if self.chemo[x,y] >= infil)
|
|
405
|
+
if immune_size > 0:
|
|
406
|
+
self.kill_day.append(kills_per_hour / max(1, wbc_infil) * 24)
|
|
407
|
+
|
|
408
|
+
# ---------------------------------------------------------------------
|
|
409
|
+
def animate(self, mode):
|
|
410
|
+
"""
|
|
411
|
+
Creates and returns animation of the growth.
|
|
412
|
+
|
|
413
|
+
Parameters:
|
|
414
|
+
mode (int): create figure, save frame or display animation. (1-3)
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
ArtistAnimation: the animation of the growth (optional)
|
|
418
|
+
"""
|
|
419
|
+
|
|
420
|
+
if mode == 1:
|
|
421
|
+
# Create the figure
|
|
422
|
+
self.fig, self.ax = plt.subplots()
|
|
423
|
+
self.ax.imshow(self.field)
|
|
424
|
+
self.ax.set_title(str(self.cycles)+ " hour cell growth")
|
|
425
|
+
self.ax.set_xlabel(str(self.side*10) + " micrometers")
|
|
426
|
+
self.ax.set_ylabel(str(self.side*10) + " micrometers")
|
|
427
|
+
elif mode == 2:
|
|
428
|
+
# Save the current frame
|
|
429
|
+
growth = self.ax.imshow(self.field, animated=True)
|
|
430
|
+
immune_coords = np.argwhere(self.immune > 0)
|
|
431
|
+
immune = self.ax.scatter(immune_coords[:,1], immune_coords[:,0], c='blue', s=10)
|
|
432
|
+
self.images.append([growth, immune])
|
|
433
|
+
elif mode == 3:
|
|
434
|
+
# Display the animation
|
|
435
|
+
return animation.ArtistAnimation(self.fig, self.images, interval=50, blit=True)
|
|
436
|
+
|
|
437
|
+
# ---------------------------------------------------------------------
|
|
438
|
+
def save_field_to_excel(self, file_name):
|
|
439
|
+
"""
|
|
440
|
+
Saves the current state of self.field to an excel file.
|
|
441
|
+
|
|
442
|
+
Parameters:
|
|
443
|
+
file_name (str): name of the excel file
|
|
444
|
+
"""
|
|
445
|
+
|
|
446
|
+
pd.DataFrame(self.field).to_excel(file_name, index=False)
|
|
447
|
+
|
|
448
|
+
# ---------------------------------------------------------------------
|
|
449
|
+
def mod_cell(self, x, y, value):
|
|
450
|
+
"""
|
|
451
|
+
Modifies cell value. (Create initial state before this!)
|
|
452
|
+
|
|
453
|
+
Parameters:
|
|
454
|
+
x, y (int): representing coordinates of the cell
|
|
455
|
+
value (int): the new value at the given position
|
|
456
|
+
"""
|
|
457
|
+
|
|
458
|
+
self.field[y][x] = value
|
|
459
|
+
|
|
460
|
+
# ---------------------------------------------------------------------
|
|
461
|
+
def get_prolif_potentials(self):
|
|
462
|
+
"""
|
|
463
|
+
Returns a dictionary of proliferation potential numbers.
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
dict: a dictionary of the proliferation potentials
|
|
467
|
+
"""
|
|
468
|
+
|
|
469
|
+
nonzero_field = np.array(self.field)[np.array(self.field) > 0]
|
|
470
|
+
unique, counts = np.unique(nonzero_field, return_counts=True)
|
|
471
|
+
prolif_potents = {}
|
|
472
|
+
|
|
473
|
+
for i in range(1, self.pmax + 2):
|
|
474
|
+
prolif_potents[i] = 0
|
|
475
|
+
for val, count in zip(unique, counts):
|
|
476
|
+
prolif_potents[int(val)] = count
|
|
477
|
+
|
|
478
|
+
return prolif_potents
|
|
479
|
+
|
|
480
|
+
# ---------------------------------------------------------------------
|
|
481
|
+
def get_statistics(self):
|
|
482
|
+
"""
|
|
483
|
+
Returns various statistical properties of the model.
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
dict: a dictionary of the statistical properties
|
|
487
|
+
"""
|
|
488
|
+
|
|
489
|
+
nonzero_field = self.field[self.field > 0]
|
|
490
|
+
|
|
491
|
+
# Statistics
|
|
492
|
+
if nonzero_field.size != 0:
|
|
493
|
+
stats = {
|
|
494
|
+
"Min": nonzero_field.min(),
|
|
495
|
+
"Max": nonzero_field.max(),
|
|
496
|
+
"Mean": nonzero_field.mean(),
|
|
497
|
+
"Std": nonzero_field.std(),
|
|
498
|
+
"Median": np.median(nonzero_field),
|
|
499
|
+
"Skew": skew(nonzero_field.ravel()),
|
|
500
|
+
"Kurtosis": kurtosis(nonzero_field.ravel()),
|
|
501
|
+
"Final STC": self.stc_number[self.cycles-1],
|
|
502
|
+
"Final RTC": self.rtc_number[self.cycles-1],
|
|
503
|
+
"Final WBC": self.wbc_number[self.cycles-1],
|
|
504
|
+
"Tumor Size": nonzero_field.size,
|
|
505
|
+
"Confluence": nonzero_field.size/self.field.size*100,
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
if self.I > 0:
|
|
509
|
+
stats.update({
|
|
510
|
+
"Mean I/T" : sum(self.it_ratio)/len(self.it_ratio),
|
|
511
|
+
"Mean k/d" : sum(self.kill_day)/len(self.kill_day)
|
|
512
|
+
})
|
|
513
|
+
|
|
514
|
+
# Proliferation potentials
|
|
515
|
+
stats.update(self.get_prolif_potentials())
|
|
516
|
+
|
|
517
|
+
# Cell Numbers
|
|
518
|
+
checkpoints = np.linspace(0, self.cycles - 1, int(self.cycles/10) + 1, dtype=int)
|
|
519
|
+
for idx in checkpoints:
|
|
520
|
+
hour = (idx + 1)
|
|
521
|
+
stats[f"{hour}h_STC"] = self.stc_number[idx]
|
|
522
|
+
stats[f"{hour}h_RTC"] = self.rtc_number[idx]
|
|
523
|
+
stats[f"{hour}h_WBC"] = self.wbc_number[idx]
|
|
524
|
+
|
|
525
|
+
else: stats = {"Status": "Extinct"}
|
|
526
|
+
return stats
|
|
527
|
+
|
|
528
|
+
# ---------------------------------------------------------------------
|
|
529
|
+
def save_statistics(self, file_name):
|
|
530
|
+
"""
|
|
531
|
+
Saves various statistical properties of the model to an excel file.
|
|
532
|
+
|
|
533
|
+
Parameters:
|
|
534
|
+
file_name (str): name of the excel file
|
|
535
|
+
"""
|
|
536
|
+
|
|
537
|
+
stats_dict = self.get_statistics()
|
|
538
|
+
df = pd.DataFrame([stats_dict])
|
|
539
|
+
df.to_excel(file_name, index=False)
|
|
540
|
+
|
|
541
|
+
# ---------------------------------------------------------------------
|
|
542
|
+
def measure_runtime(func):
|
|
543
|
+
# Decorator to measure completion time
|
|
544
|
+
@wraps(func)
|
|
545
|
+
def wrapper(*args, **kwargs):
|
|
546
|
+
start_time = time.time()
|
|
547
|
+
result = func(*args, **kwargs)
|
|
548
|
+
end_time = time.time()
|
|
549
|
+
runtime = end_time - start_time
|
|
550
|
+
print("Model completion time (s): " + str(runtime))
|
|
551
|
+
return result
|
|
552
|
+
return wrapper
|
|
553
|
+
|
|
554
|
+
# ---------------------------------------------------------------------
|
|
555
|
+
@measure_runtime
|
|
556
|
+
def run_model(self, plot, animate, stats):
|
|
557
|
+
"""
|
|
558
|
+
The function that runs a single entire simulation.
|
|
559
|
+
For animation matplotlib backend cannot be inline!
|
|
560
|
+
|
|
561
|
+
Parameters:
|
|
562
|
+
plot (bool): set to true to display the plots of the model
|
|
563
|
+
animate (bool): set to true to enable matplotlib animation
|
|
564
|
+
stats (bool): set to true to print statistics of the model
|
|
565
|
+
"""
|
|
566
|
+
|
|
567
|
+
# Create initial state
|
|
568
|
+
if len(self.field) == 0: self.init_state()
|
|
569
|
+
self.find_tumor_cells()
|
|
570
|
+
if len(self.immune) == 0:
|
|
571
|
+
self.immune = np.zeros((self.side, self.side))
|
|
572
|
+
if len(self.mutate) == 0:
|
|
573
|
+
self.mutate = np.zeros((self.side, self.side))
|
|
574
|
+
self.mutmap = np.zeros((self.side, self.side))
|
|
575
|
+
|
|
576
|
+
self.stc_number = []
|
|
577
|
+
self.rtc_number = []
|
|
578
|
+
self.wbc_number = []
|
|
579
|
+
|
|
580
|
+
if animate: self.animate(1)
|
|
581
|
+
|
|
582
|
+
# Growth loop
|
|
583
|
+
for c in tqdm(range(self.cycles), desc="Running simulation..."):
|
|
584
|
+
self.tumor_action()
|
|
585
|
+
self.immune_response()
|
|
586
|
+
self.find_tumor_cells()
|
|
587
|
+
self.count_tumor_cells()
|
|
588
|
+
if animate: self.animate(2)
|
|
589
|
+
|
|
590
|
+
# Store the results
|
|
591
|
+
self.store_model()
|
|
592
|
+
|
|
593
|
+
# Output settings
|
|
594
|
+
if plot: self.plot_run(len(self.runs))
|
|
595
|
+
if animate: self.ani = self.animate(3)
|
|
596
|
+
if stats:
|
|
597
|
+
df = pd.DataFrame(self.stats)
|
|
598
|
+
base_cols = self.separate_columns(df)[0]
|
|
599
|
+
print(df[base_cols])
|
|
600
|
+
|
|
601
|
+
# ---------------------------------------------------------------------
|
|
602
|
+
@measure_runtime
|
|
603
|
+
def run_multimodel(self, count, init_field, plot, stats):
|
|
604
|
+
"""
|
|
605
|
+
Runs the model multiple times and returns a DataFrame of statistics.
|
|
606
|
+
|
|
607
|
+
Parameters:
|
|
608
|
+
count (int): number of times to run the simulation
|
|
609
|
+
init_field (np.array): custom initial state of field/run
|
|
610
|
+
plot (bool): set to true to display the plots of the model
|
|
611
|
+
stats (bool): set to true to print statistics of the model
|
|
612
|
+
|
|
613
|
+
Returns:
|
|
614
|
+
pd.DataFrame: collected statistics from each run
|
|
615
|
+
"""
|
|
616
|
+
|
|
617
|
+
stats = []
|
|
618
|
+
|
|
619
|
+
for i in range(count):
|
|
620
|
+
self.field = init_field.copy()
|
|
621
|
+
self.immune = []
|
|
622
|
+
self.mutate = []
|
|
623
|
+
self.mutmap = []
|
|
624
|
+
self.run_model(plot = False, animate = False, stats = False)
|
|
625
|
+
stats.append(self.get_statistics())
|
|
626
|
+
all_stats = pd.DataFrame(stats)
|
|
627
|
+
|
|
628
|
+
if plot:
|
|
629
|
+
self.plot_averages(all_stats)
|
|
630
|
+
if stats:
|
|
631
|
+
df = pd.DataFrame(self.stats)
|
|
632
|
+
base_cols = self.separate_columns(df)[0]
|
|
633
|
+
print(df[base_cols])
|
|
634
|
+
|
|
635
|
+
return all_stats
|
|
636
|
+
|
|
637
|
+
# ---------------------------------------------------------------------
|
|
638
|
+
def store_model(self):
|
|
639
|
+
"""
|
|
640
|
+
Stores the results of the previous model executions.
|
|
641
|
+
"""
|
|
642
|
+
|
|
643
|
+
result = {}
|
|
644
|
+
|
|
645
|
+
result["immune"] = self.immune
|
|
646
|
+
result["mutate"] = self.mutate
|
|
647
|
+
result["mutmap"] = self.mutmap
|
|
648
|
+
result["field"] = self.field
|
|
649
|
+
result["stc"] = self.stc_number
|
|
650
|
+
result["rtc"] = self.rtc_number
|
|
651
|
+
result["wbc"] = self.wbc_number
|
|
652
|
+
result["pp"] = self.get_prolif_potentials().values()
|
|
653
|
+
|
|
654
|
+
# Stores data for plotting
|
|
655
|
+
self.runs.append(result)
|
|
656
|
+
|
|
657
|
+
# Stores data for statistics
|
|
658
|
+
self.stats.append(self.get_statistics())
|
|
659
|
+
|
|
660
|
+
# ---------------------------------------------------------------------
|
|
661
|
+
def separate_columns(self, data):
|
|
662
|
+
"""
|
|
663
|
+
Separates the statistics DataFrame columns into logical groups:
|
|
664
|
+
base stats, STC, RTC, WBC counts, and proliferation potentials.
|
|
665
|
+
|
|
666
|
+
Parameters:
|
|
667
|
+
data (pd.DataFrame): Your data in a pandas dataframe format
|
|
668
|
+
|
|
669
|
+
Returns:
|
|
670
|
+
tuple of list[str]: A tuple containing 5 lists of column names:
|
|
671
|
+
- base: Columns with general statistical properties
|
|
672
|
+
- stc: Columns with STC counts at each time point
|
|
673
|
+
- rtc: Columns with RTC counts at each time point
|
|
674
|
+
- wbc: Columns with WBC counts at each time point
|
|
675
|
+
- pp: Columns for proliferation potential values
|
|
676
|
+
"""
|
|
677
|
+
|
|
678
|
+
base = [col for col in data.columns if not str(col).isdigit()
|
|
679
|
+
and "_STC" not in str(col)
|
|
680
|
+
and "_RTC" not in str(col)
|
|
681
|
+
and "_WBC" not in str(col)]
|
|
682
|
+
stc = sorted([col for col in data.columns if "_STC" in str(col)],
|
|
683
|
+
key=lambda x: int(str(x).split("h")[0]))
|
|
684
|
+
rtc = sorted([col for col in data.columns if "_RTC" in str(col)],
|
|
685
|
+
key=lambda x: int(str(x).split("h")[0]))
|
|
686
|
+
wbc = sorted([col for col in data.columns if "_WBC" in str(col)],
|
|
687
|
+
key=lambda x: int(str(x).split("h")[0]))
|
|
688
|
+
pp = sorted([col for col in data.columns if isinstance(col, int)])
|
|
689
|
+
|
|
690
|
+
return base, stc, rtc, wbc, pp
|
|
691
|
+
|
|
692
|
+
# ---------------------------------------------------------------------
|
|
693
|
+
def plot_run(self, run):
|
|
694
|
+
"""
|
|
695
|
+
Creates growth and cell number plots, proliferation potential histograms.
|
|
696
|
+
|
|
697
|
+
Paramteres:
|
|
698
|
+
run (int): which model execution to plot
|
|
699
|
+
|
|
700
|
+
Returns:
|
|
701
|
+
matplotlib.figure.Figure: the generated plots of the specific run
|
|
702
|
+
"""
|
|
703
|
+
|
|
704
|
+
# Create the figue and axis
|
|
705
|
+
fig, axs = plt.subplots(2, 2, figsize=(14,14))
|
|
706
|
+
|
|
707
|
+
tumor = axs[0, 0].imshow(self.runs[run-1]["field"], vmin=0, vmax=self.pmax+1)
|
|
708
|
+
fig.colorbar(tumor, ax=axs[0, 0])
|
|
709
|
+
|
|
710
|
+
immune_coords = np.argwhere(self.runs[run-1]["immune"] > 0)
|
|
711
|
+
axs[0, 0].scatter(immune_coords[:,1], immune_coords[:,0],
|
|
712
|
+
c='blue', marker='v', s=10)
|
|
713
|
+
|
|
714
|
+
axs[0, 1].plot(self.runs[run-1]["stc"], 'C1', label='STC')
|
|
715
|
+
axs[0, 1].plot(self.runs[run-1]["rtc"], 'C2', label='RTC')
|
|
716
|
+
axs[0, 1].plot(self.runs[run-1]["wbc"], 'C3', label='WBC')
|
|
717
|
+
axs[0, 1].legend()
|
|
718
|
+
|
|
719
|
+
mutmap = axs[1, 0].imshow(self.runs[run-1]["mutmap"],
|
|
720
|
+
cmap="RdBu_r", vmin=-3, vmax=3, interpolation="bicubic")
|
|
721
|
+
fig.colorbar(mutmap, ax=axs[1, 0])
|
|
722
|
+
|
|
723
|
+
axs[1, 1].bar(range(1, self.pmax + 2), self.runs[run-1]["pp"], edgecolor='black')
|
|
724
|
+
|
|
725
|
+
# Titles/labels of the plots
|
|
726
|
+
titles = [str(self.cycles)+ "h cell growth", "Cell count",
|
|
727
|
+
"Mutation history", "Final PP values"]
|
|
728
|
+
labs_x = [str(self.side*10) + " um", "Time (h)",
|
|
729
|
+
str(self.side*10) + " um", "Proliferation potentials"]
|
|
730
|
+
labs_y = [str(self.side*10) + " um", "Cell numbers",
|
|
731
|
+
str(self.side*10) + " um", "Number of appearance"]
|
|
732
|
+
|
|
733
|
+
fig.suptitle("Simulation " + str(run) + " Results", fontsize = 16)
|
|
734
|
+
for i, ax in enumerate(axs.flat):
|
|
735
|
+
ax.set_title(titles[i])
|
|
736
|
+
ax.set_xlabel(labs_x[i])
|
|
737
|
+
ax.set_ylabel(labs_y[i])
|
|
738
|
+
|
|
739
|
+
# ---------------------------------------------------------------------
|
|
740
|
+
def plot_averages(self, data):
|
|
741
|
+
"""
|
|
742
|
+
The function that plots the averages of multiple model results.
|
|
743
|
+
Works with the results of the 'run_multimodel' function.
|
|
744
|
+
|
|
745
|
+
Parameters:
|
|
746
|
+
data (pd.DataFrame): Your data in a pandas dataframe format
|
|
747
|
+
|
|
748
|
+
Returns:
|
|
749
|
+
matplotlib.figure.Figure: The plots of the averages with SD values
|
|
750
|
+
"""
|
|
751
|
+
|
|
752
|
+
base_cols, stc_cols, rtc_cols, wbc_cols, pp_cols = self.separate_columns(data)
|
|
753
|
+
|
|
754
|
+
avg_stc = data[stc_cols].mean()
|
|
755
|
+
std_stc = data[stc_cols].std()
|
|
756
|
+
avg_rtc = data[rtc_cols].mean()
|
|
757
|
+
std_rtc = data[rtc_cols].std()
|
|
758
|
+
avg_wbc = data[wbc_cols].mean()
|
|
759
|
+
std_wbc = data[wbc_cols].std()
|
|
760
|
+
avg_pp = data[pp_cols].mean()
|
|
761
|
+
std_pp = data[pp_cols].std()
|
|
762
|
+
|
|
763
|
+
fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(14, 5))
|
|
764
|
+
timepoints = np.linspace(0, self.cycles - 1, int(self.cycles/10) + 1)
|
|
765
|
+
|
|
766
|
+
ax1.plot(timepoints, avg_stc, label='STC', color='C1')
|
|
767
|
+
ax1.fill_between(timepoints, avg_stc - std_stc, avg_stc + std_stc,
|
|
768
|
+
color='C1', alpha=0.3)
|
|
769
|
+
ax1.plot(timepoints, avg_rtc, label='RTC', color='C2')
|
|
770
|
+
ax1.fill_between(timepoints, avg_rtc - std_rtc, avg_rtc + std_rtc,
|
|
771
|
+
color='C2', alpha=0.3)
|
|
772
|
+
ax1.plot(timepoints, avg_wbc, label='WBC', color='C3')
|
|
773
|
+
ax1.fill_between(timepoints, avg_wbc - std_wbc, avg_wbc + std_wbc,
|
|
774
|
+
color='C3', alpha=0.3)
|
|
775
|
+
|
|
776
|
+
ax1.set_title("Average Tumor Cell Count")
|
|
777
|
+
ax1.set_xlabel("Model Time (hours)")
|
|
778
|
+
ax1.set_ylabel("Number of Cells")
|
|
779
|
+
ax1.legend()
|
|
780
|
+
|
|
781
|
+
ax2.bar(pp_cols, avg_pp, yerr=std_pp, capsize=5, edgecolor='black')
|
|
782
|
+
ax2.set_title("Average Proliferation Potential Distribution")
|
|
783
|
+
ax2.set_xlabel("Proliferation Potential")
|
|
784
|
+
ax2.set_ylabel("Average Count")
|
|
785
|
+
|
|
786
|
+
fig.suptitle("Averages of " + str(len(self.stats)) + " Models", fontsize = 16)
|
|
787
|
+
plt.tight_layout()
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
class TDashboard:
|
|
791
|
+
"""
|
|
792
|
+
Class for a Streamlit dashboard providing a GUI for the model.
|
|
793
|
+
|
|
794
|
+
Parameters:
|
|
795
|
+
model (TModel): The created model you want a dashboard for
|
|
796
|
+
"""
|
|
797
|
+
|
|
798
|
+
def __init__(self, model):
|
|
799
|
+
self.model = model
|
|
800
|
+
|
|
801
|
+
# ---------------------------------------------------------------------
|
|
802
|
+
def run_dashboard(self):
|
|
803
|
+
"""
|
|
804
|
+
The function that creates the entire streamlit dashboard for the model.
|
|
805
|
+
"""
|
|
806
|
+
|
|
807
|
+
st.set_page_config(layout="wide")
|
|
808
|
+
st.markdown("<h1 style='text-align: center;'>TCAMpy</h1>", unsafe_allow_html=True)
|
|
809
|
+
self.screen_width = st_javascript("window.innerWidth", key="screen_width")
|
|
810
|
+
|
|
811
|
+
tab1, tab2 = st.tabs(["SIMULATION", "MACHINE LEARNING"])
|
|
812
|
+
with tab1:
|
|
813
|
+
self.columns = [4, 1, 12]
|
|
814
|
+
self.col1, _, self.col3 = st.columns(self.columns)
|
|
815
|
+
|
|
816
|
+
with self.col1:
|
|
817
|
+
self._initialize()
|
|
818
|
+
self._modify_cell()
|
|
819
|
+
self._execute_model()
|
|
820
|
+
with self.col3:
|
|
821
|
+
self._visualize_run("Last Simulation", len(self.model.runs))
|
|
822
|
+
self._show_statistics()
|
|
823
|
+
self._reset_save_stats()
|
|
824
|
+
|
|
825
|
+
with tab2:
|
|
826
|
+
col1, col2 = st.columns(2)
|
|
827
|
+
|
|
828
|
+
with col1:
|
|
829
|
+
self._simdata_generator()
|
|
830
|
+
with col2:
|
|
831
|
+
self._train_and_predict()
|
|
832
|
+
|
|
833
|
+
# ---------------------------------------------------------------------
|
|
834
|
+
def print_title(self, title):
|
|
835
|
+
"""
|
|
836
|
+
The function that prints text as a title on the dashboard.
|
|
837
|
+
|
|
838
|
+
Parameters:
|
|
839
|
+
title (string): The text to print
|
|
840
|
+
"""
|
|
841
|
+
|
|
842
|
+
st.markdown(
|
|
843
|
+
f"<h2 style='text-align: center;'>{title}</h2>",
|
|
844
|
+
unsafe_allow_html=True
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
# ---------------------------------------------------------------------
|
|
848
|
+
def get_plot_height(self, col, scaler):
|
|
849
|
+
"""
|
|
850
|
+
The function that calculates the height of plots
|
|
851
|
+
based on screen width, column width and a scaler.
|
|
852
|
+
|
|
853
|
+
Parameters:
|
|
854
|
+
col (int): main column number
|
|
855
|
+
scalar (float): scaler for column width
|
|
856
|
+
"""
|
|
857
|
+
|
|
858
|
+
screen_width = st.session_state.get("screen_width")
|
|
859
|
+
col_width_px = screen_width * (self.columns[col-1] / sum(self.columns))
|
|
860
|
+
return int(col_width_px * scaler)
|
|
861
|
+
|
|
862
|
+
# ---------------------------------------------------------------------
|
|
863
|
+
def _initialize(self):
|
|
864
|
+
"""
|
|
865
|
+
The function that sets the parameters and initializes the model.
|
|
866
|
+
"""
|
|
867
|
+
|
|
868
|
+
self.print_title("Model Parameters")
|
|
869
|
+
|
|
870
|
+
self.model.cycles = st.slider("Model Duration (hours)", 50, 5000, value=self.model.cycles)
|
|
871
|
+
self.model.side = st.slider("Field Side Length (10um)", 10, 200, value=self.model.side)
|
|
872
|
+
self.model.pmax = st.slider("Max Proliferation Potential", 1, 20, value=self.model.pmax)
|
|
873
|
+
self.model.PA = st.slider("Apoptosis Chance (RTC) (%)", 0, 100, value=self.model.PA)
|
|
874
|
+
self.model.CCT = st.slider("Cell Cycle Time (hours)", 1, 48, value=self.model.CCT)
|
|
875
|
+
self.model.Dt = st.slider("Time Step (days)", 0.01, 1.0, value=self.model.Dt, step=0.01)
|
|
876
|
+
self.model.PS = st.slider("STC-STC Division Chance (%)", 0, 100, value=self.model.PS)
|
|
877
|
+
self.model.mu = st.slider("Migration Capacity", 0, 10, value=self.model.mu)
|
|
878
|
+
self.model.I = st.slider("Immune Strength", 0, 10, value=self.model.I)
|
|
879
|
+
self.model.M = st.slider("Mutation Chance", 0, 50, value=self.model.M)
|
|
880
|
+
|
|
881
|
+
self.model.PP = int(self.model.CCT * self.model.Dt / 24 * 100)
|
|
882
|
+
self.model.PM = 100 * self.model.mu / 24
|
|
883
|
+
|
|
884
|
+
init_config = (
|
|
885
|
+
self.model.side, self.model.cycles, self.model.pmax,
|
|
886
|
+
self.model.PA, self.model.CCT, self.model.Dt, self.model.PS,
|
|
887
|
+
self.model.mu, self.model.I, self.model.M
|
|
888
|
+
)
|
|
889
|
+
config_hash = hashlib.md5(str(init_config).encode()).hexdigest()
|
|
890
|
+
|
|
891
|
+
# Storing data for model plotting
|
|
892
|
+
if "model_runs" in st.session_state:
|
|
893
|
+
self.model.runs = st.session_state.model_runs
|
|
894
|
+
|
|
895
|
+
# Storing data for model statistics
|
|
896
|
+
if "model_stats" in st.session_state:
|
|
897
|
+
self.model.stats = st.session_state.model_stats
|
|
898
|
+
|
|
899
|
+
if (
|
|
900
|
+
"initialized" not in st.session_state
|
|
901
|
+
or "init_config_hash" not in st.session_state
|
|
902
|
+
or st.session_state.init_config_hash != config_hash
|
|
903
|
+
):
|
|
904
|
+
self.model.init_state()
|
|
905
|
+
st.session_state.field = self.model.field.copy()
|
|
906
|
+
st.session_state.immune = self.model.immune.copy()
|
|
907
|
+
st.session_state.mutate = self.model.mutate.copy()
|
|
908
|
+
st.session_state.mutmap = self.model.mutmap.copy()
|
|
909
|
+
st.session_state.initialized = True
|
|
910
|
+
st.session_state.init_config_hash = config_hash
|
|
911
|
+
|
|
912
|
+
# ---------------------------------------------------------------------
|
|
913
|
+
def _modify_cell(self):
|
|
914
|
+
"""
|
|
915
|
+
The function for initial state modification logic.
|
|
916
|
+
"""
|
|
917
|
+
|
|
918
|
+
self.print_title("Initial State")
|
|
919
|
+
|
|
920
|
+
x_coord = st.number_input("X Coordinate", 0, self.model.side - 1, value=self.model.side // 2)
|
|
921
|
+
y_coord = st.number_input("Y Coordinate", 0, self.model.side - 1, value=self.model.side // 2)
|
|
922
|
+
cell_value = st.number_input("Cell Value", 0, self.model.pmax + 1, value=self.model.pmax + 1)
|
|
923
|
+
plots_height = self.get_plot_height(1, 0.9)
|
|
924
|
+
|
|
925
|
+
if st.button("Modify Cell"):
|
|
926
|
+
self.model.field = st.session_state.field.copy()
|
|
927
|
+
self.model.mod_cell(x_coord, y_coord, cell_value)
|
|
928
|
+
st.session_state.field = self.model.field.copy()
|
|
929
|
+
st.success(f"Cell modified at ({x_coord}, {y_coord}) to {cell_value}")
|
|
930
|
+
|
|
931
|
+
field = st.session_state.field
|
|
932
|
+
heatmap = self._create_heatmap(
|
|
933
|
+
plots_height, "Initial state", "viridis",
|
|
934
|
+
"PP", 0, self.model.pmax+1, field
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
st.altair_chart(heatmap, use_container_width=True)
|
|
938
|
+
|
|
939
|
+
# ---------------------------------------------------------------------
|
|
940
|
+
def _execute_model(self):
|
|
941
|
+
"""
|
|
942
|
+
The function for model running logic.
|
|
943
|
+
"""
|
|
944
|
+
|
|
945
|
+
self.print_title("Execution")
|
|
946
|
+
|
|
947
|
+
rep = st.number_input("How many simulations?", 1)
|
|
948
|
+
|
|
949
|
+
if st.button("Run Model"):
|
|
950
|
+
with st.spinner("Running simulations..."):
|
|
951
|
+
for i in range(rep):
|
|
952
|
+
self.model.field = st.session_state.field.copy()
|
|
953
|
+
self.model.immune = st.session_state.immune.copy()
|
|
954
|
+
self.model.mutate = st.session_state.mutate.copy()
|
|
955
|
+
self.model.mutmap = st.session_state.mutmap.copy()
|
|
956
|
+
self.model.run_model(plot = False, animate=False, stats=False)
|
|
957
|
+
|
|
958
|
+
st.session_state.model_runs = self.model.runs
|
|
959
|
+
st.session_state.model_stats = self.model.stats
|
|
960
|
+
|
|
961
|
+
# ---------------------------------------------------------------------
|
|
962
|
+
def _visualize_run(self, title, run):
|
|
963
|
+
"""
|
|
964
|
+
The function for the result visualization logic.
|
|
965
|
+
|
|
966
|
+
Parameters:
|
|
967
|
+
title (string): title of the visualization
|
|
968
|
+
run (int): which model execution to plot
|
|
969
|
+
"""
|
|
970
|
+
|
|
971
|
+
if "model_runs" not in st.session_state:
|
|
972
|
+
st.warning("Simulation results will appear here...")
|
|
973
|
+
return
|
|
974
|
+
self.print_title(title)
|
|
975
|
+
|
|
976
|
+
# --- Get latest run ---
|
|
977
|
+
latest = self.model.runs[run - 1]
|
|
978
|
+
immune = latest["immune"]
|
|
979
|
+
mutmap = latest["mutmap"]
|
|
980
|
+
field = latest["field"]
|
|
981
|
+
stc = latest["stc"]
|
|
982
|
+
rtc = latest["rtc"]
|
|
983
|
+
wbc = latest["wbc"]
|
|
984
|
+
pp = latest["pp"]
|
|
985
|
+
|
|
986
|
+
# --- Create charts ---
|
|
987
|
+
plots_height = self.get_plot_height(3, 0.4)
|
|
988
|
+
|
|
989
|
+
tumor_heatmap = self._create_heatmap(
|
|
990
|
+
plots_height, "Tumor growth", "viridis",
|
|
991
|
+
"PP", 0, self.model.pmax+1, field, immune
|
|
992
|
+
)
|
|
993
|
+
mutation_map = self._create_heatmap(
|
|
994
|
+
plots_height, "Mutation history", "redblue",
|
|
995
|
+
"M", -3, 3, mutmap
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
bar_chart = self._create_bar_chart(plots_height, list(pp))
|
|
999
|
+
line_chart = self._create_line_chart(plots_height, stc, rtc, wbc)
|
|
1000
|
+
|
|
1001
|
+
# --- Layout rules ---
|
|
1002
|
+
col1, col2 = st.columns([4, 5])
|
|
1003
|
+
with col1:
|
|
1004
|
+
st.altair_chart(tumor_heatmap, use_container_width=True)
|
|
1005
|
+
st.altair_chart(mutation_map, use_container_width=True)
|
|
1006
|
+
with col2:
|
|
1007
|
+
st.altair_chart(bar_chart, use_container_width=True)
|
|
1008
|
+
st.altair_chart(line_chart, use_container_width=True)
|
|
1009
|
+
|
|
1010
|
+
# ---------------------------------------------------------------------
|
|
1011
|
+
def _create_heatmap(
|
|
1012
|
+
self, h, title, cmap, ctitle,
|
|
1013
|
+
vmin, vmax, heatmap, scatter=None
|
|
1014
|
+
):
|
|
1015
|
+
"""
|
|
1016
|
+
Creates an Altair heatmap with a scatter plot overlaid.
|
|
1017
|
+
Used for tumor field with immune cells, and mutations.
|
|
1018
|
+
|
|
1019
|
+
Parameters:
|
|
1020
|
+
h (int): the height of the plot
|
|
1021
|
+
title (string): title of the plot
|
|
1022
|
+
cmap (stirng): colormap for the heatmap
|
|
1023
|
+
vmin, vmax (int): domain for the colormap
|
|
1024
|
+
heatmap (2D array-like): array for the heatmap
|
|
1025
|
+
scatter (2D array-like): array for scatter plot
|
|
1026
|
+
|
|
1027
|
+
Returns:
|
|
1028
|
+
Altair.Chart: heatmap with scatter overlay
|
|
1029
|
+
"""
|
|
1030
|
+
|
|
1031
|
+
# --- Heatmap data ---
|
|
1032
|
+
heat_df = pd.DataFrame([
|
|
1033
|
+
{"x": x, "y": y, "value": heatmap[y, x]}
|
|
1034
|
+
for y in range(heatmap.shape[0])
|
|
1035
|
+
for x in range(heatmap.shape[1])
|
|
1036
|
+
])
|
|
1037
|
+
|
|
1038
|
+
heat_chart = alt.Chart(heat_df).mark_rect().encode(
|
|
1039
|
+
x=alt.X("x:O", title="X"),
|
|
1040
|
+
y=alt.Y("y:O", sort="descending", title="Y"),
|
|
1041
|
+
color=alt.Color("value:Q", title=ctitle,
|
|
1042
|
+
scale=alt.Scale(scheme=cmap, domain=[vmin, vmax]))
|
|
1043
|
+
).properties(
|
|
1044
|
+
title = title,
|
|
1045
|
+
width='container',
|
|
1046
|
+
height=h
|
|
1047
|
+
)
|
|
1048
|
+
|
|
1049
|
+
# --- Scatter plot data ---
|
|
1050
|
+
if scatter is not None:
|
|
1051
|
+
scatter_coords = np.argwhere(scatter > 0)
|
|
1052
|
+
scatter_df = pd.DataFrame(scatter_coords, columns=["y", "x"])
|
|
1053
|
+
|
|
1054
|
+
scatter = alt.Chart(scatter_df).mark_point(
|
|
1055
|
+
color="blue", size=h/20, filled=True, shape="circle"
|
|
1056
|
+
).encode(
|
|
1057
|
+
x=alt.X("x:O"),
|
|
1058
|
+
y=alt.Y("y:O", sort="descending")
|
|
1059
|
+
)
|
|
1060
|
+
|
|
1061
|
+
# --- Combine layers ---
|
|
1062
|
+
heat_chart = (heat_chart + scatter).properties(
|
|
1063
|
+
title=title,
|
|
1064
|
+
width='container',
|
|
1065
|
+
height=h
|
|
1066
|
+
)
|
|
1067
|
+
|
|
1068
|
+
return heat_chart
|
|
1069
|
+
|
|
1070
|
+
# ---------------------------------------------------------------------
|
|
1071
|
+
def _create_line_chart(
|
|
1072
|
+
self, h, stc, rtc, wbc, stc_l=None, stc_u=None,
|
|
1073
|
+
rtc_l=None, rtc_u=None, wbc_l=None, wbc_u=None
|
|
1074
|
+
):
|
|
1075
|
+
"""
|
|
1076
|
+
The function that creates an Altair line chart of the cell numbers.
|
|
1077
|
+
|
|
1078
|
+
Parameters:
|
|
1079
|
+
h (int): the height of the plot
|
|
1080
|
+
stc, rtc, wbc (list): a list of the cell and immune numbers (mean or raw)
|
|
1081
|
+
stc_l, rtc_l, wbc_l (list of float, optional): Lower bounds (e.g., mean - SD) for cell counts.
|
|
1082
|
+
stc_u, rtc_u, wbc_u (list of float, optional): Upper bounds (e.g., mean + SD) for cell counts.
|
|
1083
|
+
|
|
1084
|
+
Returns:
|
|
1085
|
+
Altair.Chart: represents the line chart of the cell numbers
|
|
1086
|
+
"""
|
|
1087
|
+
|
|
1088
|
+
timepoints = list(range(len(stc)))
|
|
1089
|
+
df = pd.DataFrame({
|
|
1090
|
+
"Hour": timepoints * 3,
|
|
1091
|
+
"Cell Type": ["STC"] * len(stc) + ["RTC"] * len(rtc) + ["WBC"] * len(wbc),
|
|
1092
|
+
"Mean": stc + rtc + wbc
|
|
1093
|
+
})
|
|
1094
|
+
|
|
1095
|
+
if stc_l and rtc_l and wbc_l:
|
|
1096
|
+
df["Lower"] = stc_l + rtc_l + wbc_l
|
|
1097
|
+
df["Upper"] = stc_u + rtc_u + wbc_u
|
|
1098
|
+
|
|
1099
|
+
area = alt.Chart(df).mark_area(opacity=0.3).encode(
|
|
1100
|
+
x=alt.X("Hour:Q", title="Time (hours)"),
|
|
1101
|
+
y=alt.Y("Lower:Q", title="Mean"),
|
|
1102
|
+
y2="Upper:Q",
|
|
1103
|
+
color="Cell Type:N"
|
|
1104
|
+
)
|
|
1105
|
+
else:
|
|
1106
|
+
area = None
|
|
1107
|
+
|
|
1108
|
+
line = alt.Chart(df).mark_line().encode(
|
|
1109
|
+
x="Hour:Q",
|
|
1110
|
+
y=alt.Y("Mean:Q", title="Mean"),
|
|
1111
|
+
color="Cell Type:N"
|
|
1112
|
+
)
|
|
1113
|
+
|
|
1114
|
+
chart = (area + line) if area else line
|
|
1115
|
+
return chart.properties(title="Cell Counts Over Time", height=h)
|
|
1116
|
+
|
|
1117
|
+
# ---------------------------------------------------------------------
|
|
1118
|
+
def _create_bar_chart(self, h, pp, std=None):
|
|
1119
|
+
"""
|
|
1120
|
+
Creates an Altair bar chart for proliferation potential distribution.
|
|
1121
|
+
|
|
1122
|
+
Parameters:
|
|
1123
|
+
h (int): the height of the plot
|
|
1124
|
+
pp (list of float or int): Mean or raw counts of cells per proliferation potential class
|
|
1125
|
+
std (list of float, optional): Standard deviation for each class
|
|
1126
|
+
|
|
1127
|
+
Returns:
|
|
1128
|
+
alt.Chart: An Altair chart representing the distribution of proliferation potentials
|
|
1129
|
+
"""
|
|
1130
|
+
|
|
1131
|
+
pp_df = pd.DataFrame({
|
|
1132
|
+
"Proliferation Potential": list(range(1, len(pp) + 1)),
|
|
1133
|
+
"Mean": pp
|
|
1134
|
+
})
|
|
1135
|
+
chart = alt.Chart(pp_df).mark_bar().encode(
|
|
1136
|
+
x="Proliferation Potential:O",
|
|
1137
|
+
y="Mean:Q"
|
|
1138
|
+
)
|
|
1139
|
+
|
|
1140
|
+
if std is not None:
|
|
1141
|
+
pp_df["Std"] = std
|
|
1142
|
+
error = alt.Chart(pp_df).mark_errorbar(extent="stdev").encode(
|
|
1143
|
+
x="Proliferation Potential:O",
|
|
1144
|
+
y="Mean:Q",
|
|
1145
|
+
yError="Std:Q"
|
|
1146
|
+
)
|
|
1147
|
+
chart = chart + error
|
|
1148
|
+
|
|
1149
|
+
return chart.properties(title="Proliferation Potential Distribution", height=h)
|
|
1150
|
+
|
|
1151
|
+
# ---------------------------------------------------------------------
|
|
1152
|
+
def _show_statistics(self):
|
|
1153
|
+
"""
|
|
1154
|
+
The function for the statistics printing logic.
|
|
1155
|
+
"""
|
|
1156
|
+
|
|
1157
|
+
if not self.model.stats: return
|
|
1158
|
+
|
|
1159
|
+
self.print_title("All Simulations")
|
|
1160
|
+
plots_height = self.get_plot_height(3, 0.4)
|
|
1161
|
+
df = pd.DataFrame(self.model.stats)
|
|
1162
|
+
base_cols, stc_cols, rtc_cols, wbc_cols, pp_cols = self.model.separate_columns(df)
|
|
1163
|
+
df.index = df.index + 1
|
|
1164
|
+
|
|
1165
|
+
# Display Statistics
|
|
1166
|
+
mean_row = df[base_cols].mean(numeric_only=True)
|
|
1167
|
+
std_row = df[base_cols].std(numeric_only=True)
|
|
1168
|
+
mean_row.name = "Mean"
|
|
1169
|
+
std_row.name = "Std"
|
|
1170
|
+
full_stats = pd.concat([df[base_cols], mean_row.to_frame().T, std_row.to_frame().T])
|
|
1171
|
+
st.dataframe(full_stats)
|
|
1172
|
+
|
|
1173
|
+
# Create avg charts
|
|
1174
|
+
stc_means = df[stc_cols].mean()
|
|
1175
|
+
stc_stds = df[stc_cols].std()
|
|
1176
|
+
rtc_means = df[rtc_cols].mean()
|
|
1177
|
+
rtc_stds = df[rtc_cols].std()
|
|
1178
|
+
wbc_means = df[wbc_cols].mean()
|
|
1179
|
+
wbc_stds = df[wbc_cols].std()
|
|
1180
|
+
pp_means = df[pp_cols].mean()
|
|
1181
|
+
pp_stds = df[pp_cols].std()
|
|
1182
|
+
|
|
1183
|
+
line_chart = self._create_line_chart(plots_height,
|
|
1184
|
+
list(stc_means.values), list(rtc_means.values), list(wbc_means.values),
|
|
1185
|
+
list((stc_means - stc_stds).values), list((stc_means + stc_stds).values),
|
|
1186
|
+
list((rtc_means - rtc_stds).values), list((rtc_means + rtc_stds).values),
|
|
1187
|
+
list((wbc_means - wbc_stds).values), list((wbc_means + wbc_stds).values)
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
bar_chart = self._create_bar_chart(plots_height, list(pp_means.values), list(pp_stds.values))
|
|
1191
|
+
|
|
1192
|
+
col1, col2 = st.columns(2)
|
|
1193
|
+
with col1:
|
|
1194
|
+
st.altair_chart(line_chart, use_container_width=True)
|
|
1195
|
+
with col2:
|
|
1196
|
+
st.altair_chart(bar_chart, use_container_width=True)
|
|
1197
|
+
|
|
1198
|
+
# ---------------------------------------------------------------------
|
|
1199
|
+
def _reset_save_stats(self):
|
|
1200
|
+
"""
|
|
1201
|
+
The function for the reset/download statistics logic.
|
|
1202
|
+
"""
|
|
1203
|
+
|
|
1204
|
+
if "model_stats" in st.session_state:
|
|
1205
|
+
self.print_title("Simulation Options")
|
|
1206
|
+
col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
|
|
1207
|
+
selected_run = None
|
|
1208
|
+
visualize = False
|
|
1209
|
+
|
|
1210
|
+
with col1:
|
|
1211
|
+
if st.button("Reset Model Executions and Data", use_container_width=True):
|
|
1212
|
+
del st.session_state.model_stats
|
|
1213
|
+
self.model.stats.clear()
|
|
1214
|
+
del st.session_state.model_runs
|
|
1215
|
+
self.model.runs.clear()
|
|
1216
|
+
|
|
1217
|
+
st.success("Executions have been reset.")
|
|
1218
|
+
with col2:
|
|
1219
|
+
buffer = io.BytesIO()
|
|
1220
|
+
pd.DataFrame(self.model.stats).to_excel(buffer, index=False)
|
|
1221
|
+
buffer.seek(0)
|
|
1222
|
+
|
|
1223
|
+
st.download_button(
|
|
1224
|
+
label="Download Statistics (xlsx)",
|
|
1225
|
+
data=buffer,
|
|
1226
|
+
file_name="simulation_statistics.xlsx",
|
|
1227
|
+
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
1228
|
+
use_container_width=True
|
|
1229
|
+
)
|
|
1230
|
+
with col3:
|
|
1231
|
+
run_select = st.selectbox(
|
|
1232
|
+
"", list(range(1, len(self.model.runs) + 1)),
|
|
1233
|
+
placeholder="Select simulation",
|
|
1234
|
+
label_visibility="collapsed",
|
|
1235
|
+
index=None
|
|
1236
|
+
)
|
|
1237
|
+
selected_run = run_select
|
|
1238
|
+
with col4:
|
|
1239
|
+
if st.button("Visualize Selected Simulation", use_container_width=True):
|
|
1240
|
+
if selected_run: visualize = True
|
|
1241
|
+
else: st.warning('Please select a simulation!')
|
|
1242
|
+
if visualize:
|
|
1243
|
+
self._visualize_run("Selected Simulation", selected_run)
|
|
1244
|
+
|
|
1245
|
+
# ---------------------------------------------------------------------
|
|
1246
|
+
def _simdata_generator(self):
|
|
1247
|
+
"""
|
|
1248
|
+
The Machine Learning tab for dataset generation and download.
|
|
1249
|
+
Uses the TML class to generate simulation data.
|
|
1250
|
+
"""
|
|
1251
|
+
|
|
1252
|
+
self.print_title("Simulation Data Generator")
|
|
1253
|
+
|
|
1254
|
+
# Initialize TML
|
|
1255
|
+
tml = TML(self.model)
|
|
1256
|
+
|
|
1257
|
+
st.write("Select randomization ranges for each parameter:")
|
|
1258
|
+
|
|
1259
|
+
# Build parameter range inputs dynamically
|
|
1260
|
+
param_ranges = {}
|
|
1261
|
+
for param, default_val in tml.default_params.items():
|
|
1262
|
+
col1, col2 = st.columns(2)
|
|
1263
|
+
with col1:
|
|
1264
|
+
low = st.number_input(
|
|
1265
|
+
f"{param} (min)",
|
|
1266
|
+
value=float(default_val) * 0.8,
|
|
1267
|
+
key=f"{param}_low"
|
|
1268
|
+
)
|
|
1269
|
+
with col2:
|
|
1270
|
+
high = st.number_input(
|
|
1271
|
+
f"{param} (max)",
|
|
1272
|
+
value=float(default_val) * 1.5,
|
|
1273
|
+
key=f"{param}_high"
|
|
1274
|
+
)
|
|
1275
|
+
param_ranges[param] = (low, high)
|
|
1276
|
+
|
|
1277
|
+
n = st.number_input("Number of simulations", 5, 500, 50, step=5)
|
|
1278
|
+
|
|
1279
|
+
# Run simulation button
|
|
1280
|
+
if st.button("Generate Dataset", use_container_width=True):
|
|
1281
|
+
with st.spinner("Running simulations..."):
|
|
1282
|
+
df = tml.generate_dataset(n=n, random_params=param_ranges)
|
|
1283
|
+
st.success(f"Dataset generated successfully ({len(df)} rows).")
|
|
1284
|
+
|
|
1285
|
+
# Allow CSV download
|
|
1286
|
+
csv = df.to_csv(index=False).encode("utf-8")
|
|
1287
|
+
st.download_button(
|
|
1288
|
+
label="Download Dataset (.csv)",
|
|
1289
|
+
data=csv,
|
|
1290
|
+
file_name="tumor_dataset.csv",
|
|
1291
|
+
mime="text/csv",
|
|
1292
|
+
use_container_width=True
|
|
1293
|
+
)
|
|
1294
|
+
|
|
1295
|
+
# ---------------------------------------------------------------------
|
|
1296
|
+
def _train_and_predict(self):
|
|
1297
|
+
"""
|
|
1298
|
+
Streamlit UI for model training and prediction using the TML class.
|
|
1299
|
+
"""
|
|
1300
|
+
|
|
1301
|
+
self.print_title("Model Trainer and Predictor")
|
|
1302
|
+
|
|
1303
|
+
tml = TML(self.model)
|
|
1304
|
+
|
|
1305
|
+
uploaded_file = st.file_uploader("Upload CSV dataset", type=["csv"])
|
|
1306
|
+
if uploaded_file is not None:
|
|
1307
|
+
df = pd.read_csv(uploaded_file)
|
|
1308
|
+
st.write("Uploaded dataset preview:")
|
|
1309
|
+
st.dataframe(df.head())
|
|
1310
|
+
|
|
1311
|
+
# Choose target column
|
|
1312
|
+
target = st.selectbox("Select target attribute", df.columns, index=len(df.columns) - 1)
|
|
1313
|
+
|
|
1314
|
+
# Three sliders for model parameters
|
|
1315
|
+
test_size = st.slider("Test Size", 0.1, 0.5, 0.2, step=0.05)
|
|
1316
|
+
random_state = st.slider("Random Seed", 0, 100, 42, step=1)
|
|
1317
|
+
n_estimators = st.slider("Number of Trees (n_estimators)", 50, 500, 200, step=50)
|
|
1318
|
+
|
|
1319
|
+
if st.button("Train Model", use_container_width=True):
|
|
1320
|
+
with st.spinner("Training model..."):
|
|
1321
|
+
model, metrics = tml.train_predictor(
|
|
1322
|
+
file=df,
|
|
1323
|
+
target=target,
|
|
1324
|
+
test_size=test_size,
|
|
1325
|
+
random_state=random_state,
|
|
1326
|
+
n_estimators=n_estimators
|
|
1327
|
+
)
|
|
1328
|
+
|
|
1329
|
+
st.success("Model trained successfully!")
|
|
1330
|
+
st.write(f"**R^2:** {metrics['R2']:.3f}")
|
|
1331
|
+
st.write(f"**MAE:** {metrics['MAE']:.3f}")
|
|
1332
|
+
|
|
1333
|
+
# Store trained model for later prediction
|
|
1334
|
+
st.session_state["trained_tml"] = tml
|
|
1335
|
+
st.session_state["target"] = target
|
|
1336
|
+
else:
|
|
1337
|
+
st.info("Please upload a dataset to train a model.")
|
|
1338
|
+
|
|
1339
|
+
if "trained_tml" in st.session_state:
|
|
1340
|
+
trained_tml = st.session_state["trained_tml"]
|
|
1341
|
+
target = st.session_state.get("target", "Target")
|
|
1342
|
+
feature_cols = trained_tml.feature_columns
|
|
1343
|
+
|
|
1344
|
+
# Numeric inputs for each feature
|
|
1345
|
+
new_params = []
|
|
1346
|
+
self.print_title("Predict for new instance")
|
|
1347
|
+
for col in feature_cols:
|
|
1348
|
+
val = st.number_input(f"{col}", value=1.0, key=f"pred_{col}")
|
|
1349
|
+
new_params.append(val)
|
|
1350
|
+
|
|
1351
|
+
# Target selector (for user clarity, though single-target regression)
|
|
1352
|
+
st.markdown("#### Select Target to Predict:")
|
|
1353
|
+
st.text(f"Predicting: {target}")
|
|
1354
|
+
|
|
1355
|
+
if st.button("🔮 Predict New", use_container_width=True):
|
|
1356
|
+
try:
|
|
1357
|
+
prediction = trained_tml.predict_new(new_params)
|
|
1358
|
+
st.success(f"Predicted {target}: **{prediction:.3f}**")
|
|
1359
|
+
except Exception as e:
|
|
1360
|
+
st.error(f"Prediction failed: {e}")
|
|
1361
|
+
else:
|
|
1362
|
+
st.info("Train a model first to enable prediction.")
|
|
1363
|
+
|
|
1364
|
+
|
|
1365
|
+
class TML:
|
|
1366
|
+
"""
|
|
1367
|
+
Class for handling Machine Learning tasks related to the tumor model.
|
|
1368
|
+
Allows dataset generation, parameter exploration, and result export.
|
|
1369
|
+
Allows predicting the size/confluence for a new set of parameters.
|
|
1370
|
+
|
|
1371
|
+
Parameters:
|
|
1372
|
+
model (TModel): a created instance of the TModel class.
|
|
1373
|
+
"""
|
|
1374
|
+
|
|
1375
|
+
def __init__(self, model):
|
|
1376
|
+
self.model = model
|
|
1377
|
+
|
|
1378
|
+
self.default_params = {
|
|
1379
|
+
"cycles": self.model.cycles,
|
|
1380
|
+
"side": self.model.side,
|
|
1381
|
+
"pmax": self.model.pmax,
|
|
1382
|
+
"PA": self.model.PA,
|
|
1383
|
+
"CCT": self.model.CCT,
|
|
1384
|
+
"Dt": self.model.Dt,
|
|
1385
|
+
"PS": self.model.PS,
|
|
1386
|
+
"mu": self.model.mu,
|
|
1387
|
+
"I": self.model.I,
|
|
1388
|
+
"M": self.model.M,
|
|
1389
|
+
}
|
|
1390
|
+
|
|
1391
|
+
# ---------------------------------------------------------------------
|
|
1392
|
+
def generate_dataset(
|
|
1393
|
+
self, n=50, random_params=None,
|
|
1394
|
+
output_file="tumor_dataset.csv"
|
|
1395
|
+
):
|
|
1396
|
+
"""
|
|
1397
|
+
Generate a dataset of tumor simulations by randomizing given parameters.
|
|
1398
|
+
|
|
1399
|
+
Parameters:
|
|
1400
|
+
n_sims (int): Number of simulations to run.
|
|
1401
|
+
randomize_params (dict): Parameters to randomize, e.g.
|
|
1402
|
+
{
|
|
1403
|
+
"PA": (1, 20), "PS": (10, 40))
|
|
1404
|
+
}
|
|
1405
|
+
output_file (str): CSV filename to save dataset.
|
|
1406
|
+
|
|
1407
|
+
Returns:
|
|
1408
|
+
pd.DataFrame: Combined DataFrame with all simulation results.
|
|
1409
|
+
"""
|
|
1410
|
+
|
|
1411
|
+
stats = []
|
|
1412
|
+
|
|
1413
|
+
# Randomize chosen parameters
|
|
1414
|
+
for i in tqdm(range(n), desc="Generating simulations"):
|
|
1415
|
+
params = self.default_params.copy()
|
|
1416
|
+
for key, (low, high) in random_params.items():
|
|
1417
|
+
if isinstance(params[key], int):
|
|
1418
|
+
params[key] = random.randint(int(low), int(high))
|
|
1419
|
+
else:
|
|
1420
|
+
params[key] = random.uniform(float(low), float(high))
|
|
1421
|
+
|
|
1422
|
+
# Run simulation
|
|
1423
|
+
model = TModel(**params)
|
|
1424
|
+
model.run_model(plot = False, animate = False, stats = False)
|
|
1425
|
+
|
|
1426
|
+
run_stats = {}
|
|
1427
|
+
for k, v in params.items():
|
|
1428
|
+
run_stats[k] = v
|
|
1429
|
+
run_stats["Tumor size"] = np.count_nonzero(model.field)
|
|
1430
|
+
run_stats["Confluence"] = np.count_nonzero(model.field)/model.field.size*100
|
|
1431
|
+
stats.append(run_stats)
|
|
1432
|
+
|
|
1433
|
+
if stats:
|
|
1434
|
+
df = pd.DataFrame(stats)
|
|
1435
|
+
df.to_csv(output_file, index=False)
|
|
1436
|
+
print(f"Dataset saved to {output_file} ({len(df)} runs)")
|
|
1437
|
+
return df
|
|
1438
|
+
|
|
1439
|
+
# ---------------------------------------------------------------------
|
|
1440
|
+
def train_predictor(
|
|
1441
|
+
self, file, target, test_size=0.2,
|
|
1442
|
+
random_state=42, n_estimators=200
|
|
1443
|
+
):
|
|
1444
|
+
"""
|
|
1445
|
+
Trains a regression model to predict final tumor size based on simulation parameters.
|
|
1446
|
+
|
|
1447
|
+
Parameters:
|
|
1448
|
+
file (str): CSV file containing the dataset
|
|
1449
|
+
target (str): Column name of the target attribute
|
|
1450
|
+
test_size (float): Fraction of dataset to use for testing
|
|
1451
|
+
random_state (int): Random seed for reproducibility
|
|
1452
|
+
n_estimators (int): Number of trees in the random forest
|
|
1453
|
+
|
|
1454
|
+
Returns:
|
|
1455
|
+
model (RandomForestRegressor): Trained model
|
|
1456
|
+
metrics (dict): R^2 and MAE metrics on test set
|
|
1457
|
+
"""
|
|
1458
|
+
|
|
1459
|
+
if isinstance(file, pd.DataFrame):
|
|
1460
|
+
df = file
|
|
1461
|
+
else:
|
|
1462
|
+
df = pd.read_csv(file)
|
|
1463
|
+
x = df[df.columns[0:10]]
|
|
1464
|
+
y = df[target]
|
|
1465
|
+
|
|
1466
|
+
self.feature_columns = x.columns.tolist()
|
|
1467
|
+
|
|
1468
|
+
# Split into train/test
|
|
1469
|
+
x_train, x_test, y_train, y_test = train_test_split(
|
|
1470
|
+
x, y, test_size=test_size, random_state=random_state
|
|
1471
|
+
)
|
|
1472
|
+
|
|
1473
|
+
# Train model
|
|
1474
|
+
model = RandomForestRegressor(
|
|
1475
|
+
n_estimators=n_estimators,
|
|
1476
|
+
random_state=random_state,
|
|
1477
|
+
n_jobs=-1
|
|
1478
|
+
)
|
|
1479
|
+
model.fit(x_train, y_train)
|
|
1480
|
+
|
|
1481
|
+
# Predict & evaluate
|
|
1482
|
+
y_pred = model.predict(x_test)
|
|
1483
|
+
metrics = {
|
|
1484
|
+
"R2": r2_score(y_test, y_pred),
|
|
1485
|
+
"MAE": mean_absolute_error(y_test, y_pred)
|
|
1486
|
+
}
|
|
1487
|
+
self.trained_model = model
|
|
1488
|
+
|
|
1489
|
+
print(f"Model trained on {len(x_train)} samples, tested on {len(x_test)}")
|
|
1490
|
+
print(f"R^2: {metrics['R2']:.3f}, MAE: {metrics['MAE']:.3f}")
|
|
1491
|
+
|
|
1492
|
+
# Feature importance summary
|
|
1493
|
+
importance = pd.Series(model.feature_importances_, index=x.columns).sort_values(ascending=False)
|
|
1494
|
+
print("\n Top influencing parameters:")
|
|
1495
|
+
print(importance.head())
|
|
1496
|
+
|
|
1497
|
+
return model, metrics
|
|
1498
|
+
|
|
1499
|
+
# ---------------------------------------------------------------------
|
|
1500
|
+
def predict_new(self, params):
|
|
1501
|
+
"""
|
|
1502
|
+
Predicts an attribute value for a set of
|
|
1503
|
+
parameters using a previously trained model.
|
|
1504
|
+
|
|
1505
|
+
Parameters:
|
|
1506
|
+
params (list): List of parameters, e.g.
|
|
1507
|
+
[500, 50, 10, 1, 24, 1/24, 15, 4, 4, 10]
|
|
1508
|
+
|
|
1509
|
+
Returns:
|
|
1510
|
+
float: Predicted tumor size
|
|
1511
|
+
"""
|
|
1512
|
+
if self.trained_model is None:
|
|
1513
|
+
raise RuntimeError(
|
|
1514
|
+
"No trained model found. Train one with train_predictor() first."
|
|
1515
|
+
)
|
|
1516
|
+
if self.feature_columns is None:
|
|
1517
|
+
raise RuntimeError(
|
|
1518
|
+
"Feature column list not found. Did you train the model?"
|
|
1519
|
+
)
|
|
1520
|
+
|
|
1521
|
+
if len(params) != len(self.feature_columns):
|
|
1522
|
+
raise ValueError(
|
|
1523
|
+
f"Parameter list must have {len(self.feature_columns)} values "
|
|
1524
|
+
f"(got {len(params)}). Expected order: {self.feature_columns}"
|
|
1525
|
+
)
|
|
1526
|
+
df = pd.DataFrame([params], columns=self.feature_columns)
|
|
1527
|
+
|
|
1528
|
+
# Ensure all expected features are present (fill missing ones with 0)
|
|
1529
|
+
for col in self.feature_columns:
|
|
1530
|
+
if col not in df.columns:
|
|
1531
|
+
df[col] = 0
|
|
1532
|
+
df = df[self.feature_columns]
|
|
1533
|
+
|
|
1534
|
+
return self.trained_model.predict(df)[0]
|