TCAMpy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
TCAMpy.py ADDED
@@ -0,0 +1,1534 @@
1
+ import io
2
+ import time
3
+ import random
4
+ import hashlib
5
+ import numpy as np
6
+ import pandas as pd
7
+ import altair as alt
8
+ import streamlit as st
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib.animation as animation
11
+
12
+ from sklearn.metrics import r2_score, mean_absolute_error
13
+ from sklearn.model_selection import train_test_split
14
+ from sklearn.ensemble import RandomForestRegressor
15
+ from streamlit_javascript import st_javascript
16
+ from scipy.ndimage import gaussian_filter
17
+ from scipy.stats import skew, kurtosis
18
+ from functools import wraps
19
+ from tqdm import tqdm
20
+
21
+ class TModel:
22
+ """
23
+ Class for a cellular automata, modeling tumor growth.
24
+
25
+ Parameters:
26
+ cycles (int): duration of the model given in hours
27
+ side (int): the length of the side of field (10um)
28
+ pmax (int): maximum proliferation potential of RTC
29
+ PA (int): chance for apoptosis of RTC (in percent)
30
+ CCT (int): cell cycle time of cells given in hours
31
+ Dt (float): time step of the model given in days
32
+ PS (int): STC-STC division chance (in percent)
33
+ mu (int): migration capacity of cancer cells
34
+ I (int): strength of the immune cells (1-5)
35
+ M (int): tumor mutation chance (in percent)
36
+ """
37
+
38
+ def __init__(self, cycles, side, pmax, PA, CCT, Dt, PS, mu, I, M):
39
+ # Parameters
40
+ self.cycles = cycles
41
+ self.side = side
42
+ self.pmax = pmax
43
+ self.CCT = CCT
44
+ self.Dt = Dt
45
+ self.mu = mu
46
+ self.I = I
47
+ self.M = M
48
+
49
+ # Single model data
50
+ self.stc_number = []
51
+ self.rtc_number = []
52
+ self.wbc_number = []
53
+ self.immune = []
54
+ self.mutate = []
55
+ self.mutmap = []
56
+ self.field = []
57
+ self.images = []
58
+
59
+ # Multiple models data
60
+ self.stats = []
61
+ self.runs = []
62
+
63
+ # Chances
64
+ self.PP = CCT*Dt/24*100
65
+ self.PM = 100*mu/24
66
+ self.PA = PA
67
+ self.PS = PS
68
+
69
+ # Immune Data
70
+ self.it_ratio = []
71
+ self.kill_day = []
72
+
73
+ # ---------------------------------------------------------------------
74
+ def init_state(self):
75
+ """
76
+ Creates the initial state with one STC in the middle.
77
+ Creates the field for immune cells and mutations too.
78
+ """
79
+
80
+ self.field = np.zeros((self.side, self.side))
81
+ self.immune = np.zeros((self.side, self.side))
82
+ self.mutate = np.zeros((self.side, self.side))
83
+ self.mutmap = np.zeros((self.side, self.side))
84
+
85
+ self.mod_cell(self.side//2, self.side//2, self.pmax+1)
86
+
87
+ # ---------------------------------------------------------------------
88
+ def find_tumor_cells(self):
89
+ """
90
+ Saves the coordinates of tumor cells to self.tumor_cells.
91
+ """
92
+
93
+ # Where are tumor cells?
94
+ coords = np.nonzero(self.field)
95
+ coords = np.transpose(coords)
96
+
97
+ # Shuffle to randomize direction
98
+ np.random.shuffle(coords)
99
+ self.tumor_cells = coords
100
+
101
+ # ---------------------------------------------------------------------
102
+ def count_tumor_cells(self):
103
+ """
104
+ Saves the number of STCs/RTCs to self.stc_number/self.rtc_number.
105
+ """
106
+
107
+ # Count RTC and STC
108
+ stc_count = np.count_nonzero(self.field == self.pmax + 1)
109
+ rtc_count = len(self.tumor_cells) - stc_count
110
+
111
+ # Save the current number
112
+ self.stc_number.append(stc_count)
113
+ self.rtc_number.append(rtc_count)
114
+
115
+ # ---------------------------------------------------------------------
116
+ def get_neighbours(self, x, y, neighbour_type):
117
+ """
118
+ Returns the neighboring coordinates of a given cell in a 2D NumPy matrix.
119
+
120
+ Parameters:
121
+ x, y (int): representing the coordinates of the cell
122
+ neighbour_type (int): type of neighboring cells (1-5)
123
+
124
+ Returns:
125
+ list: a list with the coords of the neighbouring cells
126
+ """
127
+
128
+ directions = [
129
+ (-1, 0), (1, 0),
130
+ (0, -1), (0, 1),
131
+ (-1,-1), (-1,1),
132
+ (1, -1), (1, 1)]
133
+
134
+ coords = []
135
+ for dx, dy in directions:
136
+ nx, ny = x + dx, y + dy
137
+ if 0 < nx < self.side-1 and 0 < ny < self.side-1:
138
+ coords.append([nx, ny])
139
+
140
+ neighbours = []
141
+ for n in coords:
142
+ match neighbour_type:
143
+ case 1:
144
+ # Return list of empty cells
145
+ if (self.field[n[0],n[1]] == 0 and self.immune[n[0],n[1]] == 0):
146
+ neighbours.append(n)
147
+ case 2:
148
+ # Return list of tumor cells
149
+ if self.field[n[0],n[1]] != 0:
150
+ neighbours.append(n)
151
+ case 3:
152
+ # Return list of immune cells
153
+ if self.immune[n[0],n[1]] != 0:
154
+ neighbours.append(n)
155
+ case 4:
156
+ # Return list of any cells
157
+ if (self.field[n[0],n[1]] != 0 or self.immune[n[0],n[1]] != 0):
158
+ neighbours.append(n)
159
+ case 5:
160
+ # # Return list of free/immune cells
161
+ if self.immune[n[0],n[1]] == 0:
162
+ neighbours.append(n)
163
+
164
+ return neighbours
165
+
166
+ # ---------------------------------------------------------------------
167
+ def cell_step(self, x, y, step_type):
168
+ """
169
+ The function that makes a single cell do one of the following actions:
170
+ prolif STC - STC, prolif STC - RTC, prolif RTC - RTC, migration (1-4).
171
+ New mutations can appear every time a cell proliferates with M chance.
172
+
173
+ Parameters:
174
+ x, y (int): representing the coordinates of the cell
175
+ step_type (int): type of division or migration (1-4)
176
+ """
177
+
178
+ # Choose random target position
179
+ free_nb = self.get_neighbours(x, y, 1)
180
+ nx, ny = free_nb[random.randint(1,len(free_nb)) - 1]
181
+
182
+ match step_type:
183
+ case 1:
184
+ # Proliferation STC -> STC + STC
185
+ self.field[nx, ny] = self.pmax+1
186
+ case 2:
187
+ # Proliferation STC -> STC + RTC
188
+ self.field[nx, ny] = self.pmax
189
+ case 3:
190
+ # Proliferation RTC -> RTC + RTC
191
+ self.field[x, y] -= 1
192
+ self.field[nx, ny] = self.field[x, y]
193
+ case 4:
194
+ # Migration
195
+ self.field[nx, ny] = self.field[x, y]
196
+ self.field[x, y] = 0
197
+
198
+ if step_type < 4 and self.field[x, y] == 0:
199
+ self.mutate[x, y] = 0
200
+
201
+ elif step_type < 4:
202
+ # Inherit mother's mutation
203
+ self.mutate[nx, ny] = self.mutate[x, y]
204
+
205
+ # Chance of a new mutation
206
+ if self.M >= random.randint(1, 100):
207
+ mut = random.choice([-1,1])
208
+ self.mutate[nx, ny] = np.clip(self.mutate[nx, ny]+mut, -3, 3)
209
+
210
+ # Mutation influences pp value
211
+ if step_type != 1:
212
+ self.field[nx, ny] = np.clip(self.field[nx, ny]+mut, 1, self.pmax)
213
+
214
+ self.mutmap[nx, ny] = self.mutate[nx, ny]
215
+ else:
216
+ self.mutate[nx, ny] = self.mutate[x, y]
217
+ self.mutmap[nx, ny] = self.mutate[x, y]
218
+ self.mutate[x, y] = 0
219
+
220
+ # ---------------------------------------------------------------------
221
+ def tumor_action(self):
222
+ """
223
+ This is the function that decides what action a cell will do.
224
+ Either kills the cell or calls the 'cell_step' function.
225
+ This function goes through every single cell in the field.
226
+ """
227
+
228
+ for cell in self.tumor_cells:
229
+ x, y = cell
230
+ is_stc = (self.field[x, y] == self.pmax + 1)
231
+
232
+ # Probabilities
233
+ probs = np.array([self.PA, self.PP, self.PM, 0], dtype=float)
234
+ if is_stc:
235
+ probs[0] = 0
236
+ if not self.get_neighbours(x, y, 1):
237
+ probs[1:3] = 0
238
+ probs = self.mutate_probs(probs, x, y)
239
+ probs /= probs.sum()
240
+
241
+ # Choose action
242
+ choice = np.random.choice(4, p=probs)
243
+
244
+ if choice == 0: # apoptosis
245
+ self.field[x, y] = 0
246
+ self.mutate[x, y] = 0
247
+
248
+ elif choice == 1: # proliferation
249
+ if is_stc and np.random.rand() < self.PS/100:
250
+ self.cell_step(x, y, 1) # STC-STC division
251
+ elif is_stc:
252
+ self.cell_step(x, y, 2) # STC-RTC division
253
+ else:
254
+ self.cell_step(x, y, 3) # RTC-RTC division
255
+
256
+ elif choice == 2: # migration
257
+ self.cell_step(x, y, 4)
258
+
259
+ # ---------------------------------------------------------------------
260
+ def mutate_probs(self, chances, x, y):
261
+ """
262
+ The function that changes the cell action chances
263
+ based on the current mutation status of the cell.
264
+
265
+ Parameters:
266
+ chances (list of float): the base action chances
267
+ x, y (int): representing coordinates of the cell
268
+ """
269
+
270
+ mut_state = self.mutate[x, y]/2
271
+
272
+ if mut_state > 0:
273
+ mut_state += 1
274
+ chances[0] = chances[0]/mut_state # Decreased chance for apoptosis
275
+ chances[1] = chances[1]*mut_state # Increased proliferation chance
276
+ elif mut_state < 0:
277
+ mut_state -= 1
278
+ chances[0] = chances[0]*abs(mut_state) # Increased chance for apoptosis
279
+ chances[1] = chances[1]/abs(mut_state) # Decreased proliferation chance
280
+
281
+ if chances.sum() <= 100:
282
+ chances[3] = 100 - chances.sum()
283
+ return chances
284
+
285
+ # ---------------------------------------------------------------------
286
+ def immune_response(self, offset = 10, alpha = 0.002, it_targ = 0.1, infil = 0.3):
287
+ """
288
+ The function that simulates immune cells.
289
+ Spawns, moves and activates immune cells.
290
+
291
+ Parameters:
292
+ offset (int): distance of spawnpoints ("frame") from the tumor
293
+ alpha (float): controls strength (slope) of immune exhaustion
294
+ it_targ (float): desired mean immune/tumor ratio during simulation
295
+ infil (float): "searching/infiltrating" threshold for wbcs (0-1)
296
+ """
297
+
298
+ # Current tumor cell locations
299
+ self.find_tumor_cells()
300
+ tumor_size = len(self.tumor_cells)
301
+ if tumor_size == 0:
302
+ # Count existing immune cells
303
+ coords = np.argwhere(self.immune > 0)
304
+ self.immune_cells = coords
305
+ immune_size = len(coords)
306
+ for x, y in coords:
307
+ self.immune[x, y] -= 1
308
+ self.wbc_number.append(immune_size)
309
+ return
310
+
311
+ tumor_x = [x for x, _ in self.tumor_cells]
312
+ tumor_y = [y for _, y in self.tumor_cells]
313
+
314
+ # Find spawnpoints (a "frame" around tumor)
315
+ frame = []
316
+ left = max(1, min(tumor_x) - offset)
317
+ right = min(self.side - 2, max(tumor_x) + offset)
318
+ top = max(1, min(tumor_y) - offset)
319
+ bottom = min(self.side - 2, max(tumor_y) + offset)
320
+
321
+ for j in range(left, right + 1):
322
+ frame.append([top, j])
323
+ frame.append([bottom, j])
324
+ for i in range(top, bottom + 1):
325
+ frame.append([i, left])
326
+ frame.append([i, right])
327
+ self.spawnpoints = np.array(frame)
328
+
329
+ # Immune exhaustion = time-dependent decline
330
+ IE = 1.0 / (1.0 + alpha * self.cycles)
331
+ IE = max(IE, 0.2)
332
+
333
+ # Saturating spawn (sigmoid-like), delayed onse
334
+ spawn = self.I * (tumor_size / (tumor_size + self.I * 100)) * IE
335
+
336
+ coords = np.nonzero(self.immune)
337
+ it_ratio = len(np.transpose(coords)) / tumor_size
338
+ for cell in self.spawnpoints:
339
+ x, y = cell
340
+ if np.random.rand() < spawn / 50:
341
+ if self.immune[x, y] == 0 and self.field[x, y] == 0 and it_ratio <= it_targ:
342
+ # Immune cell lifespan: I-1 weeks to I+1 weeks
343
+ min_life = min(24, (self.I-1)*168)
344
+ max_life = (self.I+1)*168
345
+ self.immune[x, y] = np.random.randint(min_life, max_life)
346
+
347
+ # Chemoattractant map for tumor density
348
+ self.chemo = (self.field > 0).astype(float)
349
+ self.chemo = gaussian_filter(self.chemo, sigma=5)
350
+ self.chemo = self.chemo / np.max(self.chemo)
351
+
352
+ # Immune action
353
+ coords = np.nonzero(self.immune)
354
+ kills_per_hour = 0
355
+ self.immune_cells = np.transpose(coords)
356
+
357
+ # Temporary immune grid
358
+ new_immune = np.zeros_like(self.immune)
359
+
360
+ for (x, y) in self.immune_cells:
361
+ strength = self.immune[x, y]
362
+ if strength <= 0:
363
+ continue
364
+
365
+ # Kill prob on contact: (0.15 - 0.3, if I=5, IE = 0)
366
+ tumor_nb = self.get_neighbours(x, y, 2)
367
+ if tumor_nb:
368
+ tx, ty = random.choice(tumor_nb)
369
+ kill = (0.05*self.I) * np.exp(-0.25*self.mutate[tx,ty]) * IE
370
+ kill = min(kill, 0.3)
371
+ if np.random.rand() < kill:
372
+ self.field[tx, ty] = 0
373
+ self.mutate[tx, ty] = 0
374
+ kills_per_hour += 1
375
+
376
+ # # Multiple moves/cycle as immune cells are faster
377
+ moves = int(1 + self.I * (1 - self.chemo[x, y]))
378
+ for _ in range(moves):
379
+ free_nb = self.get_neighbours(x, y, 1)
380
+ if not free_nb:
381
+ strength -= 1
382
+ break
383
+
384
+ # Biased movement towards tumor density (chemotaxis)
385
+ t_dens = [self.chemo[i, j] for (i, j) in free_nb]
386
+ if sum(t_dens) > 0:
387
+ weights = np.array(t_dens) / sum(t_dens)
388
+ tx, ty = free_nb[np.random.choice(len(free_nb), p=weights)]
389
+ else:
390
+ tx, ty = random.choice(free_nb)
391
+ x, y = tx, ty
392
+ strength -= 1
393
+
394
+ if strength > 0:
395
+ new_immune[x, y] = strength
396
+ self.immune = new_immune
397
+
398
+ # Save number of immune cells
399
+ immune_size = len(self.immune_cells)
400
+ self.wbc_number.append(immune_size)
401
+ self.it_ratio.append(immune_size / tumor_size)
402
+
403
+ # Infiltrating immune cells
404
+ wbc_infil = sum(1 for (x,y) in self.immune_cells if self.chemo[x,y] >= infil)
405
+ if immune_size > 0:
406
+ self.kill_day.append(kills_per_hour / max(1, wbc_infil) * 24)
407
+
408
+ # ---------------------------------------------------------------------
409
+ def animate(self, mode):
410
+ """
411
+ Creates and returns animation of the growth.
412
+
413
+ Parameters:
414
+ mode (int): create figure, save frame or display animation. (1-3)
415
+
416
+ Returns:
417
+ ArtistAnimation: the animation of the growth (optional)
418
+ """
419
+
420
+ if mode == 1:
421
+ # Create the figure
422
+ self.fig, self.ax = plt.subplots()
423
+ self.ax.imshow(self.field)
424
+ self.ax.set_title(str(self.cycles)+ " hour cell growth")
425
+ self.ax.set_xlabel(str(self.side*10) + " micrometers")
426
+ self.ax.set_ylabel(str(self.side*10) + " micrometers")
427
+ elif mode == 2:
428
+ # Save the current frame
429
+ growth = self.ax.imshow(self.field, animated=True)
430
+ immune_coords = np.argwhere(self.immune > 0)
431
+ immune = self.ax.scatter(immune_coords[:,1], immune_coords[:,0], c='blue', s=10)
432
+ self.images.append([growth, immune])
433
+ elif mode == 3:
434
+ # Display the animation
435
+ return animation.ArtistAnimation(self.fig, self.images, interval=50, blit=True)
436
+
437
+ # ---------------------------------------------------------------------
438
+ def save_field_to_excel(self, file_name):
439
+ """
440
+ Saves the current state of self.field to an excel file.
441
+
442
+ Parameters:
443
+ file_name (str): name of the excel file
444
+ """
445
+
446
+ pd.DataFrame(self.field).to_excel(file_name, index=False)
447
+
448
+ # ---------------------------------------------------------------------
449
+ def mod_cell(self, x, y, value):
450
+ """
451
+ Modifies cell value. (Create initial state before this!)
452
+
453
+ Parameters:
454
+ x, y (int): representing coordinates of the cell
455
+ value (int): the new value at the given position
456
+ """
457
+
458
+ self.field[y][x] = value
459
+
460
+ # ---------------------------------------------------------------------
461
+ def get_prolif_potentials(self):
462
+ """
463
+ Returns a dictionary of proliferation potential numbers.
464
+
465
+ Returns:
466
+ dict: a dictionary of the proliferation potentials
467
+ """
468
+
469
+ nonzero_field = np.array(self.field)[np.array(self.field) > 0]
470
+ unique, counts = np.unique(nonzero_field, return_counts=True)
471
+ prolif_potents = {}
472
+
473
+ for i in range(1, self.pmax + 2):
474
+ prolif_potents[i] = 0
475
+ for val, count in zip(unique, counts):
476
+ prolif_potents[int(val)] = count
477
+
478
+ return prolif_potents
479
+
480
+ # ---------------------------------------------------------------------
481
+ def get_statistics(self):
482
+ """
483
+ Returns various statistical properties of the model.
484
+
485
+ Returns:
486
+ dict: a dictionary of the statistical properties
487
+ """
488
+
489
+ nonzero_field = self.field[self.field > 0]
490
+
491
+ # Statistics
492
+ if nonzero_field.size != 0:
493
+ stats = {
494
+ "Min": nonzero_field.min(),
495
+ "Max": nonzero_field.max(),
496
+ "Mean": nonzero_field.mean(),
497
+ "Std": nonzero_field.std(),
498
+ "Median": np.median(nonzero_field),
499
+ "Skew": skew(nonzero_field.ravel()),
500
+ "Kurtosis": kurtosis(nonzero_field.ravel()),
501
+ "Final STC": self.stc_number[self.cycles-1],
502
+ "Final RTC": self.rtc_number[self.cycles-1],
503
+ "Final WBC": self.wbc_number[self.cycles-1],
504
+ "Tumor Size": nonzero_field.size,
505
+ "Confluence": nonzero_field.size/self.field.size*100,
506
+ }
507
+
508
+ if self.I > 0:
509
+ stats.update({
510
+ "Mean I/T" : sum(self.it_ratio)/len(self.it_ratio),
511
+ "Mean k/d" : sum(self.kill_day)/len(self.kill_day)
512
+ })
513
+
514
+ # Proliferation potentials
515
+ stats.update(self.get_prolif_potentials())
516
+
517
+ # Cell Numbers
518
+ checkpoints = np.linspace(0, self.cycles - 1, int(self.cycles/10) + 1, dtype=int)
519
+ for idx in checkpoints:
520
+ hour = (idx + 1)
521
+ stats[f"{hour}h_STC"] = self.stc_number[idx]
522
+ stats[f"{hour}h_RTC"] = self.rtc_number[idx]
523
+ stats[f"{hour}h_WBC"] = self.wbc_number[idx]
524
+
525
+ else: stats = {"Status": "Extinct"}
526
+ return stats
527
+
528
+ # ---------------------------------------------------------------------
529
+ def save_statistics(self, file_name):
530
+ """
531
+ Saves various statistical properties of the model to an excel file.
532
+
533
+ Parameters:
534
+ file_name (str): name of the excel file
535
+ """
536
+
537
+ stats_dict = self.get_statistics()
538
+ df = pd.DataFrame([stats_dict])
539
+ df.to_excel(file_name, index=False)
540
+
541
+ # ---------------------------------------------------------------------
542
+ def measure_runtime(func):
543
+ # Decorator to measure completion time
544
+ @wraps(func)
545
+ def wrapper(*args, **kwargs):
546
+ start_time = time.time()
547
+ result = func(*args, **kwargs)
548
+ end_time = time.time()
549
+ runtime = end_time - start_time
550
+ print("Model completion time (s): " + str(runtime))
551
+ return result
552
+ return wrapper
553
+
554
+ # ---------------------------------------------------------------------
555
+ @measure_runtime
556
+ def run_model(self, plot, animate, stats):
557
+ """
558
+ The function that runs a single entire simulation.
559
+ For animation matplotlib backend cannot be inline!
560
+
561
+ Parameters:
562
+ plot (bool): set to true to display the plots of the model
563
+ animate (bool): set to true to enable matplotlib animation
564
+ stats (bool): set to true to print statistics of the model
565
+ """
566
+
567
+ # Create initial state
568
+ if len(self.field) == 0: self.init_state()
569
+ self.find_tumor_cells()
570
+ if len(self.immune) == 0:
571
+ self.immune = np.zeros((self.side, self.side))
572
+ if len(self.mutate) == 0:
573
+ self.mutate = np.zeros((self.side, self.side))
574
+ self.mutmap = np.zeros((self.side, self.side))
575
+
576
+ self.stc_number = []
577
+ self.rtc_number = []
578
+ self.wbc_number = []
579
+
580
+ if animate: self.animate(1)
581
+
582
+ # Growth loop
583
+ for c in tqdm(range(self.cycles), desc="Running simulation..."):
584
+ self.tumor_action()
585
+ self.immune_response()
586
+ self.find_tumor_cells()
587
+ self.count_tumor_cells()
588
+ if animate: self.animate(2)
589
+
590
+ # Store the results
591
+ self.store_model()
592
+
593
+ # Output settings
594
+ if plot: self.plot_run(len(self.runs))
595
+ if animate: self.ani = self.animate(3)
596
+ if stats:
597
+ df = pd.DataFrame(self.stats)
598
+ base_cols = self.separate_columns(df)[0]
599
+ print(df[base_cols])
600
+
601
+ # ---------------------------------------------------------------------
602
+ @measure_runtime
603
+ def run_multimodel(self, count, init_field, plot, stats):
604
+ """
605
+ Runs the model multiple times and returns a DataFrame of statistics.
606
+
607
+ Parameters:
608
+ count (int): number of times to run the simulation
609
+ init_field (np.array): custom initial state of field/run
610
+ plot (bool): set to true to display the plots of the model
611
+ stats (bool): set to true to print statistics of the model
612
+
613
+ Returns:
614
+ pd.DataFrame: collected statistics from each run
615
+ """
616
+
617
+ stats = []
618
+
619
+ for i in range(count):
620
+ self.field = init_field.copy()
621
+ self.immune = []
622
+ self.mutate = []
623
+ self.mutmap = []
624
+ self.run_model(plot = False, animate = False, stats = False)
625
+ stats.append(self.get_statistics())
626
+ all_stats = pd.DataFrame(stats)
627
+
628
+ if plot:
629
+ self.plot_averages(all_stats)
630
+ if stats:
631
+ df = pd.DataFrame(self.stats)
632
+ base_cols = self.separate_columns(df)[0]
633
+ print(df[base_cols])
634
+
635
+ return all_stats
636
+
637
+ # ---------------------------------------------------------------------
638
+ def store_model(self):
639
+ """
640
+ Stores the results of the previous model executions.
641
+ """
642
+
643
+ result = {}
644
+
645
+ result["immune"] = self.immune
646
+ result["mutate"] = self.mutate
647
+ result["mutmap"] = self.mutmap
648
+ result["field"] = self.field
649
+ result["stc"] = self.stc_number
650
+ result["rtc"] = self.rtc_number
651
+ result["wbc"] = self.wbc_number
652
+ result["pp"] = self.get_prolif_potentials().values()
653
+
654
+ # Stores data for plotting
655
+ self.runs.append(result)
656
+
657
+ # Stores data for statistics
658
+ self.stats.append(self.get_statistics())
659
+
660
+ # ---------------------------------------------------------------------
661
+ def separate_columns(self, data):
662
+ """
663
+ Separates the statistics DataFrame columns into logical groups:
664
+ base stats, STC, RTC, WBC counts, and proliferation potentials.
665
+
666
+ Parameters:
667
+ data (pd.DataFrame): Your data in a pandas dataframe format
668
+
669
+ Returns:
670
+ tuple of list[str]: A tuple containing 5 lists of column names:
671
+ - base: Columns with general statistical properties
672
+ - stc: Columns with STC counts at each time point
673
+ - rtc: Columns with RTC counts at each time point
674
+ - wbc: Columns with WBC counts at each time point
675
+ - pp: Columns for proliferation potential values
676
+ """
677
+
678
+ base = [col for col in data.columns if not str(col).isdigit()
679
+ and "_STC" not in str(col)
680
+ and "_RTC" not in str(col)
681
+ and "_WBC" not in str(col)]
682
+ stc = sorted([col for col in data.columns if "_STC" in str(col)],
683
+ key=lambda x: int(str(x).split("h")[0]))
684
+ rtc = sorted([col for col in data.columns if "_RTC" in str(col)],
685
+ key=lambda x: int(str(x).split("h")[0]))
686
+ wbc = sorted([col for col in data.columns if "_WBC" in str(col)],
687
+ key=lambda x: int(str(x).split("h")[0]))
688
+ pp = sorted([col for col in data.columns if isinstance(col, int)])
689
+
690
+ return base, stc, rtc, wbc, pp
691
+
692
+ # ---------------------------------------------------------------------
693
+ def plot_run(self, run):
694
+ """
695
+ Creates growth and cell number plots, proliferation potential histograms.
696
+
697
+ Paramteres:
698
+ run (int): which model execution to plot
699
+
700
+ Returns:
701
+ matplotlib.figure.Figure: the generated plots of the specific run
702
+ """
703
+
704
+ # Create the figue and axis
705
+ fig, axs = plt.subplots(2, 2, figsize=(14,14))
706
+
707
+ tumor = axs[0, 0].imshow(self.runs[run-1]["field"], vmin=0, vmax=self.pmax+1)
708
+ fig.colorbar(tumor, ax=axs[0, 0])
709
+
710
+ immune_coords = np.argwhere(self.runs[run-1]["immune"] > 0)
711
+ axs[0, 0].scatter(immune_coords[:,1], immune_coords[:,0],
712
+ c='blue', marker='v', s=10)
713
+
714
+ axs[0, 1].plot(self.runs[run-1]["stc"], 'C1', label='STC')
715
+ axs[0, 1].plot(self.runs[run-1]["rtc"], 'C2', label='RTC')
716
+ axs[0, 1].plot(self.runs[run-1]["wbc"], 'C3', label='WBC')
717
+ axs[0, 1].legend()
718
+
719
+ mutmap = axs[1, 0].imshow(self.runs[run-1]["mutmap"],
720
+ cmap="RdBu_r", vmin=-3, vmax=3, interpolation="bicubic")
721
+ fig.colorbar(mutmap, ax=axs[1, 0])
722
+
723
+ axs[1, 1].bar(range(1, self.pmax + 2), self.runs[run-1]["pp"], edgecolor='black')
724
+
725
+ # Titles/labels of the plots
726
+ titles = [str(self.cycles)+ "h cell growth", "Cell count",
727
+ "Mutation history", "Final PP values"]
728
+ labs_x = [str(self.side*10) + " um", "Time (h)",
729
+ str(self.side*10) + " um", "Proliferation potentials"]
730
+ labs_y = [str(self.side*10) + " um", "Cell numbers",
731
+ str(self.side*10) + " um", "Number of appearance"]
732
+
733
+ fig.suptitle("Simulation " + str(run) + " Results", fontsize = 16)
734
+ for i, ax in enumerate(axs.flat):
735
+ ax.set_title(titles[i])
736
+ ax.set_xlabel(labs_x[i])
737
+ ax.set_ylabel(labs_y[i])
738
+
739
+ # ---------------------------------------------------------------------
740
+ def plot_averages(self, data):
741
+ """
742
+ The function that plots the averages of multiple model results.
743
+ Works with the results of the 'run_multimodel' function.
744
+
745
+ Parameters:
746
+ data (pd.DataFrame): Your data in a pandas dataframe format
747
+
748
+ Returns:
749
+ matplotlib.figure.Figure: The plots of the averages with SD values
750
+ """
751
+
752
+ base_cols, stc_cols, rtc_cols, wbc_cols, pp_cols = self.separate_columns(data)
753
+
754
+ avg_stc = data[stc_cols].mean()
755
+ std_stc = data[stc_cols].std()
756
+ avg_rtc = data[rtc_cols].mean()
757
+ std_rtc = data[rtc_cols].std()
758
+ avg_wbc = data[wbc_cols].mean()
759
+ std_wbc = data[wbc_cols].std()
760
+ avg_pp = data[pp_cols].mean()
761
+ std_pp = data[pp_cols].std()
762
+
763
+ fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(14, 5))
764
+ timepoints = np.linspace(0, self.cycles - 1, int(self.cycles/10) + 1)
765
+
766
+ ax1.plot(timepoints, avg_stc, label='STC', color='C1')
767
+ ax1.fill_between(timepoints, avg_stc - std_stc, avg_stc + std_stc,
768
+ color='C1', alpha=0.3)
769
+ ax1.plot(timepoints, avg_rtc, label='RTC', color='C2')
770
+ ax1.fill_between(timepoints, avg_rtc - std_rtc, avg_rtc + std_rtc,
771
+ color='C2', alpha=0.3)
772
+ ax1.plot(timepoints, avg_wbc, label='WBC', color='C3')
773
+ ax1.fill_between(timepoints, avg_wbc - std_wbc, avg_wbc + std_wbc,
774
+ color='C3', alpha=0.3)
775
+
776
+ ax1.set_title("Average Tumor Cell Count")
777
+ ax1.set_xlabel("Model Time (hours)")
778
+ ax1.set_ylabel("Number of Cells")
779
+ ax1.legend()
780
+
781
+ ax2.bar(pp_cols, avg_pp, yerr=std_pp, capsize=5, edgecolor='black')
782
+ ax2.set_title("Average Proliferation Potential Distribution")
783
+ ax2.set_xlabel("Proliferation Potential")
784
+ ax2.set_ylabel("Average Count")
785
+
786
+ fig.suptitle("Averages of " + str(len(self.stats)) + " Models", fontsize = 16)
787
+ plt.tight_layout()
788
+
789
+
790
+ class TDashboard:
791
+ """
792
+ Class for a Streamlit dashboard providing a GUI for the model.
793
+
794
+ Parameters:
795
+ model (TModel): The created model you want a dashboard for
796
+ """
797
+
798
+ def __init__(self, model):
799
+ self.model = model
800
+
801
+ # ---------------------------------------------------------------------
802
+ def run_dashboard(self):
803
+ """
804
+ The function that creates the entire streamlit dashboard for the model.
805
+ """
806
+
807
+ st.set_page_config(layout="wide")
808
+ st.markdown("<h1 style='text-align: center;'>TCAMpy</h1>", unsafe_allow_html=True)
809
+ self.screen_width = st_javascript("window.innerWidth", key="screen_width")
810
+
811
+ tab1, tab2 = st.tabs(["SIMULATION", "MACHINE LEARNING"])
812
+ with tab1:
813
+ self.columns = [4, 1, 12]
814
+ self.col1, _, self.col3 = st.columns(self.columns)
815
+
816
+ with self.col1:
817
+ self._initialize()
818
+ self._modify_cell()
819
+ self._execute_model()
820
+ with self.col3:
821
+ self._visualize_run("Last Simulation", len(self.model.runs))
822
+ self._show_statistics()
823
+ self._reset_save_stats()
824
+
825
+ with tab2:
826
+ col1, col2 = st.columns(2)
827
+
828
+ with col1:
829
+ self._simdata_generator()
830
+ with col2:
831
+ self._train_and_predict()
832
+
833
+ # ---------------------------------------------------------------------
834
+ def print_title(self, title):
835
+ """
836
+ The function that prints text as a title on the dashboard.
837
+
838
+ Parameters:
839
+ title (string): The text to print
840
+ """
841
+
842
+ st.markdown(
843
+ f"<h2 style='text-align: center;'>{title}</h2>",
844
+ unsafe_allow_html=True
845
+ )
846
+
847
+ # ---------------------------------------------------------------------
848
+ def get_plot_height(self, col, scaler):
849
+ """
850
+ The function that calculates the height of plots
851
+ based on screen width, column width and a scaler.
852
+
853
+ Parameters:
854
+ col (int): main column number
855
+ scalar (float): scaler for column width
856
+ """
857
+
858
+ screen_width = st.session_state.get("screen_width")
859
+ col_width_px = screen_width * (self.columns[col-1] / sum(self.columns))
860
+ return int(col_width_px * scaler)
861
+
862
+ # ---------------------------------------------------------------------
863
+ def _initialize(self):
864
+ """
865
+ The function that sets the parameters and initializes the model.
866
+ """
867
+
868
+ self.print_title("Model Parameters")
869
+
870
+ self.model.cycles = st.slider("Model Duration (hours)", 50, 5000, value=self.model.cycles)
871
+ self.model.side = st.slider("Field Side Length (10um)", 10, 200, value=self.model.side)
872
+ self.model.pmax = st.slider("Max Proliferation Potential", 1, 20, value=self.model.pmax)
873
+ self.model.PA = st.slider("Apoptosis Chance (RTC) (%)", 0, 100, value=self.model.PA)
874
+ self.model.CCT = st.slider("Cell Cycle Time (hours)", 1, 48, value=self.model.CCT)
875
+ self.model.Dt = st.slider("Time Step (days)", 0.01, 1.0, value=self.model.Dt, step=0.01)
876
+ self.model.PS = st.slider("STC-STC Division Chance (%)", 0, 100, value=self.model.PS)
877
+ self.model.mu = st.slider("Migration Capacity", 0, 10, value=self.model.mu)
878
+ self.model.I = st.slider("Immune Strength", 0, 10, value=self.model.I)
879
+ self.model.M = st.slider("Mutation Chance", 0, 50, value=self.model.M)
880
+
881
+ self.model.PP = int(self.model.CCT * self.model.Dt / 24 * 100)
882
+ self.model.PM = 100 * self.model.mu / 24
883
+
884
+ init_config = (
885
+ self.model.side, self.model.cycles, self.model.pmax,
886
+ self.model.PA, self.model.CCT, self.model.Dt, self.model.PS,
887
+ self.model.mu, self.model.I, self.model.M
888
+ )
889
+ config_hash = hashlib.md5(str(init_config).encode()).hexdigest()
890
+
891
+ # Storing data for model plotting
892
+ if "model_runs" in st.session_state:
893
+ self.model.runs = st.session_state.model_runs
894
+
895
+ # Storing data for model statistics
896
+ if "model_stats" in st.session_state:
897
+ self.model.stats = st.session_state.model_stats
898
+
899
+ if (
900
+ "initialized" not in st.session_state
901
+ or "init_config_hash" not in st.session_state
902
+ or st.session_state.init_config_hash != config_hash
903
+ ):
904
+ self.model.init_state()
905
+ st.session_state.field = self.model.field.copy()
906
+ st.session_state.immune = self.model.immune.copy()
907
+ st.session_state.mutate = self.model.mutate.copy()
908
+ st.session_state.mutmap = self.model.mutmap.copy()
909
+ st.session_state.initialized = True
910
+ st.session_state.init_config_hash = config_hash
911
+
912
+ # ---------------------------------------------------------------------
913
+ def _modify_cell(self):
914
+ """
915
+ The function for initial state modification logic.
916
+ """
917
+
918
+ self.print_title("Initial State")
919
+
920
+ x_coord = st.number_input("X Coordinate", 0, self.model.side - 1, value=self.model.side // 2)
921
+ y_coord = st.number_input("Y Coordinate", 0, self.model.side - 1, value=self.model.side // 2)
922
+ cell_value = st.number_input("Cell Value", 0, self.model.pmax + 1, value=self.model.pmax + 1)
923
+ plots_height = self.get_plot_height(1, 0.9)
924
+
925
+ if st.button("Modify Cell"):
926
+ self.model.field = st.session_state.field.copy()
927
+ self.model.mod_cell(x_coord, y_coord, cell_value)
928
+ st.session_state.field = self.model.field.copy()
929
+ st.success(f"Cell modified at ({x_coord}, {y_coord}) to {cell_value}")
930
+
931
+ field = st.session_state.field
932
+ heatmap = self._create_heatmap(
933
+ plots_height, "Initial state", "viridis",
934
+ "PP", 0, self.model.pmax+1, field
935
+ )
936
+
937
+ st.altair_chart(heatmap, use_container_width=True)
938
+
939
+ # ---------------------------------------------------------------------
940
+ def _execute_model(self):
941
+ """
942
+ The function for model running logic.
943
+ """
944
+
945
+ self.print_title("Execution")
946
+
947
+ rep = st.number_input("How many simulations?", 1)
948
+
949
+ if st.button("Run Model"):
950
+ with st.spinner("Running simulations..."):
951
+ for i in range(rep):
952
+ self.model.field = st.session_state.field.copy()
953
+ self.model.immune = st.session_state.immune.copy()
954
+ self.model.mutate = st.session_state.mutate.copy()
955
+ self.model.mutmap = st.session_state.mutmap.copy()
956
+ self.model.run_model(plot = False, animate=False, stats=False)
957
+
958
+ st.session_state.model_runs = self.model.runs
959
+ st.session_state.model_stats = self.model.stats
960
+
961
+ # ---------------------------------------------------------------------
962
+ def _visualize_run(self, title, run):
963
+ """
964
+ The function for the result visualization logic.
965
+
966
+ Parameters:
967
+ title (string): title of the visualization
968
+ run (int): which model execution to plot
969
+ """
970
+
971
+ if "model_runs" not in st.session_state:
972
+ st.warning("Simulation results will appear here...")
973
+ return
974
+ self.print_title(title)
975
+
976
+ # --- Get latest run ---
977
+ latest = self.model.runs[run - 1]
978
+ immune = latest["immune"]
979
+ mutmap = latest["mutmap"]
980
+ field = latest["field"]
981
+ stc = latest["stc"]
982
+ rtc = latest["rtc"]
983
+ wbc = latest["wbc"]
984
+ pp = latest["pp"]
985
+
986
+ # --- Create charts ---
987
+ plots_height = self.get_plot_height(3, 0.4)
988
+
989
+ tumor_heatmap = self._create_heatmap(
990
+ plots_height, "Tumor growth", "viridis",
991
+ "PP", 0, self.model.pmax+1, field, immune
992
+ )
993
+ mutation_map = self._create_heatmap(
994
+ plots_height, "Mutation history", "redblue",
995
+ "M", -3, 3, mutmap
996
+ )
997
+
998
+ bar_chart = self._create_bar_chart(plots_height, list(pp))
999
+ line_chart = self._create_line_chart(plots_height, stc, rtc, wbc)
1000
+
1001
+ # --- Layout rules ---
1002
+ col1, col2 = st.columns([4, 5])
1003
+ with col1:
1004
+ st.altair_chart(tumor_heatmap, use_container_width=True)
1005
+ st.altair_chart(mutation_map, use_container_width=True)
1006
+ with col2:
1007
+ st.altair_chart(bar_chart, use_container_width=True)
1008
+ st.altair_chart(line_chart, use_container_width=True)
1009
+
1010
+ # ---------------------------------------------------------------------
1011
+ def _create_heatmap(
1012
+ self, h, title, cmap, ctitle,
1013
+ vmin, vmax, heatmap, scatter=None
1014
+ ):
1015
+ """
1016
+ Creates an Altair heatmap with a scatter plot overlaid.
1017
+ Used for tumor field with immune cells, and mutations.
1018
+
1019
+ Parameters:
1020
+ h (int): the height of the plot
1021
+ title (string): title of the plot
1022
+ cmap (stirng): colormap for the heatmap
1023
+ vmin, vmax (int): domain for the colormap
1024
+ heatmap (2D array-like): array for the heatmap
1025
+ scatter (2D array-like): array for scatter plot
1026
+
1027
+ Returns:
1028
+ Altair.Chart: heatmap with scatter overlay
1029
+ """
1030
+
1031
+ # --- Heatmap data ---
1032
+ heat_df = pd.DataFrame([
1033
+ {"x": x, "y": y, "value": heatmap[y, x]}
1034
+ for y in range(heatmap.shape[0])
1035
+ for x in range(heatmap.shape[1])
1036
+ ])
1037
+
1038
+ heat_chart = alt.Chart(heat_df).mark_rect().encode(
1039
+ x=alt.X("x:O", title="X"),
1040
+ y=alt.Y("y:O", sort="descending", title="Y"),
1041
+ color=alt.Color("value:Q", title=ctitle,
1042
+ scale=alt.Scale(scheme=cmap, domain=[vmin, vmax]))
1043
+ ).properties(
1044
+ title = title,
1045
+ width='container',
1046
+ height=h
1047
+ )
1048
+
1049
+ # --- Scatter plot data ---
1050
+ if scatter is not None:
1051
+ scatter_coords = np.argwhere(scatter > 0)
1052
+ scatter_df = pd.DataFrame(scatter_coords, columns=["y", "x"])
1053
+
1054
+ scatter = alt.Chart(scatter_df).mark_point(
1055
+ color="blue", size=h/20, filled=True, shape="circle"
1056
+ ).encode(
1057
+ x=alt.X("x:O"),
1058
+ y=alt.Y("y:O", sort="descending")
1059
+ )
1060
+
1061
+ # --- Combine layers ---
1062
+ heat_chart = (heat_chart + scatter).properties(
1063
+ title=title,
1064
+ width='container',
1065
+ height=h
1066
+ )
1067
+
1068
+ return heat_chart
1069
+
1070
+ # ---------------------------------------------------------------------
1071
+ def _create_line_chart(
1072
+ self, h, stc, rtc, wbc, stc_l=None, stc_u=None,
1073
+ rtc_l=None, rtc_u=None, wbc_l=None, wbc_u=None
1074
+ ):
1075
+ """
1076
+ The function that creates an Altair line chart of the cell numbers.
1077
+
1078
+ Parameters:
1079
+ h (int): the height of the plot
1080
+ stc, rtc, wbc (list): a list of the cell and immune numbers (mean or raw)
1081
+ stc_l, rtc_l, wbc_l (list of float, optional): Lower bounds (e.g., mean - SD) for cell counts.
1082
+ stc_u, rtc_u, wbc_u (list of float, optional): Upper bounds (e.g., mean + SD) for cell counts.
1083
+
1084
+ Returns:
1085
+ Altair.Chart: represents the line chart of the cell numbers
1086
+ """
1087
+
1088
+ timepoints = list(range(len(stc)))
1089
+ df = pd.DataFrame({
1090
+ "Hour": timepoints * 3,
1091
+ "Cell Type": ["STC"] * len(stc) + ["RTC"] * len(rtc) + ["WBC"] * len(wbc),
1092
+ "Mean": stc + rtc + wbc
1093
+ })
1094
+
1095
+ if stc_l and rtc_l and wbc_l:
1096
+ df["Lower"] = stc_l + rtc_l + wbc_l
1097
+ df["Upper"] = stc_u + rtc_u + wbc_u
1098
+
1099
+ area = alt.Chart(df).mark_area(opacity=0.3).encode(
1100
+ x=alt.X("Hour:Q", title="Time (hours)"),
1101
+ y=alt.Y("Lower:Q", title="Mean"),
1102
+ y2="Upper:Q",
1103
+ color="Cell Type:N"
1104
+ )
1105
+ else:
1106
+ area = None
1107
+
1108
+ line = alt.Chart(df).mark_line().encode(
1109
+ x="Hour:Q",
1110
+ y=alt.Y("Mean:Q", title="Mean"),
1111
+ color="Cell Type:N"
1112
+ )
1113
+
1114
+ chart = (area + line) if area else line
1115
+ return chart.properties(title="Cell Counts Over Time", height=h)
1116
+
1117
+ # ---------------------------------------------------------------------
1118
+ def _create_bar_chart(self, h, pp, std=None):
1119
+ """
1120
+ Creates an Altair bar chart for proliferation potential distribution.
1121
+
1122
+ Parameters:
1123
+ h (int): the height of the plot
1124
+ pp (list of float or int): Mean or raw counts of cells per proliferation potential class
1125
+ std (list of float, optional): Standard deviation for each class
1126
+
1127
+ Returns:
1128
+ alt.Chart: An Altair chart representing the distribution of proliferation potentials
1129
+ """
1130
+
1131
+ pp_df = pd.DataFrame({
1132
+ "Proliferation Potential": list(range(1, len(pp) + 1)),
1133
+ "Mean": pp
1134
+ })
1135
+ chart = alt.Chart(pp_df).mark_bar().encode(
1136
+ x="Proliferation Potential:O",
1137
+ y="Mean:Q"
1138
+ )
1139
+
1140
+ if std is not None:
1141
+ pp_df["Std"] = std
1142
+ error = alt.Chart(pp_df).mark_errorbar(extent="stdev").encode(
1143
+ x="Proliferation Potential:O",
1144
+ y="Mean:Q",
1145
+ yError="Std:Q"
1146
+ )
1147
+ chart = chart + error
1148
+
1149
+ return chart.properties(title="Proliferation Potential Distribution", height=h)
1150
+
1151
+ # ---------------------------------------------------------------------
1152
+ def _show_statistics(self):
1153
+ """
1154
+ The function for the statistics printing logic.
1155
+ """
1156
+
1157
+ if not self.model.stats: return
1158
+
1159
+ self.print_title("All Simulations")
1160
+ plots_height = self.get_plot_height(3, 0.4)
1161
+ df = pd.DataFrame(self.model.stats)
1162
+ base_cols, stc_cols, rtc_cols, wbc_cols, pp_cols = self.model.separate_columns(df)
1163
+ df.index = df.index + 1
1164
+
1165
+ # Display Statistics
1166
+ mean_row = df[base_cols].mean(numeric_only=True)
1167
+ std_row = df[base_cols].std(numeric_only=True)
1168
+ mean_row.name = "Mean"
1169
+ std_row.name = "Std"
1170
+ full_stats = pd.concat([df[base_cols], mean_row.to_frame().T, std_row.to_frame().T])
1171
+ st.dataframe(full_stats)
1172
+
1173
+ # Create avg charts
1174
+ stc_means = df[stc_cols].mean()
1175
+ stc_stds = df[stc_cols].std()
1176
+ rtc_means = df[rtc_cols].mean()
1177
+ rtc_stds = df[rtc_cols].std()
1178
+ wbc_means = df[wbc_cols].mean()
1179
+ wbc_stds = df[wbc_cols].std()
1180
+ pp_means = df[pp_cols].mean()
1181
+ pp_stds = df[pp_cols].std()
1182
+
1183
+ line_chart = self._create_line_chart(plots_height,
1184
+ list(stc_means.values), list(rtc_means.values), list(wbc_means.values),
1185
+ list((stc_means - stc_stds).values), list((stc_means + stc_stds).values),
1186
+ list((rtc_means - rtc_stds).values), list((rtc_means + rtc_stds).values),
1187
+ list((wbc_means - wbc_stds).values), list((wbc_means + wbc_stds).values)
1188
+ )
1189
+
1190
+ bar_chart = self._create_bar_chart(plots_height, list(pp_means.values), list(pp_stds.values))
1191
+
1192
+ col1, col2 = st.columns(2)
1193
+ with col1:
1194
+ st.altair_chart(line_chart, use_container_width=True)
1195
+ with col2:
1196
+ st.altair_chart(bar_chart, use_container_width=True)
1197
+
1198
+ # ---------------------------------------------------------------------
1199
+ def _reset_save_stats(self):
1200
+ """
1201
+ The function for the reset/download statistics logic.
1202
+ """
1203
+
1204
+ if "model_stats" in st.session_state:
1205
+ self.print_title("Simulation Options")
1206
+ col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
1207
+ selected_run = None
1208
+ visualize = False
1209
+
1210
+ with col1:
1211
+ if st.button("Reset Model Executions and Data", use_container_width=True):
1212
+ del st.session_state.model_stats
1213
+ self.model.stats.clear()
1214
+ del st.session_state.model_runs
1215
+ self.model.runs.clear()
1216
+
1217
+ st.success("Executions have been reset.")
1218
+ with col2:
1219
+ buffer = io.BytesIO()
1220
+ pd.DataFrame(self.model.stats).to_excel(buffer, index=False)
1221
+ buffer.seek(0)
1222
+
1223
+ st.download_button(
1224
+ label="Download Statistics (xlsx)",
1225
+ data=buffer,
1226
+ file_name="simulation_statistics.xlsx",
1227
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
1228
+ use_container_width=True
1229
+ )
1230
+ with col3:
1231
+ run_select = st.selectbox(
1232
+ "", list(range(1, len(self.model.runs) + 1)),
1233
+ placeholder="Select simulation",
1234
+ label_visibility="collapsed",
1235
+ index=None
1236
+ )
1237
+ selected_run = run_select
1238
+ with col4:
1239
+ if st.button("Visualize Selected Simulation", use_container_width=True):
1240
+ if selected_run: visualize = True
1241
+ else: st.warning('Please select a simulation!')
1242
+ if visualize:
1243
+ self._visualize_run("Selected Simulation", selected_run)
1244
+
1245
+ # ---------------------------------------------------------------------
1246
+ def _simdata_generator(self):
1247
+ """
1248
+ The Machine Learning tab for dataset generation and download.
1249
+ Uses the TML class to generate simulation data.
1250
+ """
1251
+
1252
+ self.print_title("Simulation Data Generator")
1253
+
1254
+ # Initialize TML
1255
+ tml = TML(self.model)
1256
+
1257
+ st.write("Select randomization ranges for each parameter:")
1258
+
1259
+ # Build parameter range inputs dynamically
1260
+ param_ranges = {}
1261
+ for param, default_val in tml.default_params.items():
1262
+ col1, col2 = st.columns(2)
1263
+ with col1:
1264
+ low = st.number_input(
1265
+ f"{param} (min)",
1266
+ value=float(default_val) * 0.8,
1267
+ key=f"{param}_low"
1268
+ )
1269
+ with col2:
1270
+ high = st.number_input(
1271
+ f"{param} (max)",
1272
+ value=float(default_val) * 1.5,
1273
+ key=f"{param}_high"
1274
+ )
1275
+ param_ranges[param] = (low, high)
1276
+
1277
+ n = st.number_input("Number of simulations", 5, 500, 50, step=5)
1278
+
1279
+ # Run simulation button
1280
+ if st.button("Generate Dataset", use_container_width=True):
1281
+ with st.spinner("Running simulations..."):
1282
+ df = tml.generate_dataset(n=n, random_params=param_ranges)
1283
+ st.success(f"Dataset generated successfully ({len(df)} rows).")
1284
+
1285
+ # Allow CSV download
1286
+ csv = df.to_csv(index=False).encode("utf-8")
1287
+ st.download_button(
1288
+ label="Download Dataset (.csv)",
1289
+ data=csv,
1290
+ file_name="tumor_dataset.csv",
1291
+ mime="text/csv",
1292
+ use_container_width=True
1293
+ )
1294
+
1295
+ # ---------------------------------------------------------------------
1296
+ def _train_and_predict(self):
1297
+ """
1298
+ Streamlit UI for model training and prediction using the TML class.
1299
+ """
1300
+
1301
+ self.print_title("Model Trainer and Predictor")
1302
+
1303
+ tml = TML(self.model)
1304
+
1305
+ uploaded_file = st.file_uploader("Upload CSV dataset", type=["csv"])
1306
+ if uploaded_file is not None:
1307
+ df = pd.read_csv(uploaded_file)
1308
+ st.write("Uploaded dataset preview:")
1309
+ st.dataframe(df.head())
1310
+
1311
+ # Choose target column
1312
+ target = st.selectbox("Select target attribute", df.columns, index=len(df.columns) - 1)
1313
+
1314
+ # Three sliders for model parameters
1315
+ test_size = st.slider("Test Size", 0.1, 0.5, 0.2, step=0.05)
1316
+ random_state = st.slider("Random Seed", 0, 100, 42, step=1)
1317
+ n_estimators = st.slider("Number of Trees (n_estimators)", 50, 500, 200, step=50)
1318
+
1319
+ if st.button("Train Model", use_container_width=True):
1320
+ with st.spinner("Training model..."):
1321
+ model, metrics = tml.train_predictor(
1322
+ file=df,
1323
+ target=target,
1324
+ test_size=test_size,
1325
+ random_state=random_state,
1326
+ n_estimators=n_estimators
1327
+ )
1328
+
1329
+ st.success("Model trained successfully!")
1330
+ st.write(f"**R^2:** {metrics['R2']:.3f}")
1331
+ st.write(f"**MAE:** {metrics['MAE']:.3f}")
1332
+
1333
+ # Store trained model for later prediction
1334
+ st.session_state["trained_tml"] = tml
1335
+ st.session_state["target"] = target
1336
+ else:
1337
+ st.info("Please upload a dataset to train a model.")
1338
+
1339
+ if "trained_tml" in st.session_state:
1340
+ trained_tml = st.session_state["trained_tml"]
1341
+ target = st.session_state.get("target", "Target")
1342
+ feature_cols = trained_tml.feature_columns
1343
+
1344
+ # Numeric inputs for each feature
1345
+ new_params = []
1346
+ self.print_title("Predict for new instance")
1347
+ for col in feature_cols:
1348
+ val = st.number_input(f"{col}", value=1.0, key=f"pred_{col}")
1349
+ new_params.append(val)
1350
+
1351
+ # Target selector (for user clarity, though single-target regression)
1352
+ st.markdown("#### Select Target to Predict:")
1353
+ st.text(f"Predicting: {target}")
1354
+
1355
+ if st.button("🔮 Predict New", use_container_width=True):
1356
+ try:
1357
+ prediction = trained_tml.predict_new(new_params)
1358
+ st.success(f"Predicted {target}: **{prediction:.3f}**")
1359
+ except Exception as e:
1360
+ st.error(f"Prediction failed: {e}")
1361
+ else:
1362
+ st.info("Train a model first to enable prediction.")
1363
+
1364
+
1365
+ class TML:
1366
+ """
1367
+ Class for handling Machine Learning tasks related to the tumor model.
1368
+ Allows dataset generation, parameter exploration, and result export.
1369
+ Allows predicting the size/confluence for a new set of parameters.
1370
+
1371
+ Parameters:
1372
+ model (TModel): a created instance of the TModel class.
1373
+ """
1374
+
1375
+ def __init__(self, model):
1376
+ self.model = model
1377
+
1378
+ self.default_params = {
1379
+ "cycles": self.model.cycles,
1380
+ "side": self.model.side,
1381
+ "pmax": self.model.pmax,
1382
+ "PA": self.model.PA,
1383
+ "CCT": self.model.CCT,
1384
+ "Dt": self.model.Dt,
1385
+ "PS": self.model.PS,
1386
+ "mu": self.model.mu,
1387
+ "I": self.model.I,
1388
+ "M": self.model.M,
1389
+ }
1390
+
1391
+ # ---------------------------------------------------------------------
1392
+ def generate_dataset(
1393
+ self, n=50, random_params=None,
1394
+ output_file="tumor_dataset.csv"
1395
+ ):
1396
+ """
1397
+ Generate a dataset of tumor simulations by randomizing given parameters.
1398
+
1399
+ Parameters:
1400
+ n_sims (int): Number of simulations to run.
1401
+ randomize_params (dict): Parameters to randomize, e.g.
1402
+ {
1403
+ "PA": (1, 20), "PS": (10, 40))
1404
+ }
1405
+ output_file (str): CSV filename to save dataset.
1406
+
1407
+ Returns:
1408
+ pd.DataFrame: Combined DataFrame with all simulation results.
1409
+ """
1410
+
1411
+ stats = []
1412
+
1413
+ # Randomize chosen parameters
1414
+ for i in tqdm(range(n), desc="Generating simulations"):
1415
+ params = self.default_params.copy()
1416
+ for key, (low, high) in random_params.items():
1417
+ if isinstance(params[key], int):
1418
+ params[key] = random.randint(int(low), int(high))
1419
+ else:
1420
+ params[key] = random.uniform(float(low), float(high))
1421
+
1422
+ # Run simulation
1423
+ model = TModel(**params)
1424
+ model.run_model(plot = False, animate = False, stats = False)
1425
+
1426
+ run_stats = {}
1427
+ for k, v in params.items():
1428
+ run_stats[k] = v
1429
+ run_stats["Tumor size"] = np.count_nonzero(model.field)
1430
+ run_stats["Confluence"] = np.count_nonzero(model.field)/model.field.size*100
1431
+ stats.append(run_stats)
1432
+
1433
+ if stats:
1434
+ df = pd.DataFrame(stats)
1435
+ df.to_csv(output_file, index=False)
1436
+ print(f"Dataset saved to {output_file} ({len(df)} runs)")
1437
+ return df
1438
+
1439
+ # ---------------------------------------------------------------------
1440
+ def train_predictor(
1441
+ self, file, target, test_size=0.2,
1442
+ random_state=42, n_estimators=200
1443
+ ):
1444
+ """
1445
+ Trains a regression model to predict final tumor size based on simulation parameters.
1446
+
1447
+ Parameters:
1448
+ file (str): CSV file containing the dataset
1449
+ target (str): Column name of the target attribute
1450
+ test_size (float): Fraction of dataset to use for testing
1451
+ random_state (int): Random seed for reproducibility
1452
+ n_estimators (int): Number of trees in the random forest
1453
+
1454
+ Returns:
1455
+ model (RandomForestRegressor): Trained model
1456
+ metrics (dict): R^2 and MAE metrics on test set
1457
+ """
1458
+
1459
+ if isinstance(file, pd.DataFrame):
1460
+ df = file
1461
+ else:
1462
+ df = pd.read_csv(file)
1463
+ x = df[df.columns[0:10]]
1464
+ y = df[target]
1465
+
1466
+ self.feature_columns = x.columns.tolist()
1467
+
1468
+ # Split into train/test
1469
+ x_train, x_test, y_train, y_test = train_test_split(
1470
+ x, y, test_size=test_size, random_state=random_state
1471
+ )
1472
+
1473
+ # Train model
1474
+ model = RandomForestRegressor(
1475
+ n_estimators=n_estimators,
1476
+ random_state=random_state,
1477
+ n_jobs=-1
1478
+ )
1479
+ model.fit(x_train, y_train)
1480
+
1481
+ # Predict & evaluate
1482
+ y_pred = model.predict(x_test)
1483
+ metrics = {
1484
+ "R2": r2_score(y_test, y_pred),
1485
+ "MAE": mean_absolute_error(y_test, y_pred)
1486
+ }
1487
+ self.trained_model = model
1488
+
1489
+ print(f"Model trained on {len(x_train)} samples, tested on {len(x_test)}")
1490
+ print(f"R^2: {metrics['R2']:.3f}, MAE: {metrics['MAE']:.3f}")
1491
+
1492
+ # Feature importance summary
1493
+ importance = pd.Series(model.feature_importances_, index=x.columns).sort_values(ascending=False)
1494
+ print("\n Top influencing parameters:")
1495
+ print(importance.head())
1496
+
1497
+ return model, metrics
1498
+
1499
+ # ---------------------------------------------------------------------
1500
+ def predict_new(self, params):
1501
+ """
1502
+ Predicts an attribute value for a set of
1503
+ parameters using a previously trained model.
1504
+
1505
+ Parameters:
1506
+ params (list): List of parameters, e.g.
1507
+ [500, 50, 10, 1, 24, 1/24, 15, 4, 4, 10]
1508
+
1509
+ Returns:
1510
+ float: Predicted tumor size
1511
+ """
1512
+ if self.trained_model is None:
1513
+ raise RuntimeError(
1514
+ "No trained model found. Train one with train_predictor() first."
1515
+ )
1516
+ if self.feature_columns is None:
1517
+ raise RuntimeError(
1518
+ "Feature column list not found. Did you train the model?"
1519
+ )
1520
+
1521
+ if len(params) != len(self.feature_columns):
1522
+ raise ValueError(
1523
+ f"Parameter list must have {len(self.feature_columns)} values "
1524
+ f"(got {len(params)}). Expected order: {self.feature_columns}"
1525
+ )
1526
+ df = pd.DataFrame([params], columns=self.feature_columns)
1527
+
1528
+ # Ensure all expected features are present (fill missing ones with 0)
1529
+ for col in self.feature_columns:
1530
+ if col not in df.columns:
1531
+ df[col] = 0
1532
+ df = df[self.feature_columns]
1533
+
1534
+ return self.trained_model.predict(df)[0]