spacr 0.3.72__py3-none-any.whl → 0.3.81__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/gui_core.py +169 -37
- spacr/gui_elements.py +2 -3
- spacr/ml.py +151 -56
- spacr/settings.py +42 -7
- spacr/utils.py +24 -14
- {spacr-0.3.72.dist-info → spacr-0.3.81.dist-info}/METADATA +1 -1
- {spacr-0.3.72.dist-info → spacr-0.3.81.dist-info}/RECORD +11 -11
- {spacr-0.3.72.dist-info → spacr-0.3.81.dist-info}/LICENSE +0 -0
- {spacr-0.3.72.dist-info → spacr-0.3.81.dist-info}/WHEEL +0 -0
- {spacr-0.3.72.dist-info → spacr-0.3.81.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.72.dist-info → spacr-0.3.81.dist-info}/top_level.txt +0 -0
spacr/gui_core.py
CHANGED
@@ -4,6 +4,7 @@ from tkinter import ttk
|
|
4
4
|
from tkinter import filedialog
|
5
5
|
from multiprocessing import Process, Value, Queue, set_start_method
|
6
6
|
from tkinter import ttk
|
7
|
+
import matplotlib
|
7
8
|
from matplotlib.figure import Figure
|
8
9
|
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
|
9
10
|
import numpy as np
|
@@ -169,7 +170,7 @@ def display_figure(fig):
|
|
169
170
|
#flash_feedback("right")
|
170
171
|
show_next_figure()
|
171
172
|
|
172
|
-
def
|
173
|
+
def zoom_v1(event):
|
173
174
|
nonlocal scale_factor
|
174
175
|
|
175
176
|
zoom_speed = 0.1 # Adjust the zoom speed for smoother experience
|
@@ -197,6 +198,70 @@ def display_figure(fig):
|
|
197
198
|
|
198
199
|
# Redraw the figure efficiently
|
199
200
|
canvas.draw_idle()
|
201
|
+
|
202
|
+
def zoom_test(event):
|
203
|
+
if event.num == 4: # Scroll up
|
204
|
+
print("zoom in")
|
205
|
+
elif event.num == 5: # Scroll down
|
206
|
+
print("zoom out")
|
207
|
+
|
208
|
+
def zoom_2(event):
|
209
|
+
zoom_speed = 0.1 # Change this to control how fast you zoom
|
210
|
+
|
211
|
+
# Determine the zoom direction based on the scroll event
|
212
|
+
if event.num == 4 or (hasattr(event, 'delta') and event.delta > 0): # Scroll up = zoom in
|
213
|
+
factor = 1 - zoom_speed
|
214
|
+
elif event.num == 5 or (hasattr(event, 'delta') and event.delta < 0): # Scroll down = zoom out
|
215
|
+
factor = 1 + zoom_speed
|
216
|
+
else:
|
217
|
+
return # No recognized scroll direction
|
218
|
+
|
219
|
+
for ax in canvas.figure.get_axes():
|
220
|
+
xlim = ax.get_xlim()
|
221
|
+
ylim = ax.get_ylim()
|
222
|
+
|
223
|
+
x_center = (xlim[1] + xlim[0]) / 2
|
224
|
+
y_center = (ylim[1] + ylim[0]) / 2
|
225
|
+
|
226
|
+
x_range = (xlim[1] - xlim[0]) * factor
|
227
|
+
y_range = (ylim[1] - ylim[0]) * factor
|
228
|
+
|
229
|
+
# Set the new limits
|
230
|
+
ax.set_xlim([x_center - x_range / 2, x_center + x_range / 2])
|
231
|
+
ax.set_ylim([y_center - y_range / 2, y_center + y_range / 2])
|
232
|
+
|
233
|
+
# Redraw the figure efficiently
|
234
|
+
canvas.draw_idle()
|
235
|
+
|
236
|
+
def zoom(event):
|
237
|
+
# Fixed zoom factors (adjust these if you want faster or slower zoom)
|
238
|
+
zoom_in_factor = 0.9 # When zooming in, ranges shrink by 10%
|
239
|
+
zoom_out_factor = 1.1 # When zooming out, ranges increase by 10%
|
240
|
+
|
241
|
+
# Determine the zoom direction based on the scroll event
|
242
|
+
if event.num == 4 or (hasattr(event, 'delta') and event.delta > 0): # Scroll up = zoom in
|
243
|
+
factor = zoom_in_factor
|
244
|
+
elif event.num == 5 or (hasattr(event, 'delta') and event.delta < 0): # Scroll down = zoom out
|
245
|
+
factor = zoom_out_factor
|
246
|
+
else:
|
247
|
+
return # No recognized scroll direction
|
248
|
+
|
249
|
+
for ax in canvas.figure.get_axes():
|
250
|
+
xlim = ax.get_xlim()
|
251
|
+
ylim = ax.get_ylim()
|
252
|
+
|
253
|
+
x_center = (xlim[1] + xlim[0]) / 2
|
254
|
+
y_center = (ylim[1] + ylim[0]) / 2
|
255
|
+
|
256
|
+
x_range = (xlim[1] - xlim[0]) * factor
|
257
|
+
y_range = (ylim[1] - ylim[0]) * factor
|
258
|
+
|
259
|
+
# Set the new limits
|
260
|
+
ax.set_xlim([x_center - x_range / 2, x_center + x_range / 2])
|
261
|
+
ax.set_ylim([y_center - y_range / 2, y_center + y_range / 2])
|
262
|
+
|
263
|
+
# Redraw the figure efficiently
|
264
|
+
canvas.draw_idle()
|
200
265
|
|
201
266
|
|
202
267
|
# Bind events for hover, click interactions, and zoom
|
@@ -205,19 +270,20 @@ def display_figure(fig):
|
|
205
270
|
canvas_widget.bind("<Button-1>", on_click)
|
206
271
|
canvas_widget.bind("<Button-3>", on_right_click)
|
207
272
|
|
208
|
-
|
209
273
|
# Detect the operating system and bind the appropriate mouse wheel events
|
210
274
|
current_os = platform.system()
|
211
275
|
|
212
276
|
if current_os == "Windows":
|
213
277
|
canvas_widget.bind("<MouseWheel>", zoom) # Windows
|
214
|
-
elif current_os == "Darwin":
|
278
|
+
elif current_os == "Darwin":
|
215
279
|
canvas_widget.bind("<MouseWheel>", zoom)
|
216
280
|
canvas_widget.bind("<Button-4>", zoom) # Scroll up
|
217
281
|
canvas_widget.bind("<Button-5>", zoom) # Scroll down
|
218
282
|
elif current_os == "Linux":
|
219
283
|
canvas_widget.bind("<Button-4>", zoom) # Linux Scroll up
|
220
284
|
canvas_widget.bind("<Button-5>", zoom) # Linux Scroll down
|
285
|
+
|
286
|
+
process_fig_queue()
|
221
287
|
|
222
288
|
def clear_unused_figures():
|
223
289
|
global figures, figure_index
|
@@ -230,71 +296,127 @@ def clear_unused_figures():
|
|
230
296
|
figure_index = min(max(figure_index, 0), len(figures) - 1)
|
231
297
|
|
232
298
|
def show_previous_figure():
|
233
|
-
|
299
|
+
from .gui_elements import standardize_figure
|
300
|
+
global figure_index, figures, fig_queue, index_control
|
234
301
|
|
235
302
|
if figure_index is not None and figure_index > 0:
|
236
303
|
figure_index -= 1
|
304
|
+
index_control.set(figure_index)
|
305
|
+
figures[figure_index] = standardize_figure(figures[figure_index])
|
237
306
|
display_figure(figures[figure_index])
|
238
|
-
clear_unused_figures()
|
307
|
+
#clear_unused_figures()
|
239
308
|
|
240
309
|
def show_next_figure():
|
241
|
-
|
310
|
+
from .gui_elements import standardize_figure
|
311
|
+
global figure_index, figures, fig_queue, index_control
|
242
312
|
if figure_index is not None and figure_index < len(figures) - 1:
|
243
313
|
figure_index += 1
|
314
|
+
index_control.set(figure_index)
|
315
|
+
index_control.set_to(len(figures) - 1)
|
316
|
+
figures[figure_index] = standardize_figure(figures[figure_index])
|
244
317
|
display_figure(figures[figure_index])
|
245
|
-
clear_unused_figures()
|
318
|
+
#clear_unused_figures()
|
319
|
+
|
246
320
|
elif figure_index == len(figures) - 1 and not fig_queue.empty():
|
247
321
|
fig = fig_queue.get_nowait()
|
248
322
|
figures.append(fig)
|
249
323
|
figure_index += 1
|
324
|
+
index_control.set(figure_index)
|
325
|
+
index_control.set_to(len(figures) - 1)
|
250
326
|
display_figure(fig)
|
251
|
-
|
252
|
-
def process_fig_queue():
|
253
|
-
global canvas, fig_queue, canvas_widget, parent_frame, uppdate_frequency, figures, figure_index, index_control
|
254
327
|
|
328
|
+
def process_fig_queue_v1():
|
329
|
+
global canvas, fig_queue, canvas_widget, parent_frame, uppdate_frequency, figures, figure_index, index_control
|
255
330
|
from .gui_elements import standardize_figure
|
331
|
+
|
332
|
+
#print("process_fig_queue called", flush=True)
|
256
333
|
try:
|
257
334
|
while not fig_queue.empty():
|
258
335
|
fig = fig_queue.get_nowait()
|
336
|
+
if fig is None:
|
337
|
+
print("Warning: Retrieved a None figure from fig_queue.", flush=True)
|
338
|
+
continue
|
339
|
+
|
340
|
+
# Standardize the figure appearance before adding it
|
341
|
+
fig = standardize_figure(fig)
|
342
|
+
figures.append(fig)
|
343
|
+
|
344
|
+
# Update slider maximum
|
345
|
+
index_control.set_to(len(figures) - 1)
|
259
346
|
|
347
|
+
# If no figure has been displayed yet
|
348
|
+
if figure_index == -1:
|
349
|
+
figure_index = 0
|
350
|
+
display_figure(figures[figure_index])
|
351
|
+
index_control.set(figure_index)
|
352
|
+
|
353
|
+
except Exception as e:
|
354
|
+
print("Exception in process_fig_queue:", e, flush=True)
|
355
|
+
traceback.print_exc()
|
356
|
+
|
357
|
+
finally:
|
358
|
+
# Schedule process_fig_queue() to run again
|
359
|
+
after_id = canvas_widget.after(uppdate_frequency, process_fig_queue)
|
360
|
+
parent_frame.after_tasks.append(after_id)
|
361
|
+
|
362
|
+
def process_fig_queue():
|
363
|
+
global canvas, fig_queue, canvas_widget, parent_frame, uppdate_frequency, figures, figure_index, index_control
|
364
|
+
from .gui_elements import standardize_figure
|
365
|
+
|
366
|
+
try:
|
367
|
+
while not fig_queue.empty():
|
368
|
+
fig = fig_queue.get_nowait()
|
260
369
|
if fig is None:
|
261
370
|
print("Warning: Retrieved a None figure from fig_queue.")
|
262
|
-
continue
|
371
|
+
continue
|
263
372
|
|
264
|
-
# Standardize the figure appearance before adding it
|
373
|
+
# Standardize the figure appearance before adding it
|
265
374
|
fig = standardize_figure(fig)
|
266
|
-
|
267
375
|
figures.append(fig)
|
268
376
|
|
269
|
-
#
|
377
|
+
# OPTIONAL: Cap the size of the figures deque at 100
|
378
|
+
MAX_FIGURES = 100
|
379
|
+
while len(figures) > MAX_FIGURES:
|
380
|
+
# Discard the oldest figure
|
381
|
+
old_fig = figures.popleft()
|
382
|
+
# If needed, you could also close the figure to free memory:
|
383
|
+
matplotlib.pyplot.close(old_fig)
|
384
|
+
|
385
|
+
# Update slider maximum
|
270
386
|
index_control.set_to(len(figures) - 1)
|
271
387
|
|
388
|
+
# If no figure has been displayed yet
|
272
389
|
if figure_index == -1:
|
273
|
-
figure_index
|
390
|
+
figure_index = 0
|
274
391
|
display_figure(figures[figure_index])
|
275
392
|
index_control.set(figure_index)
|
276
|
-
|
393
|
+
|
277
394
|
except Exception as e:
|
395
|
+
print("Exception in process_fig_queue:", e)
|
278
396
|
traceback.print_exc()
|
397
|
+
|
279
398
|
finally:
|
399
|
+
# Schedule process_fig_queue() to run again
|
280
400
|
after_id = canvas_widget.after(uppdate_frequency, process_fig_queue)
|
281
401
|
parent_frame.after_tasks.append(after_id)
|
282
402
|
|
283
|
-
def update_figure(value):
|
284
|
-
global figure_index, figures
|
285
403
|
|
404
|
+
def update_figure(value):
|
405
|
+
from .gui_elements import standardize_figure
|
406
|
+
global figure_index, figures, index_control
|
407
|
+
|
286
408
|
# Convert the value to an integer
|
287
409
|
index = int(value)
|
288
410
|
|
289
411
|
# Check if the index is valid
|
290
412
|
if 0 <= index < len(figures):
|
291
413
|
figure_index = index
|
414
|
+
figures[figure_index] = standardize_figure(figures[figure_index])
|
292
415
|
display_figure(figures[figure_index])
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
416
|
+
index_control.set(figure_index)
|
417
|
+
print("update_figure called with value:", figure_index)
|
418
|
+
index_control.set_to(len(figures) - 1)
|
419
|
+
|
298
420
|
def setup_plot_section(vertical_container, settings_type):
|
299
421
|
global canvas, canvas_widget, figures, figure_index, index_control
|
300
422
|
from .gui_utils import display_media_in_plot_frame
|
@@ -305,29 +427,29 @@ def setup_plot_section(vertical_container, settings_type):
|
|
305
427
|
|
306
428
|
# Initialize deque for storing figures and the current index
|
307
429
|
figures = deque()
|
430
|
+
figure_index = -1 # Start with no figure displayed
|
308
431
|
|
309
432
|
# Create a frame for the plot section
|
310
433
|
plot_frame = tk.Frame(vertical_container)
|
311
434
|
plot_frame.configure(bg=bg)
|
312
435
|
vertical_container.add(plot_frame, stretch="always")
|
313
436
|
|
314
|
-
# Clear the plot_frame (optional
|
437
|
+
# Clear the plot_frame (optional)
|
315
438
|
for widget in plot_frame.winfo_children():
|
316
439
|
widget.destroy()
|
317
440
|
|
318
|
-
# Create a figure and plot
|
441
|
+
# Create a figure and plot (initial figure)
|
319
442
|
figure = Figure(figsize=(30, 4), dpi=100)
|
320
443
|
plot = figure.add_subplot(111)
|
321
444
|
plot.plot([], [])
|
322
445
|
plot.axis('off')
|
323
446
|
|
324
447
|
if settings_type == 'map_barcodes':
|
325
|
-
# Load and display GIF
|
326
448
|
current_dir = os.path.dirname(__file__)
|
327
449
|
resources_path = os.path.join(current_dir, 'resources', 'icons')
|
328
450
|
gif_path = os.path.join(resources_path, 'dna_matrix.mp4')
|
329
|
-
|
330
451
|
display_media_in_plot_frame(gif_path, plot_frame)
|
452
|
+
|
331
453
|
canvas = FigureCanvasTkAgg(figure, master=plot_frame)
|
332
454
|
canvas.get_tk_widget().configure(cursor='arrow', highlightthickness=0)
|
333
455
|
canvas_widget = canvas.get_tk_widget()
|
@@ -348,10 +470,11 @@ def setup_plot_section(vertical_container, settings_type):
|
|
348
470
|
# Create slider
|
349
471
|
control_frame = tk.Frame(plot_frame, height=15*2, bg=bg)
|
350
472
|
control_frame.grid(row=1, column=0, sticky="ew", padx=10, pady=5)
|
351
|
-
control_frame.grid_propagate(False)
|
473
|
+
control_frame.grid_propagate(False)
|
352
474
|
|
353
|
-
|
354
|
-
|
475
|
+
index_control = spacrSlider(control_frame, from_=0, to=0, value=0, thickness=2, knob_radius=10,
|
476
|
+
position="center", show_index=True, command=update_figure)
|
477
|
+
|
355
478
|
index_control.grid(row=0, column=0, sticky="ew")
|
356
479
|
control_frame.grid_columnconfigure(0, weight=1)
|
357
480
|
|
@@ -359,10 +482,17 @@ def setup_plot_section(vertical_container, settings_type):
|
|
359
482
|
style = ttk.Style(vertical_container)
|
360
483
|
_ = set_dark_style(style, containers=containers, widgets=widgets)
|
361
484
|
|
485
|
+
# Now ensure the first figure is displayed and recognized:
|
486
|
+
figures.append(figure)
|
487
|
+
figure_index = 0
|
488
|
+
display_figure(figures[figure_index])
|
489
|
+
index_control.set_to(len(figures) - 1) # Slider max = 0 in this case, since there's only one figure
|
490
|
+
index_control.set(figure_index) # Set slider to 0 to indicate the first figure
|
491
|
+
|
362
492
|
return canvas, canvas_widget
|
363
493
|
|
364
|
-
def set_globals(thread_control_var, q_var, console_output_var, parent_frame_var, vars_dict_var, canvas_var, canvas_widget_var, scrollable_frame_var, fig_queue_var,
|
365
|
-
global thread_control, q, console_output, parent_frame, vars_dict, canvas, canvas_widget, scrollable_frame, fig_queue,
|
494
|
+
def set_globals(thread_control_var, q_var, console_output_var, parent_frame_var, vars_dict_var, canvas_var, canvas_widget_var, scrollable_frame_var, fig_queue_var, progress_bar_var, usage_bars_var):
|
495
|
+
global thread_control, q, console_output, parent_frame, vars_dict, canvas, canvas_widget, scrollable_frame, fig_queue, progress_bar, usage_bars
|
366
496
|
thread_control = thread_control_var
|
367
497
|
q = q_var
|
368
498
|
console_output = console_output_var
|
@@ -372,11 +502,11 @@ def set_globals(thread_control_var, q_var, console_output_var, parent_frame_var,
|
|
372
502
|
canvas_widget = canvas_widget_var
|
373
503
|
scrollable_frame = scrollable_frame_var
|
374
504
|
fig_queue = fig_queue_var
|
375
|
-
figures = figures_var
|
376
|
-
figure_index = figure_index_var
|
505
|
+
#figures = figures_var
|
506
|
+
#figure_index = figure_index_var
|
507
|
+
#index_control = index_control_var
|
377
508
|
progress_bar = progress_bar_var
|
378
509
|
usage_bars = usage_bars_var
|
379
|
-
index_control = index_control_var
|
380
510
|
|
381
511
|
def import_settings(settings_type='mask'):
|
382
512
|
global vars_dict, scrollable_frame, button_scrollable_frame
|
@@ -606,6 +736,7 @@ def setup_button_section(horizontal_container, settings_type='mask', run=True, a
|
|
606
736
|
widgets.append(import_button)
|
607
737
|
btn_row += 1
|
608
738
|
|
739
|
+
btn_row += 1
|
609
740
|
# Add the batch progress bar
|
610
741
|
progress_bar = spacrProgressBar(button_scrollable_frame.scrollable_frame, orient='horizontal', mode='determinate')
|
611
742
|
progress_bar.grid(row=btn_row, column=0, columnspan=7, pady=5, padx=5, sticky='ew')
|
@@ -853,7 +984,8 @@ def process_console_queue():
|
|
853
984
|
if progress_bar:
|
854
985
|
progress_bar['maximum'] = total_progress
|
855
986
|
progress_bar['value'] = unique_progress_count
|
856
|
-
|
987
|
+
#print("Current progress bar value:", progress_bar['value']) # Debugg
|
988
|
+
|
857
989
|
# Store operation type and additional info
|
858
990
|
if operation_type:
|
859
991
|
progress_bar.operation_type = operation_type
|
@@ -955,7 +1087,7 @@ def initiate_root(parent, settings_type='mask'):
|
|
955
1087
|
else:
|
956
1088
|
usage_bars = []
|
957
1089
|
|
958
|
-
set_globals(thread_control, q, console_output, parent_frame, vars_dict, canvas, canvas_widget, scrollable_frame, fig_queue,
|
1090
|
+
set_globals(thread_control, q, console_output, parent_frame, vars_dict, canvas, canvas_widget, scrollable_frame, fig_queue, progress_bar, usage_bars)
|
959
1091
|
description_text = descriptions.get(settings_type, "No description available for this module.")
|
960
1092
|
|
961
1093
|
q.put(f"Console")
|
spacr/gui_elements.py
CHANGED
@@ -17,8 +17,7 @@ from skimage.draw import polygon, line
|
|
17
17
|
from skimage.transform import resize
|
18
18
|
from scipy.ndimage import binary_fill_holes, label
|
19
19
|
from tkinter import ttk, scrolledtext
|
20
|
-
from skimage.color import rgb2gray
|
21
|
-
|
20
|
+
from skimage.color import rgb2gray
|
22
21
|
fig = None
|
23
22
|
|
24
23
|
def create_menu_bar(root):
|
@@ -667,7 +666,7 @@ class spacrProgressBar(ttk.Progressbar):
|
|
667
666
|
# Remove any borders and ensure the active color fills the entire space
|
668
667
|
self.style.configure(
|
669
668
|
"spacr.Horizontal.TProgressbar",
|
670
|
-
troughcolor=self.inactive_color,
|
669
|
+
troughcolor=self.inactive_color, # Set the trough to bg color
|
671
670
|
background=self.active_color, # Active part is the active color
|
672
671
|
borderwidth=0, # Remove border width
|
673
672
|
pbarrelief="flat", # Flat relief for the progress bar
|
spacr/ml.py
CHANGED
@@ -27,6 +27,9 @@ from sklearn.linear_model import Lasso, Ridge
|
|
27
27
|
from sklearn.preprocessing import FunctionTransformer
|
28
28
|
from patsy import dmatrices
|
29
29
|
|
30
|
+
from sklearn.metrics import classification_report, accuracy_score
|
31
|
+
from sklearn.model_selection import StratifiedKFold, cross_val_score
|
32
|
+
from sklearn.feature_selection import SelectKBest, f_classif
|
30
33
|
from sklearn.model_selection import train_test_split
|
31
34
|
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
|
32
35
|
from sklearn.linear_model import LogisticRegression
|
@@ -1165,21 +1168,29 @@ def generate_ml_scores(settings):
|
|
1165
1168
|
|
1166
1169
|
settings = set_default_analyze_screen(settings)
|
1167
1170
|
|
1168
|
-
|
1171
|
+
srcs = settings['src']
|
1169
1172
|
|
1170
1173
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
1171
1174
|
display(settings_df)
|
1172
|
-
|
1173
|
-
db_loc = [src+'/measurements/measurements.db']
|
1174
|
-
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
1175
|
-
|
1176
|
-
nuclei_limit, pathogen_limit = settings['nuclei_limit'], settings['pathogen_limit']
|
1177
1175
|
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1176
|
+
if isinstance(srcs, str):
|
1177
|
+
srcs = [srcs]
|
1178
|
+
|
1179
|
+
df = pd.DataFrame()
|
1180
|
+
for idx, src in enumerate(srcs):
|
1181
|
+
|
1182
|
+
if idx == 0:
|
1183
|
+
src1 = src
|
1184
|
+
|
1185
|
+
db_loc = [src+'/measurements/measurements.db']
|
1186
|
+
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
1187
|
+
|
1188
|
+
dft, _ = _read_and_merge_data(db_loc,
|
1189
|
+
tables,
|
1190
|
+
settings['verbose'],
|
1191
|
+
nuclei_limit=settings['nuclei_limit'],
|
1192
|
+
pathogen_limit=settings['pathogen_limit'])
|
1193
|
+
df = pd.concat([df, dft])
|
1183
1194
|
|
1184
1195
|
if settings['annotation_column'] is not None:
|
1185
1196
|
|
@@ -1191,6 +1202,7 @@ def generate_ml_scores(settings):
|
|
1191
1202
|
annotated_df = png_list_df[['prcfo', settings['annotation_column']]].set_index('prcfo')
|
1192
1203
|
df = annotated_df.merge(df, left_index=True, right_index=True)
|
1193
1204
|
unique_values = df[settings['annotation_column']].dropna().unique()
|
1205
|
+
|
1194
1206
|
if len(unique_values) == 1:
|
1195
1207
|
unannotated_rows = df[df[settings['annotation_column']].isna()].index
|
1196
1208
|
existing_value = unique_values[0]
|
@@ -1213,8 +1225,8 @@ def generate_ml_scores(settings):
|
|
1213
1225
|
df[settings['annotation_column']] = df[settings['annotation_column']].apply(str)
|
1214
1226
|
|
1215
1227
|
if settings['channel_of_interest'] in [0,1,2,3]:
|
1216
|
-
|
1217
|
-
|
1228
|
+
if f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity" and f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity" in df.columns:
|
1229
|
+
df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
|
1218
1230
|
|
1219
1231
|
output, figs = ml_analysis(df,
|
1220
1232
|
settings['channel_of_interest'],
|
@@ -1224,18 +1236,24 @@ def generate_ml_scores(settings):
|
|
1224
1236
|
settings['exclude'],
|
1225
1237
|
settings['n_repeats'],
|
1226
1238
|
settings['top_features'],
|
1239
|
+
settings['reg_alpha'],
|
1240
|
+
settings['reg_lambda'],
|
1241
|
+
settings['learning_rate'],
|
1227
1242
|
settings['n_estimators'],
|
1228
1243
|
settings['test_size'],
|
1229
1244
|
settings['model_type_ml'],
|
1230
1245
|
settings['n_jobs'],
|
1231
1246
|
settings['remove_low_variance_features'],
|
1232
1247
|
settings['remove_highly_correlated_features'],
|
1248
|
+
settings['prune_features'],
|
1249
|
+
settings['cross_validation'],
|
1233
1250
|
settings['verbose'])
|
1234
1251
|
|
1235
1252
|
shap_fig = shap_analysis(output[3], output[4], output[5])
|
1236
1253
|
|
1237
1254
|
features = output[0].select_dtypes(include=[np.number]).columns.tolist()
|
1238
|
-
|
1255
|
+
train_features_df = pd.DataFrame(output[9], columns=['feature'])
|
1256
|
+
|
1239
1257
|
if not settings['heatmap_feature'] in features:
|
1240
1258
|
raise ValueError(f"Variable {settings['heatmap_feature']} not found in the dataframe. Please choose one of the following: {features}")
|
1241
1259
|
|
@@ -1247,15 +1265,16 @@ def generate_ml_scores(settings):
|
|
1247
1265
|
min_count=settings['minimum_cell_count'],
|
1248
1266
|
verbose=settings['verbose'])
|
1249
1267
|
|
1250
|
-
data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(
|
1251
|
-
df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df = output
|
1268
|
+
data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv, ml_features = get_ml_results_paths(src1, settings['model_type_ml'], settings['channel_of_interest'])
|
1269
|
+
df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df, _ = output
|
1252
1270
|
|
1253
1271
|
settings_df.to_csv(settings_csv, index=False)
|
1254
1272
|
df.to_csv(data_path, mode='w', encoding='utf-8')
|
1255
1273
|
permutation_df.to_csv(permutation_path, mode='w', encoding='utf-8')
|
1256
1274
|
feature_importance_df.to_csv(feature_importance_path, mode='w', encoding='utf-8')
|
1275
|
+
train_features_df.to_csv(ml_features, mode='w', encoding='utf-8')
|
1257
1276
|
metrics_df.to_csv(model_metricks_path, mode='w', encoding='utf-8')
|
1258
|
-
|
1277
|
+
|
1259
1278
|
plate_heatmap.savefig(plate_heatmap_path, format='pdf')
|
1260
1279
|
figs[0].savefig(permutation_fig_path, format='pdf')
|
1261
1280
|
figs[1].savefig(feature_importance_fig_path, format='pdf')
|
@@ -1263,7 +1282,7 @@ def generate_ml_scores(settings):
|
|
1263
1282
|
|
1264
1283
|
if settings['save_to_db']:
|
1265
1284
|
settings['csv_path'] = data_path
|
1266
|
-
settings['db_path'] = os.path.join(
|
1285
|
+
settings['db_path'] = os.path.join(src1, 'measurements', 'measurements.db')
|
1267
1286
|
settings['table_name'] = 'png_list'
|
1268
1287
|
settings['update_column'] = 'predictions'
|
1269
1288
|
settings['match_column'] = 'prcfo'
|
@@ -1271,7 +1290,7 @@ def generate_ml_scores(settings):
|
|
1271
1290
|
|
1272
1291
|
return [output, plate_heatmap]
|
1273
1292
|
|
1274
|
-
def ml_analysis(df, channel_of_interest=3, location_column='column_name', positive_control='c2', negative_control='c1', exclude=None, n_repeats=10, top_features=30, n_estimators=
|
1293
|
+
def ml_analysis(df, channel_of_interest=3, location_column='column_name', positive_control='c2', negative_control='c1', exclude=None, n_repeats=10, top_features=30, reg_alpha=0.1, reg_lambda=1.0, learning_rate=0.00001, n_estimators=1000, test_size=0.2, model_type='xgboost', n_jobs=-1, remove_low_variance_features=True, remove_highly_correlated_features=True, prune_features=False, cross_validation=False, verbose=False):
|
1275
1294
|
|
1276
1295
|
"""
|
1277
1296
|
Calculates permutation importance for numerical features in the dataframe,
|
@@ -1313,7 +1332,8 @@ def ml_analysis(df, channel_of_interest=3, location_column='column_name', positi
|
|
1313
1332
|
if verbose:
|
1314
1333
|
print(f'Found {len(features)} numerical features in the dataframe')
|
1315
1334
|
print(f'Features used in training: {features}')
|
1316
|
-
|
1335
|
+
print(f'Features: {features}')
|
1336
|
+
|
1317
1337
|
df = pd.concat([df, df_metadata[location_column]], axis=1)
|
1318
1338
|
|
1319
1339
|
# Subset the dataframe based on specified column values
|
@@ -1327,14 +1347,26 @@ def ml_analysis(df, channel_of_interest=3, location_column='column_name', positi
|
|
1327
1347
|
# Combine the subsets for analysis
|
1328
1348
|
combined_df = pd.concat([df1, df2])
|
1329
1349
|
combined_df = combined_df.drop(columns=[location_column])
|
1350
|
+
|
1330
1351
|
if verbose:
|
1331
1352
|
print(f'Found {len(df1)} samples for {negative_control} and {len(df2)} samples for {positive_control}. Total: {len(combined_df)}')
|
1332
|
-
|
1353
|
+
|
1333
1354
|
X = combined_df[features]
|
1334
1355
|
y = combined_df['target']
|
1335
|
-
|
1336
|
-
|
1337
|
-
|
1356
|
+
|
1357
|
+
if prune_features:
|
1358
|
+
before_pruning = len(X.columns)
|
1359
|
+
selector = SelectKBest(score_func=f_classif, k=top_features)
|
1360
|
+
X_selected = selector.fit_transform(X, y)
|
1361
|
+
|
1362
|
+
# Get the selected feature names
|
1363
|
+
selected_features = X.columns[selector.get_support()]
|
1364
|
+
X = pd.DataFrame(X_selected, columns=selected_features, index=X.index)
|
1365
|
+
|
1366
|
+
features = selected_features.tolist()
|
1367
|
+
|
1368
|
+
after_pruning = len(X.columns)
|
1369
|
+
print(f"Removed {before_pruning - after_pruning} features using SelectKBest")
|
1338
1370
|
|
1339
1371
|
# Split the data into training and testing sets
|
1340
1372
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
|
@@ -1353,12 +1385,102 @@ def ml_analysis(df, channel_of_interest=3, location_column='column_name', positi
|
|
1353
1385
|
elif model_type == 'gradient_boosting':
|
1354
1386
|
model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
|
1355
1387
|
elif model_type == 'xgboost':
|
1356
|
-
model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
|
1388
|
+
model = XGBClassifier(reg_alpha=reg_alpha, reg_lambda=reg_lambda, learning_rate=learning_rate, n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
|
1389
|
+
|
1357
1390
|
else:
|
1358
1391
|
raise ValueError(f"Unsupported model_type: {model_type}")
|
1359
1392
|
|
1360
|
-
|
1393
|
+
# Perform k-fold cross-validation
|
1394
|
+
if cross_validation:
|
1395
|
+
|
1396
|
+
# Cross-validation setup
|
1397
|
+
kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state)
|
1398
|
+
fold_metrics = []
|
1399
|
+
|
1400
|
+
for fold_idx, (train_index, test_index) in enumerate(kfold.split(X, y), start=1):
|
1401
|
+
X_train, X_test = X.iloc[train_index], X.iloc[test_index]
|
1402
|
+
y_train, y_test = y.iloc[train_index], y.iloc[test_index]
|
1403
|
+
|
1404
|
+
# Train the model
|
1405
|
+
model.fit(X_train, y_train)
|
1406
|
+
|
1407
|
+
# Predict for the current test set
|
1408
|
+
predictions_test = model.predict(X_test)
|
1409
|
+
combined_df.loc[X_test.index, 'predictions'] = predictions_test
|
1410
|
+
|
1411
|
+
# Get prediction probabilities for the test set
|
1412
|
+
prediction_probabilities_test = model.predict_proba(X_test)
|
1361
1413
|
|
1414
|
+
# Find the optimal threshold
|
1415
|
+
optimal_threshold = find_optimal_threshold(y_test, prediction_probabilities_test[:, 1])
|
1416
|
+
if verbose:
|
1417
|
+
print(f'Fold {fold_idx} - Optimal threshold: {optimal_threshold}')
|
1418
|
+
|
1419
|
+
# Assign predictions and probabilities to the test set in the DataFrame
|
1420
|
+
df.loc[X_test.index, 'predictions'] = predictions_test
|
1421
|
+
for i in range(prediction_probabilities_test.shape[1]):
|
1422
|
+
df.loc[X_test.index, f'prediction_probability_class_{i}'] = prediction_probabilities_test[:, i]
|
1423
|
+
|
1424
|
+
# Evaluate performance for the current fold
|
1425
|
+
fold_report = classification_report(y_test, predictions_test, output_dict=True)
|
1426
|
+
fold_metrics.append(pd.DataFrame(fold_report).transpose())
|
1427
|
+
|
1428
|
+
if verbose:
|
1429
|
+
print(f"Fold {fold_idx} Classification Report:")
|
1430
|
+
print(classification_report(y_test, predictions_test))
|
1431
|
+
|
1432
|
+
# Aggregate metrics across all folds
|
1433
|
+
metrics_df = pd.concat(fold_metrics).groupby(level=0).mean()
|
1434
|
+
|
1435
|
+
# Re-train on full data (X, y) and then apply to entire df
|
1436
|
+
model.fit(X, y)
|
1437
|
+
all_predictions = model.predict(df[features]) # Predict on entire df
|
1438
|
+
df['predictions'] = all_predictions
|
1439
|
+
|
1440
|
+
# Get prediction probabilities for all rows in df
|
1441
|
+
prediction_probabilities = model.predict_proba(df[features])
|
1442
|
+
for i in range(prediction_probabilities.shape[1]):
|
1443
|
+
df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
|
1444
|
+
|
1445
|
+
if verbose:
|
1446
|
+
print("\nFinal Classification Report on Full Dataset:")
|
1447
|
+
print(classification_report(y, all_predictions))
|
1448
|
+
|
1449
|
+
# Generate metrics DataFrame
|
1450
|
+
final_report_dict = classification_report(y, all_predictions, output_dict=True)
|
1451
|
+
metrics_df = pd.DataFrame(final_report_dict).transpose()
|
1452
|
+
|
1453
|
+
else:
|
1454
|
+
model.fit(X_train, y_train)
|
1455
|
+
# Predicting the target variable for the test set
|
1456
|
+
predictions_test = model.predict(X_test)
|
1457
|
+
combined_df.loc[X_test.index, 'predictions'] = predictions_test
|
1458
|
+
|
1459
|
+
# Get prediction probabilities for the test set
|
1460
|
+
prediction_probabilities_test = model.predict_proba(X_test)
|
1461
|
+
|
1462
|
+
# Find the optimal threshold
|
1463
|
+
optimal_threshold = find_optimal_threshold(y_test, prediction_probabilities_test[:, 1])
|
1464
|
+
if verbose:
|
1465
|
+
print(f'Optimal threshold: {optimal_threshold}')
|
1466
|
+
|
1467
|
+
# Predicting the target variable for all other rows in the dataframe
|
1468
|
+
X_all = df[features]
|
1469
|
+
all_predictions = model.predict(X_all)
|
1470
|
+
df['predictions'] = all_predictions
|
1471
|
+
|
1472
|
+
# Get prediction probabilities for all rows in the dataframe
|
1473
|
+
prediction_probabilities = model.predict_proba(X_all)
|
1474
|
+
for i in range(prediction_probabilities.shape[1]):
|
1475
|
+
df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
|
1476
|
+
|
1477
|
+
if verbose:
|
1478
|
+
print("\nClassification Report:")
|
1479
|
+
print(classification_report(y_test, predictions_test))
|
1480
|
+
|
1481
|
+
report_dict = classification_report(y_test, predictions_test, output_dict=True)
|
1482
|
+
metrics_df = pd.DataFrame(report_dict).transpose()
|
1483
|
+
|
1362
1484
|
perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
|
1363
1485
|
|
1364
1486
|
# Create a DataFrame for permutation importances
|
@@ -1387,40 +1509,13 @@ def ml_analysis(df, channel_of_interest=3, location_column='column_name', positi
|
|
1387
1509
|
else:
|
1388
1510
|
feature_importance_df = pd.DataFrame()
|
1389
1511
|
|
1390
|
-
# Predicting the target variable for the test set
|
1391
|
-
predictions_test = model.predict(X_test)
|
1392
|
-
combined_df.loc[X_test.index, 'predictions'] = predictions_test
|
1393
|
-
|
1394
|
-
# Get prediction probabilities for the test set
|
1395
|
-
prediction_probabilities_test = model.predict_proba(X_test)
|
1396
|
-
|
1397
|
-
# Find the optimal threshold
|
1398
|
-
optimal_threshold = find_optimal_threshold(y_test, prediction_probabilities_test[:, 1])
|
1399
|
-
if verbose:
|
1400
|
-
print(f'Optimal threshold: {optimal_threshold}')
|
1401
|
-
|
1402
|
-
# Predicting the target variable for all other rows in the dataframe
|
1403
|
-
X_all = df[features]
|
1404
|
-
all_predictions = model.predict(X_all)
|
1405
|
-
df['predictions'] = all_predictions
|
1406
|
-
|
1407
|
-
# Get prediction probabilities for all rows in the dataframe
|
1408
|
-
prediction_probabilities = model.predict_proba(X_all)
|
1409
|
-
for i in range(prediction_probabilities.shape[1]):
|
1410
|
-
df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
|
1411
|
-
if verbose:
|
1412
|
-
print("\nClassification Report:")
|
1413
|
-
print(classification_report(y_test, predictions_test))
|
1414
|
-
report_dict = classification_report(y_test, predictions_test, output_dict=True)
|
1415
|
-
metrics_df = pd.DataFrame(report_dict).transpose()
|
1416
|
-
|
1417
1512
|
df = _calculate_similarity(df, features, location_column, positive_control, negative_control)
|
1418
1513
|
|
1419
1514
|
df['prcfo'] = df.index.astype(str)
|
1420
1515
|
df[['plate', 'row_name', 'column_name', 'field', 'object']] = df['prcfo'].str.split('_', expand=True)
|
1421
1516
|
df['prc'] = df['plate'] + '_' + df['row_name'] + '_' + df['column_name']
|
1422
1517
|
|
1423
|
-
return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test, metrics_df], [permutation_fig, feature_importance_fig]
|
1518
|
+
return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test, metrics_df, features], [permutation_fig, feature_importance_fig]
|
1424
1519
|
|
1425
1520
|
def shap_analysis(model, X_train, X_test):
|
1426
1521
|
|
@@ -1495,9 +1590,9 @@ def _calculate_similarity(df, features, col_to_compare, val1, val2):
|
|
1495
1590
|
inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
|
1496
1591
|
|
1497
1592
|
# Calculate similarity scores
|
1498
|
-
def safe_similarity(func, row, control):
|
1593
|
+
def safe_similarity(func, row, control, *args, **kwargs):
|
1499
1594
|
try:
|
1500
|
-
return func(row, control)
|
1595
|
+
return func(row, control, *args, **kwargs)
|
1501
1596
|
except Exception:
|
1502
1597
|
return np.nan
|
1503
1598
|
|
spacr/settings.py
CHANGED
@@ -283,7 +283,10 @@ def set_default_analyze_screen(settings):
|
|
283
283
|
settings.setdefault('cmap','viridis')
|
284
284
|
settings.setdefault('channel_of_interest',3)
|
285
285
|
settings.setdefault('minimum_cell_count',25)
|
286
|
-
settings.setdefault('
|
286
|
+
settings.setdefault('reg_alpha',0.1)
|
287
|
+
settings.setdefault('reg_lambda',1.0)
|
288
|
+
settings.setdefault('learning_rate',0.001)
|
289
|
+
settings.setdefault('n_estimators',1000)
|
287
290
|
settings.setdefault('test_size',0.2)
|
288
291
|
settings.setdefault('location_column','column_name')
|
289
292
|
settings.setdefault('positive_control','c2')
|
@@ -296,6 +299,8 @@ def set_default_analyze_screen(settings):
|
|
296
299
|
settings.setdefault('remove_low_variance_features',True)
|
297
300
|
settings.setdefault('remove_highly_correlated_features',True)
|
298
301
|
settings.setdefault('n_jobs',-1)
|
302
|
+
settings.setdefault('prune_features',False)
|
303
|
+
settings.setdefault('cross_validation',True)
|
299
304
|
settings.setdefault('verbose',True)
|
300
305
|
return settings
|
301
306
|
|
@@ -872,10 +877,40 @@ expected_types = {
|
|
872
877
|
"target_layer":str,
|
873
878
|
"save_to_db":bool,
|
874
879
|
"test_mode":bool,
|
880
|
+
"test_images":int,
|
881
|
+
"remove_background_cell":bool,
|
882
|
+
"remove_background_nucleus":bool,
|
883
|
+
"remove_background_pathogen":bool,
|
884
|
+
"figuresize":int,
|
885
|
+
"cmap":str,
|
886
|
+
"pathogen_model":str,
|
875
887
|
"normalize_input":bool,
|
888
|
+
"filter_column":str,
|
889
|
+
"target_unique_count":int,
|
890
|
+
"threshold_multiplier":int,
|
891
|
+
"threshold_method":str,
|
892
|
+
"count_data":list,
|
893
|
+
"score_data":list,
|
894
|
+
"min_n":int,
|
895
|
+
"controls":list,
|
896
|
+
"toxo":bool,
|
897
|
+
"volcano":str,
|
898
|
+
"metadata_files":list,
|
899
|
+
"filter_value":list,
|
900
|
+
"split_axis_lims":str,
|
901
|
+
"x_lim":(list,None),
|
902
|
+
"log_x":bool,
|
903
|
+
"log_y":bool,
|
904
|
+
"reg_alpha":(int,float),
|
905
|
+
"reg_lambda":(int,float),
|
906
|
+
"prune_features":bool,
|
907
|
+
"cross_validation":bool,
|
908
|
+
"offset_start":int,
|
909
|
+
"chunk_size":int,
|
910
|
+
"single_direction":str,
|
876
911
|
}
|
877
912
|
|
878
|
-
categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "dataset","model_path","grna_csv","row_csv","column_csv"],
|
913
|
+
categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "dataset","model_path","grna_csv","row_csv","column_csv", "metadata_files", "score_data","count_data"],
|
879
914
|
"General": ["metadata_type", "custom_regex", "experiment", "channels", "magnification", "channel_dims", "apply_model_to_dataset", "generate_training_dataset", "train_DL_model", "segmentation_mode"],
|
880
915
|
"Cellpose":["fill_in","from_scratch", "n_epochs", "width_height", "model_name", "custom_model", "resample", "rescale", "CP_prob", "flow_threshold", "percentiles", "invert", "diameter", "grayscale", "Signal_to_noise", "resize", "target_height", "target_width"],
|
881
916
|
"Cell": ["cell_intensity_range", "cell_size_range", "cell_chann_dim", "cell_channel", "cell_background", "cell_Signal_to_noise", "cell_CP_prob", "cell_FT", "remove_background_cell", "cell_min_size", "cell_mask_dim", "cytoplasm", "cytoplasm_min_size", "uninfected", "merge_edge_pathogen_cells", "adjust_cells", "cells", "cell_loc"],
|
@@ -883,18 +918,18 @@ categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "dataset
|
|
883
918
|
"Pathogen": ["pathogen_intensity_range", "pathogen_size_range", "pathogen_chann_dim", "pathogen_channel", "pathogen_background", "pathogen_Signal_to_noise", "pathogen_CP_prob", "pathogen_FT", "pathogen_model", "remove_background_pathogen", "pathogen_min_size", "pathogen_mask_dim", "pathogens", "pathogen_loc", "pathogen_types", "pathogen_plate_metadata", ],
|
884
919
|
"Measurements": ["remove_image_canvas", "remove_highly_correlated", "homogeneity", "homogeneity_distances", "radial_dist", "calculate_correlation", "manders_thresholds", "save_measurements", "tables", "image_nr", "dot_size", "filter_by", "remove_highly_correlated_features", "remove_low_variance_features", "channel_of_interest"],
|
885
920
|
"Object Image": ["save_png", "dialate_pngs", "dialate_png_ratios", "png_size", "png_dims", "save_arrays", "normalize_by", "crop_mode", "normalize", "use_bounding_box"],
|
886
|
-
"Sequencing": ["signal_direction","mode","comp_level","comp_type","save_h5","expected_end","offset","target_sequence","regex", "highlight"],
|
921
|
+
"Sequencing": ["offset_start","chunk_size","single_direction", "signal_direction","mode","comp_level","comp_type","save_h5","expected_end","offset","target_sequence","regex", "highlight"],
|
887
922
|
"Generate Dataset":["save_to_db","file_metadata","class_metadata", "annotation_column","annotated_classes", "dataset_mode", "metadata_type_by","custom_measurement", "sample", "size"],
|
888
923
|
"Hyperparamiters (Training)": ["png_type", "score_threshold","file_type", "train_channels", "epochs", "loss_type", "optimizer_type","image_size","val_split","learning_rate","weight_decay","dropout_rate", "init_weights", "train", "classes", "augment", "amsgrad","use_checkpoint","gradient_accumulation","gradient_accumulation_steps","intermedeate_save","pin_memory"],
|
889
924
|
"Hyperparamiters (Embedding)": ["visualize","n_neighbors","min_dist","metric","resnet_features","reduction_method","embedding_by_controls","col_to_compare","log_data"],
|
890
925
|
"Hyperparamiters (Clustering)": ["eps","min_samples","analyze_clusters","clustering","remove_cluster_noise"],
|
891
|
-
"Hyperparamiters (Regression)":["cov_type", "class_1_threshold", "plate", "other", "fraction_threshold", "alpha", "random_row_column_effects", "regression_type", "min_cell_count", "agg_type", "transform", "dependent_variable"],
|
926
|
+
"Hyperparamiters (Regression)":["cross_validation","prune_features","reg_lambda","reg_alpha","cov_type", "class_1_threshold", "plate", "other", "fraction_threshold", "alpha", "random_row_column_effects", "regression_type", "min_cell_count", "agg_type", "transform", "dependent_variable"],
|
892
927
|
"Hyperparamiters (Activation)":["cam_type", "overlay", "correlation", "target_layer", "normalize_input"],
|
893
|
-
"Annotation": ["nc_loc", "pc_loc", "nc", "pc", "cell_plate_metadata","treatment_plate_metadata", "metadata_types", "cell_types", "target","positive_control","negative_control", "location_column", "treatment_loc", "channel_of_interest", "measurement", "treatments", "um_per_pixel", "nr_imgs", "exclude", "exclude_conditions", "mix", "pos", "neg"],
|
894
|
-
"Plot": ["plot", "plot_control", "plot_nr", "examples_to_plot", "normalize_plots", "cmap", "figuresize", "plot_cluster_grids", "img_zoom", "row_limit", "color_by", "plot_images", "smooth_lines", "plot_points", "plot_outlines", "black_background", "plot_by_cluster", "heatmap_feature","grouping","min_max","cmap","save_figure"],
|
928
|
+
"Annotation": ["filter_column", "filter_value","volcano", "toxo", "controls", "nc_loc", "pc_loc", "nc", "pc", "cell_plate_metadata","treatment_plate_metadata", "metadata_types", "cell_types", "target","positive_control","negative_control", "location_column", "treatment_loc", "channel_of_interest", "measurement", "treatments", "um_per_pixel", "nr_imgs", "exclude", "exclude_conditions", "mix", "pos", "neg"],
|
929
|
+
"Plot": ["plot", "split_axis_lims", "x_lim","log_x","log_y", "plot_control", "plot_nr", "examples_to_plot", "normalize_plots", "cmap", "figuresize", "plot_cluster_grids", "img_zoom", "row_limit", "color_by", "plot_images", "smooth_lines", "plot_points", "plot_outlines", "black_background", "plot_by_cluster", "heatmap_feature","grouping","min_max","cmap","save_figure"],
|
895
930
|
"Test": ["test_mode", "test_images", "random_test", "test_nr", "test", "test_split"],
|
896
931
|
"Timelapse": ["timelapse", "fps", "timelapse_displacement", "timelapse_memory", "timelapse_frame_limits", "timelapse_remove_transient", "timelapse_mode", "timelapse_objects", "compartments"],
|
897
|
-
"Advanced": ["shuffle", "target_intensity_min", "cells_per_well", "nuclei_limit", "pathogen_limit", "background", "backgrounds", "schedule", "test_size","exclude","n_repeats","top_features", "model_type_ml", "model_type","minimum_cell_count","n_estimators","preprocess", "remove_background", "normalize", "lower_percentile", "merge_pathogens", "batch_size", "filter", "save", "masks", "verbose", "randomize", "n_jobs"],
|
932
|
+
"Advanced": ["target_unique_count","threshold_multiplier", "threshold_method", "min_n","shuffle", "target_intensity_min", "cells_per_well", "nuclei_limit", "pathogen_limit", "background", "backgrounds", "schedule", "test_size","exclude","n_repeats","top_features", "model_type_ml", "model_type","minimum_cell_count","n_estimators","preprocess", "remove_background", "normalize", "lower_percentile", "merge_pathogens", "batch_size", "filter", "save", "masks", "verbose", "randomize", "n_jobs"],
|
898
933
|
"Miscellaneous": ["all_to_mip", "pick_slice", "skip_mode", "upscale", "upscale_factor"]
|
899
934
|
}
|
900
935
|
|
spacr/utils.py
CHANGED
@@ -4277,6 +4277,12 @@ def filter_dataframe_features(df, channel_of_interest, exclude=None, remove_low_
|
|
4277
4277
|
|
4278
4278
|
if remove_highly_correlated_features:
|
4279
4279
|
df = remove_highly_correlated_columns(df, threshold=0.95, verbose=verbose)
|
4280
|
+
|
4281
|
+
# Remove columns with NaN values
|
4282
|
+
before_drop_NaN = len(df.columns)
|
4283
|
+
df = df.dropna(axis=1)
|
4284
|
+
after_drop_NaN = len(df.columns)
|
4285
|
+
print(f"Dropped {before_drop_NaN - after_drop_NaN} columns with NaN values")
|
4280
4286
|
|
4281
4287
|
# Select numerical features
|
4282
4288
|
features = df.select_dtypes(include=[np.number]).columns.tolist()
|
@@ -4759,7 +4765,8 @@ def get_ml_results_paths(src, model_type='xgboost', channel_of_interest=1):
|
|
4759
4765
|
shap_fig_path = os.path.join(res_fldr, 'shap.pdf')
|
4760
4766
|
plate_heatmap_path = os.path.join(res_fldr, 'plate_heatmap.pdf')
|
4761
4767
|
settings_csv = os.path.join(res_fldr, 'ml_settings.csv')
|
4762
|
-
|
4768
|
+
ml_features = os.path.join(res_fldr, 'ml_features.csv')
|
4769
|
+
return data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv, ml_features
|
4763
4770
|
|
4764
4771
|
def augment_image(image):
|
4765
4772
|
"""
|
@@ -5038,19 +5045,22 @@ def generate_cytoplasm_mask(nucleus_mask, cell_mask):
|
|
5038
5045
|
return cytoplasm_mask
|
5039
5046
|
|
5040
5047
|
def add_column_to_database(settings):
|
5041
|
-
"""
|
5042
|
-
Adds a new column to the database table by matching on a common column from the DataFrame.
|
5043
|
-
If the column already exists in the database, it adds the column with a suffix.
|
5044
|
-
NaN values will remain as NULL in the database.
|
5045
|
-
|
5046
|
-
Parameters:
|
5047
|
-
|
5048
|
-
|
5049
|
-
|
5050
|
-
|
5051
|
-
|
5052
|
-
|
5053
|
-
|
5048
|
+
#"""
|
5049
|
+
#Adds a new column to the database table by matching on a common column from the DataFrame.
|
5050
|
+
#If the column already exists in the database, it adds the column with a suffix.
|
5051
|
+
#NaN values will remain as NULL in the database.
|
5052
|
+
|
5053
|
+
#Parameters:
|
5054
|
+
# settings (dict): A dictionary containing the following keys:
|
5055
|
+
# csv_path (str): Path to the CSV file with the data to be added.
|
5056
|
+
# db_path (str): Path to the SQLite database (or connection string for other databases).
|
5057
|
+
# table_name (str): The name of the table in the database.
|
5058
|
+
# update_column (str): The name of the new column in the DataFrame to add to the database.
|
5059
|
+
# match_column (str): The common column used to match rows.
|
5060
|
+
|
5061
|
+
#Returns:
|
5062
|
+
# None
|
5063
|
+
#"""
|
5054
5064
|
|
5055
5065
|
# Read the DataFrame from the provided CSV path
|
5056
5066
|
df = pd.read_csv(settings['csv_path'])
|
@@ -12,24 +12,24 @@ spacr/chat_bot.py,sha256=n3Fhqg3qofVXHmh3H9sUcmfYy9MmgRnr48663MVdY9E,1244
|
|
12
12
|
spacr/core.py,sha256=3u2qKmPmTlswvE1uKTF4gi7KQ3sJBHV9No_ysgk7JCU,48487
|
13
13
|
spacr/deep_spacr.py,sha256=V3diLyxX-0_F5UxhX_b94ROOvL9eoLvnoUmF3nMBqPQ,43250
|
14
14
|
spacr/gui.py,sha256=ARyn9Q_g8HoP-cXh1nzMLVFCKqthY4v2u9yORyaQqQE,8230
|
15
|
-
spacr/gui_core.py,sha256=
|
16
|
-
spacr/gui_elements.py,sha256=
|
15
|
+
spacr/gui_core.py,sha256=3S1S7rOtGfm5Z-Fb-jk1HxjsLZo6IboBxtXtsDMw1ME,46578
|
16
|
+
spacr/gui_elements.py,sha256=fc9z-GP_PmtJy2lIsfi3fx5r5puTJ-zz8ntB9pC53Gc,138246
|
17
17
|
spacr/gui_utils.py,sha256=u9RoIOWpAXFEOnUlLpMQZrc1pWSg6omZsJMIhJdRv_g,41211
|
18
18
|
spacr/io.py,sha256=LF6lpphw7GSeuoHQijPykjKNF56wNTFEWFZuDQp3O6Q,145739
|
19
19
|
spacr/logger.py,sha256=lJhTqt-_wfAunCPl93xE65Wr9Y1oIHJWaZMjunHUeIw,1538
|
20
20
|
spacr/measure.py,sha256=2lK-ZcTxLM-MpXV1oZnucRD9iz5aprwahRKw9IEqshg,55085
|
21
21
|
spacr/mediar.py,sha256=FwLvbLQW5LQzPgvJZG8Lw7GniA2vbZx6Jv6vIKu7I5c,14743
|
22
|
-
spacr/ml.py,sha256=
|
22
|
+
spacr/ml.py,sha256=x19S8OsR5omb8e6MU9I99Nz95J_QvM5siyk-zaAU3p8,82866
|
23
23
|
spacr/openai.py,sha256=5vBZ3Jl2llYcW3oaTEXgdyCB2aJujMUIO5K038z7w_A,1246
|
24
24
|
spacr/plot.py,sha256=gXC7y3uT4sx8KRODeSFWQG_A1CylsuJ5B7HYe_un6so,165177
|
25
25
|
spacr/sequencing.py,sha256=ClUfwPPK6rNUbUuiEkzcwakzVyDKKUMv9ricrxT8qQY,25227
|
26
|
-
spacr/settings.py,sha256=
|
26
|
+
spacr/settings.py,sha256=3bgBWBotIO7TZx2nh6JoEaqHNmwbChAiz1gW4xmURQs,81788
|
27
27
|
spacr/sim.py,sha256=1xKhXimNU3ukzIw-3l9cF3Znc_brW8h20yv8fSTzvss,71173
|
28
28
|
spacr/stats.py,sha256=mbhwsyIqt5upsSD346qGjdCw7CFBa0tIS7zHU9e0jNI,9536
|
29
29
|
spacr/submodules.py,sha256=SK8YEs850LAx30YAiwap7ecLpp1_p-bci6H-Or0GLoA,55500
|
30
30
|
spacr/timelapse.py,sha256=KGfG4L4-QnFfgbF7L6C5wL_3gd_rqr05Foje6RsoTBg,39603
|
31
31
|
spacr/toxo.py,sha256=z2nT5aAze3NUIlwnBQcnkARihDwoPfqOgQIVoUluyK0,25087
|
32
|
-
spacr/utils.py,sha256=
|
32
|
+
spacr/utils.py,sha256=jjZIqzJKl9nnWgj2eJiLFT27gED8hO6rAEsJLlimm-E,222298
|
33
33
|
spacr/version.py,sha256=axH5tnGwtgSnJHb5IDhiu4Zjk5GhLyAEDRe-rnaoFOA,409
|
34
34
|
spacr/resources/MEDIAR/.gitignore,sha256=Ff1q9Nme14JUd-4Q3jZ65aeQ5X4uttptssVDgBVHYo8,152
|
35
35
|
spacr/resources/MEDIAR/LICENSE,sha256=yEj_TRDLUfDpHDNM0StALXIt6mLqSgaV2hcCwa6_TcY,1065
|
@@ -152,9 +152,9 @@ spacr/resources/icons/umap.png,sha256=dOLF3DeLYy9k0nkUybiZMe1wzHQwLJFRmgccppw-8b
|
|
152
152
|
spacr/resources/images/plate1_E01_T0001F001L01A01Z01C02.tif,sha256=Tl0ZUfZ_AYAbu0up_nO0tPRtF1BxXhWQ3T3pURBCCRo,7958528
|
153
153
|
spacr/resources/images/plate1_E01_T0001F001L01A02Z01C01.tif,sha256=m8N-V71rA1TT4dFlENNg8s0Q0YEXXs8slIn7yObmZJQ,7958528
|
154
154
|
spacr/resources/images/plate1_E01_T0001F001L01A03Z01C03.tif,sha256=Pbhk7xn-KUP6RSIhJsxQcrHFImBm3GEpLkzx7WOc-5M,7958528
|
155
|
-
spacr-0.3.
|
156
|
-
spacr-0.3.
|
157
|
-
spacr-0.3.
|
158
|
-
spacr-0.3.
|
159
|
-
spacr-0.3.
|
160
|
-
spacr-0.3.
|
155
|
+
spacr-0.3.81.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
|
156
|
+
spacr-0.3.81.dist-info/METADATA,sha256=TneYaiWWfNHUuqpSQpRAEzWzYia6KIksYDzNVDcsimw,6032
|
157
|
+
spacr-0.3.81.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
|
158
|
+
spacr-0.3.81.dist-info/entry_points.txt,sha256=BMC0ql9aNNpv8lUZ8sgDLQMsqaVnX5L535gEhKUP5ho,296
|
159
|
+
spacr-0.3.81.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
|
160
|
+
spacr-0.3.81.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|