spacr 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. spacr/__init__.py +19 -3
  2. spacr/cellpose.py +311 -0
  3. spacr/core.py +140 -2493
  4. spacr/deep_spacr.py +151 -29
  5. spacr/gui.py +1 -0
  6. spacr/gui_core.py +74 -63
  7. spacr/gui_elements.py +110 -5
  8. spacr/gui_utils.py +346 -6
  9. spacr/io.py +624 -44
  10. spacr/logger.py +28 -9
  11. spacr/measure.py +107 -95
  12. spacr/mediar.py +0 -3
  13. spacr/ml.py +964 -0
  14. spacr/openai.py +37 -0
  15. spacr/plot.py +280 -15
  16. spacr/resources/data/lopit.csv +3833 -0
  17. spacr/resources/data/toxoplasma_metadata.csv +8843 -0
  18. spacr/resources/icons/convert.png +0 -0
  19. spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
  20. spacr/sequencing.py +241 -1311
  21. spacr/settings.py +129 -43
  22. spacr/sim.py +0 -2
  23. spacr/submodules.py +348 -0
  24. spacr/timelapse.py +0 -2
  25. spacr/toxo.py +233 -0
  26. spacr/utils.py +271 -171
  27. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/METADATA +7 -1
  28. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/RECORD +32 -33
  29. spacr/chris.py +0 -50
  30. spacr/graph_learning.py +0 -340
  31. spacr/resources/MEDIAR/.git +0 -1
  32. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  33. spacr/resources/icons/.DS_Store +0 -0
  34. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  35. spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  36. spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  37. spacr/sim_app.py +0 -0
  38. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/LICENSE +0 -0
  39. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/WHEEL +0 -0
  40. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/entry_points.txt +0 -0
  41. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/top_level.txt +0 -0
spacr/deep_spacr.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import os, torch, time, gc, datetime
2
-
3
2
  torch.backends.cudnn.benchmark = True
4
3
 
5
4
  import numpy as np
@@ -8,15 +7,138 @@ from torch.optim import Adagrad, AdamW
8
7
  from torch.autograd import grad
9
8
  from torch.optim.lr_scheduler import StepLR
10
9
  import torch.nn.functional as F
11
- from IPython.display import display, clear_output
12
10
  import matplotlib.pyplot as plt
13
11
  from PIL import Image
14
12
  from sklearn.metrics import auc, precision_recall_curve
15
- from multiprocessing import set_start_method
16
- #set_start_method('spawn', force=True)
17
13
 
18
- from .logger import log_function_call
19
- from .utils import close_multiprocessing_processes, reset_mp
14
+ from torchvision import transforms
15
+ from torch.utils.data import DataLoader
16
+
17
+ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, n_jobs=10):
18
+
19
+ from .io import NoClassDataset
20
+ from .utils import print_progress
21
+
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+
24
+ if normalize:
25
+ transform = transforms.Compose([
26
+ transforms.ToTensor(),
27
+ transforms.CenterCrop(size=(image_size, image_size)),
28
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
29
+ else:
30
+ transform = transforms.Compose([
31
+ transforms.ToTensor(),
32
+ transforms.CenterCrop(size=(image_size, image_size))])
33
+
34
+ model = torch.load(model_path)
35
+ print(model)
36
+
37
+ print(f'Loading dataset in {src} with {len(src)} images')
38
+ dataset = NoClassDataset(data_dir=src, transform=transform, shuffle=True, load_to_memory=False)
39
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_jobs)
40
+ print(f'Loaded {len(src)} images')
41
+
42
+ result_loc = os.path.splitext(model_path)[0]+datetime.date.today().strftime('%y%m%d')+'_'+os.path.splitext(model_path)[1]+'_test_result.csv'
43
+ print(f'Results wil be saved in: {result_loc}')
44
+
45
+ model.eval()
46
+ model = model.to(device)
47
+ prediction_pos_probs = []
48
+ filenames_list = []
49
+ time_ls = []
50
+ with torch.no_grad():
51
+ for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
52
+ start = time.time()
53
+ images = batch_images.to(torch.float).to(device)
54
+ outputs = model(images)
55
+ batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
56
+ prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
57
+ filenames_list.extend(filenames)
58
+ stop = time.time()
59
+ duration = stop - start
60
+ time_ls.append(duration)
61
+ files_processed = batch_idx*batch_size
62
+ files_to_process = len(data_loader)
63
+ print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=batch_size, operation_type="Generating predictions")
64
+
65
+ data = {'path':filenames_list, 'pred':prediction_pos_probs}
66
+ df = pd.DataFrame(data, index=None)
67
+ df.to_csv(result_loc, index=True, header=True, mode='w')
68
+ torch.cuda.empty_cache()
69
+ torch.cuda.memory.empty_cache()
70
+ return df
71
+
72
+ def apply_model_to_tar(settings={}):
73
+
74
+ from .io import TarImageDataset
75
+ from .utils import process_vision_results, print_progress
76
+
77
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
78
+ if settings['normalize']:
79
+ transform = transforms.Compose([
80
+ transforms.ToTensor(),
81
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
82
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
83
+ else:
84
+ transform = transforms.Compose([
85
+ transforms.ToTensor(),
86
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
87
+
88
+ if settings['verbose']:
89
+ print(f"Loading model from {settings['model_path']}")
90
+ print(f"Loading dataset from {settings['tar_path']}")
91
+
92
+ model = torch.load(settings['model_path'])
93
+
94
+ dataset = TarImageDataset(settings['tar_path'], transform=transform)
95
+ data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
96
+
97
+ model_name = os.path.splitext(os.path.basename(settings['model_path']))[0]
98
+ dataset_name = os.path.splitext(os.path.basename(settings['tar_path']))[0]
99
+ date_name = datetime.date.today().strftime('%y%m%d')
100
+ dst = os.path.dirname(settings['tar_path'])
101
+ result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
102
+
103
+ model.eval()
104
+ model = model.to(device)
105
+
106
+ if settings['verbose']:
107
+ print(model)
108
+ print(f'Generated dataset with {len(dataset)} images')
109
+ print(f'Generating loader from {len(data_loader)} batches')
110
+ print(f'Results wil be saved in: {result_loc}')
111
+ print(f'Model is in eval mode')
112
+ print(f'Model loaded to device')
113
+
114
+ prediction_pos_probs = []
115
+ filenames_list = []
116
+ time_ls = []
117
+ gc.collect()
118
+ with torch.no_grad():
119
+ for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
120
+ start = time.time()
121
+ images = batch_images.to(torch.float).to(device)
122
+ outputs = model(images)
123
+ batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
124
+ prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
125
+ filenames_list.extend(filenames)
126
+ stop = time.time()
127
+ duration = stop - start
128
+ time_ls.append(duration)
129
+ files_processed = batch_idx*settings['batch_size']
130
+ files_to_process = len(data_loader)*settings['batch_size']
131
+ print_progress(files_processed, files_to_process, n_jobs=settings['n_jobs'], time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Tar dataset")
132
+
133
+ data = {'path':filenames_list, 'pred':prediction_pos_probs}
134
+ df = pd.DataFrame(data, index=None)
135
+ df = process_vision_results(df, settings['score_threshold'])
136
+
137
+ df.to_csv(result_loc, index=True, header=True, mode='w')
138
+ print(f"Saved results to {result_loc}")
139
+ torch.cuda.empty_cache()
140
+ torch.cuda.memory.empty_cache()
141
+ return df
20
142
 
21
143
  def evaluate_model_performance(model, loader, epoch, loss_type):
22
144
  """
@@ -175,7 +297,7 @@ def test_model_core(model, loader, loader_name, epoch, loss_type):
175
297
  'class_1_probability':prediction_pos_probs})
176
298
 
177
299
  loss /= len(loader)
178
- data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
300
+ data_df = classification_metrics(all_labels, prediction_pos_probs, loss, epoch)
179
301
  return data_df, prediction_pos_probs, all_labels, results_df
180
302
 
181
303
  def test_model_performance(loaders, model, loader_name_list, epoch, loss_type):
@@ -203,7 +325,10 @@ def train_test_model(settings):
203
325
 
204
326
  from .io import _save_settings, _copy_missclassified
205
327
  from .utils import pick_best_model
206
- from .core import generate_loaders
328
+ from .io import generate_loaders
329
+ from .settings import get_train_test_model_settings
330
+
331
+ settings = get_train_test_model_settings(settings)
207
332
 
208
333
  torch.cuda.empty_cache()
209
334
  torch.cuda.memory.empty_cache()
@@ -235,7 +360,6 @@ def train_test_model(settings):
235
360
  normalize=settings['normalize'],
236
361
  channels=settings['train_channels'],
237
362
  augment=settings['augment'],
238
- preload_batches=settings['preload_batches'],
239
363
  verbose=settings['verbose'])
240
364
 
241
365
  #train_batch_1_figure = os.path.join(dst, 'batch_1.pdf')
@@ -264,19 +388,19 @@ def train_test_model(settings):
264
388
  channels=settings['train_channels'])
265
389
 
266
390
  if settings['test']:
267
- test, _, plate_names_test, train_fig = generate_loaders(src,
268
- mode='test',
269
- image_size=settings['image_size'],
270
- batch_size=settings['batch_size'],
271
- classes=settings['classes'],
272
- n_jobs=settings['n_jobs'],
273
- validation_split=0.0,
274
- pin_memory=settings['pin_memory'],
275
- normalize=settings['normalize'],
276
- channels=settings['train_channels'],
277
- augment=False,
278
- preload_batches=settings['preload_batches'],
279
- verbose=settings['verbose'])
391
+ test, _, train_fig = generate_loaders(src,
392
+ mode='test',
393
+ image_size=settings['image_size'],
394
+ batch_size=settings['batch_size'],
395
+ classes=settings['classes'],
396
+ n_jobs=settings['n_jobs'],
397
+ validation_split=0.0,
398
+ pin_memory=settings['pin_memory'],
399
+ normalize=settings['normalize'],
400
+ channels=settings['train_channels'],
401
+ augment=False,
402
+ verbose=settings['verbose'])
403
+
280
404
  if model == None:
281
405
  model_path = pick_best_model(src+'/model')
282
406
  print(f'Best model: {model_path}')
@@ -308,7 +432,10 @@ def train_test_model(settings):
308
432
  torch.cuda.memory.empty_cache()
309
433
  gc.collect()
310
434
 
311
- return model_path
435
+ if settings['train']:
436
+ return model_path
437
+ if settings['test']:
438
+ return result_loc
312
439
 
313
440
  def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, n_jobs=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b'], verbose=False):
314
441
  """
@@ -355,11 +482,6 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
355
482
 
356
483
  kwargs = {'n_jobs': n_jobs, 'pin_memory': True} if use_cuda else {}
357
484
 
358
- #for idx, (images, labels, filenames) in enumerate(train_loaders):
359
- # batch, chans, height, width = images.shape
360
- # break
361
-
362
-
363
485
  model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint, verbose=verbose)
364
486
 
365
487
 
@@ -732,7 +854,7 @@ def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, cha
732
854
 
733
855
  def deep_spacr(settings={}):
734
856
  from .settings import deep_spacr_defaults
735
- from .core import generate_training_dataset, generate_dataset, apply_model_to_tar
857
+ from .io import generate_training_dataset, generate_dataset
736
858
  from .utils import save_settings
737
859
 
738
860
  settings = deep_spacr_defaults(settings)
spacr/gui.py CHANGED
@@ -48,6 +48,7 @@ class MainApp(tk.Tk):
48
48
  }
49
49
 
50
50
  self.additional_gui_apps = {
51
+ "Convert": (lambda frame: initiate_root(self, 'convert'), "Convert images to Grayscale TIFs."),
51
52
  "Umap": (lambda frame: initiate_root(self, 'umap'), "Generate UMAP embeddings with datapoints represented as images."),
52
53
  "Train Cellpose": (lambda frame: initiate_root(self, 'train_cellpose'), "Train custom Cellpose models."),
53
54
  "ML Analyze": (lambda frame: initiate_root(self, 'ml_analyze'), "Machine learning analysis of data."),
spacr/gui_core.py CHANGED
@@ -1,9 +1,9 @@
1
- import traceback, ctypes, csv, re, platform, time
1
+ import os, traceback, ctypes, csv, re, platform
2
2
  import tkinter as tk
3
3
  from tkinter import ttk
4
4
  from tkinter import filedialog
5
5
  from multiprocessing import Process, Value, Queue, set_start_method
6
- from tkinter import ttk, scrolledtext
6
+ from tkinter import ttk
7
7
  from matplotlib.figure import Figure
8
8
  from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
9
9
  import numpy as np
@@ -11,15 +11,13 @@ import psutil
11
11
  import GPUtil
12
12
  from collections import deque
13
13
  import tracemalloc
14
- from tkinter import Menu
15
- import io
16
14
 
17
15
  try:
18
16
  ctypes.windll.shcore.SetProcessDpiAwareness(True)
19
17
  except AttributeError:
20
18
  pass
21
19
 
22
- from .gui_elements import spacrProgressBar, spacrButton, spacrLabel, spacrFrame, spacrDropdownMenu , spacrSlider, set_dark_style, standardize_figure
20
+ from .gui_elements import spacrProgressBar, spacrButton, spacrFrame, spacrDropdownMenu , spacrSlider, set_dark_style
23
21
 
24
22
  # Define global variables
25
23
  q = None
@@ -35,6 +33,7 @@ figures = None
35
33
  figure_index = None
36
34
  progress_bar = None
37
35
  usage_bars = None
36
+ index_control = None
38
37
 
39
38
  thread_control = {"run_thread": None, "stop_requested": False}
40
39
 
@@ -170,39 +169,6 @@ def display_figure(fig):
170
169
  #flash_feedback("right")
171
170
  show_next_figure()
172
171
 
173
- def zoom_v1(event):
174
- nonlocal scale_factor
175
-
176
- zoom_speed = 0.1 # Adjust the zoom speed for smoother experience
177
-
178
- # Adjust zoom factor based on the operating system and mouse event
179
- if event.num == 4 or event.delta > 0: # Scroll up
180
- scale_factor *= (1 + zoom_speed)
181
- elif event.num == 5 or event.delta < 0: # Scroll down
182
- scale_factor /= (1 + zoom_speed)
183
-
184
- # Get mouse position relative to the figure
185
- x_mouse, y_mouse = event.x, event.y
186
- x_ratio = x_mouse / canvas_widget.winfo_width()
187
- y_ratio = y_mouse / canvas_widget.winfo_height()
188
-
189
- for ax in fig.get_axes():
190
- xlim = ax.get_xlim()
191
- ylim = ax.get_ylim()
192
-
193
- # Calculate the new limits
194
- x_center = xlim[0] + x_ratio * (xlim[1] - xlim[0])
195
- y_center = ylim[0] + (1 - y_ratio) * (ylim[1] - ylim[0])
196
-
197
- x_range = (xlim[1] - xlim[0]) * scale_factor
198
- y_range = (ylim[1] - ylim[0]) * scale_factor
199
-
200
- ax.set_xlim([x_center - x_range * x_ratio, x_center + x_range * (1 - x_ratio)])
201
- ax.set_ylim([y_center - y_range * (1 - y_ratio), y_center + y_range * y_ratio])
202
-
203
- # Redraw the figure
204
- fig.canvas.draw_idle()
205
-
206
172
  def zoom(event):
207
173
  nonlocal scale_factor
208
174
 
@@ -282,7 +248,7 @@ def show_next_figure():
282
248
  figures.append(fig)
283
249
  figure_index += 1
284
250
  display_figure(fig)
285
-
251
+
286
252
  def process_fig_queue():
287
253
  global canvas, fig_queue, canvas_widget, parent_frame, uppdate_frequency, figures, figure_index, index_control
288
254
 
@@ -329,44 +295,60 @@ def update_figure(value):
329
295
  index_control.set_to(len(figures) - 1)
330
296
  index_control.set(figure_index)
331
297
 
332
- def setup_plot_section(vertical_container):
298
+ def setup_plot_section(vertical_container, settings_type):
333
299
  global canvas, canvas_widget, figures, figure_index, index_control
300
+ from .gui_utils import display_media_in_plot_frame
301
+
302
+ style_out = set_dark_style(ttk.Style())
303
+ bg = style_out['bg_color']
304
+ fg = style_out['fg_color']
334
305
 
335
306
  # Initialize deque for storing figures and the current index
336
307
  figures = deque()
337
308
 
338
309
  # Create a frame for the plot section
339
310
  plot_frame = tk.Frame(vertical_container)
311
+ plot_frame.configure(bg=bg)
340
312
  vertical_container.add(plot_frame, stretch="always")
341
313
 
342
- # Set up the plot
314
+ # Clear the plot_frame (optional, to handle cases where it may already have content)
315
+ for widget in plot_frame.winfo_children():
316
+ widget.destroy()
317
+
318
+ # Create a figure and plot
343
319
  figure = Figure(figsize=(30, 4), dpi=100)
344
320
  plot = figure.add_subplot(111)
345
321
  plot.plot([], [])
346
322
  plot.axis('off')
323
+
324
+ if settings_type == 'map_barcodes':
325
+ # Load and display GIF
326
+ current_dir = os.path.dirname(__file__)
327
+ resources_path = os.path.join(current_dir, 'resources', 'icons')
328
+ gif_path = os.path.join(resources_path, 'dna_matrix.mp4')
329
+
330
+ display_media_in_plot_frame(gif_path, plot_frame)
331
+ canvas = FigureCanvasTkAgg(figure, master=plot_frame)
332
+ canvas.get_tk_widget().configure(cursor='arrow', highlightthickness=0)
333
+ canvas_widget = canvas.get_tk_widget()
334
+ return canvas, canvas_widget
347
335
 
348
336
  canvas = FigureCanvasTkAgg(figure, master=plot_frame)
349
337
  canvas.get_tk_widget().configure(cursor='arrow', highlightthickness=0)
350
338
  canvas_widget = canvas.get_tk_widget()
351
339
  canvas_widget.grid(row=0, column=0, sticky="nsew")
352
-
353
340
  plot_frame.grid_rowconfigure(0, weight=1)
354
341
  plot_frame.grid_columnconfigure(0, weight=1)
355
-
356
342
  canvas.draw()
357
- canvas.figure = figure # Ensure that the figure is linked to the canvas
358
- style_out = set_dark_style(ttk.Style())
359
- bg = style_out['bg_color']
360
- fg = style_out['fg_color']
361
-
343
+ canvas.figure = figure
362
344
  figure.patch.set_facecolor(bg)
363
345
  plot.set_facecolor(bg)
364
346
  containers = [plot_frame]
365
347
 
366
348
  # Create slider
367
- control_frame = tk.Frame(plot_frame, height=15*2, bg=bg) # Fixed height based on knob_radius
349
+ control_frame = tk.Frame(plot_frame, height=15*2, bg=bg)
368
350
  control_frame.grid(row=1, column=0, sticky="ew", padx=10, pady=5)
369
- control_frame.grid_propagate(False) # Prevent the frame from resizing
351
+ control_frame.grid_propagate(False)
370
352
 
371
353
  # Pass the update_figure function as the command to spacrSlider
372
354
  index_control = spacrSlider(control_frame, from_=0, to=0, value=0, thickness=2, knob_radius=10, position="center", show_index=True, command=update_figure)
@@ -442,6 +424,8 @@ def import_settings(settings_type='mask'):
442
424
  settings = get_analyze_recruitment_default_settings(settings={})
443
425
  elif settings_type == 'analyze_plaques':
444
426
  settings = {}
427
+ elif settings_type == 'convert':
428
+ settings = {}
445
429
  else:
446
430
  raise ValueError(f"Invalid settings type: {settings_type}")
447
431
 
@@ -452,7 +436,8 @@ def import_settings(settings_type='mask'):
452
436
 
453
437
  def setup_settings_panel(vertical_container, settings_type='mask'):
454
438
  global vars_dict, scrollable_frame
455
- from .settings import get_identify_masks_finetune_default_settings, set_default_analyze_screen, set_default_settings_preprocess_generate_masks, get_measure_crop_settings, deep_spacr_defaults, set_default_generate_barecode_mapping, set_default_umap_image_settings, generate_fields, get_perform_regression_default_settings, get_train_cellpose_default_settings, get_map_barcodes_default_settings, get_analyze_recruitment_default_settings, get_check_cellpose_models_default_settings
439
+ from .settings import get_identify_masks_finetune_default_settings, set_default_analyze_screen, set_default_settings_preprocess_generate_masks, get_measure_crop_settings, deep_spacr_defaults, set_default_generate_barecode_mapping, set_default_umap_image_settings
440
+ from .settings import get_map_barcodes_default_settings, get_analyze_recruitment_default_settings, get_check_cellpose_models_default_settings, generate_fields, get_perform_regression_default_settings, get_train_cellpose_default_settings
456
441
  from .gui_utils import convert_settings_dict_for_gui
457
442
  from .gui_elements import set_element_size
458
443
 
@@ -496,7 +481,9 @@ def setup_settings_panel(vertical_container, settings_type='mask'):
496
481
  elif settings_type == 'recruitment':
497
482
  settings = get_analyze_recruitment_default_settings(settings={})
498
483
  elif settings_type == 'analyze_plaques':
499
- settings = {}
484
+ settings = {'src':'path to images'}
485
+ elif settings_type == 'convert':
486
+ settings = {'src':'path to images'}
500
487
  else:
501
488
  raise ValueError(f"Invalid settings type: {settings_type}")
502
489
 
@@ -515,7 +502,7 @@ def setup_settings_panel(vertical_container, settings_type='mask'):
515
502
  def setup_console(vertical_container):
516
503
  global console_output
517
504
  from .gui_elements import set_dark_style
518
-
505
+
519
506
  # Apply dark style and get style output
520
507
  style = ttk.Style()
521
508
  style_out = set_dark_style(style)
@@ -546,9 +533,27 @@ def setup_console(vertical_container):
546
533
  def on_leave(event):
547
534
  top_border.config(bg=style_out['bg_color'])
548
535
 
536
+ #def on_enter_key(event):
537
+ # user_input = console_output.get("1.0", "end-1c").strip() # Get the user input from the console
538
+ # if user_input:
539
+ # # Print the user input with the (user) tag
540
+ # console_output.insert("end", f"\n(user): {user_input}\n")
541
+ #
542
+ # # Get the AI response from the chatbot
543
+ # response = chatbot.ask_question(user_input)
544
+ #
545
+ # # Print the AI response with the (ai) tag
546
+ # console_output.insert("end", f"(ai): {response}\n")
547
+ #
548
+ # console_output.see("end") # Scroll to the end
549
+ # #console_output.delete("1.0", "end") # Clear the input field
550
+ # return "break" # Prevent the default behavior of inserting a new line
551
+
549
552
  console_output.bind("<Enter>", on_enter)
550
553
  console_output.bind("<Leave>", on_leave)
551
554
 
555
+ #console_output.bind("<Return>", on_enter_key)
556
+
552
557
  return console_output, console_frame
553
558
 
554
559
  def setup_button_section(horizontal_container, settings_type='mask', run=True, abort=True, download=True, import_btn=True):
@@ -755,7 +760,7 @@ def initiate_abort():
755
760
  def start_process(q=None, fig_queue=None, settings_type='mask'):
756
761
  global thread_control, vars_dict, parent_frame
757
762
  from .settings import check_settings, expected_types
758
- from .gui_utils import run_function_gui, set_high_priority, set_cpu_affinity, initialize_cuda
763
+ from .gui_utils import run_function_gui, set_cpu_affinity, initialize_cuda, display_gif_in_plot_frame, print_widget_structure
759
764
 
760
765
  if q is None:
761
766
  q = Queue()
@@ -778,16 +783,14 @@ def start_process(q=None, fig_queue=None, settings_type='mask'):
778
783
 
779
784
  process_args = (settings_type, settings, q, fig_queue, stop_requested)
780
785
  if settings_type in ['mask', 'umap', 'measure', 'simulation', 'sequencing', 'classify', 'analyze_plaques',
781
- 'cellpose_dataset', 'train_cellpose', 'ml_analyze', 'cellpose_masks', 'cellpose_all', 'map_barcodes',
782
- 'regression', 'recruitment', 'cellpose_compare', 'vision_scores', 'vision_dataset']:
786
+ 'cellpose_dataset', 'train_cellpose', 'ml_analyze', 'cellpose_masks', 'cellpose_all',
787
+ 'map_barcodes', 'regression', 'recruitment', 'cellpose_compare', 'vision_scores',
788
+ 'vision_dataset', 'convert']:
783
789
 
784
790
  # Start the process
785
791
  process = Process(target=run_function_gui, args=process_args)
786
792
  process.start()
787
793
 
788
- # Set high priority for the process
789
- #set_high_priority(process)
790
-
791
794
  # Set CPU affinity if necessary
792
795
  set_cpu_affinity(process)
793
796
 
@@ -889,10 +892,14 @@ def initiate_root(parent, settings_type='mask'):
889
892
 
890
893
  global q, fig_queue, thread_control, parent_frame, scrollable_frame, button_frame, vars_dict, canvas, canvas_widget, button_scrollable_frame, progress_bar, uppdate_frequency, figures, figure_index, index_control, usage_bars
891
894
 
892
- from .gui_utils import setup_frame
895
+ from .gui_utils import setup_frame, get_screen_dimensions
893
896
  from .settings import descriptions
897
+ #from .openai import Chatbot
894
898
 
895
899
  uppdate_frequency = 500
900
+ num_cores = os.cpu_count()
901
+
902
+ #chatbot = Chatbot(api_key="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")
896
903
 
897
904
  # Start tracemalloc and initialize global variables
898
905
  tracemalloc.start()
@@ -930,10 +937,14 @@ def initiate_root(parent, settings_type='mask'):
930
937
  else:
931
938
  scrollable_frame, vars_dict = setup_settings_panel(settings_container, settings_type)
932
939
  print('setup_settings_panel')
933
- canvas, canvas_widget = setup_plot_section(vertical_container)
934
- console_output, _ = setup_console(vertical_container)
940
+ canvas, canvas_widget = setup_plot_section(vertical_container, settings_type)
941
+ console_output, _ = setup_console(vertical_container) #, chatbot)
935
942
  button_scrollable_frame, btn_col = setup_button_section(horizontal_container, settings_type)
936
- _, usage_bars, btn_col = setup_usage_panel(horizontal_container, btn_col, uppdate_frequency)
943
+
944
+ if num_cores > 12:
945
+ _, usage_bars, btn_col = setup_usage_panel(horizontal_container, btn_col, uppdate_frequency)
946
+ else:
947
+ usage_bars = []
937
948
 
938
949
  set_globals(thread_control, q, console_output, parent_frame, vars_dict, canvas, canvas_widget, scrollable_frame, fig_queue, figures, figure_index, index_control, progress_bar, usage_bars)
939
950
  description_text = descriptions.get(settings_type, "No description available for this module.")