spacr 0.2.46__py3-none-any.whl → 0.2.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. spacr/core.py +306 -21
  2. spacr/deep_spacr.py +101 -41
  3. spacr/gui.py +1 -3
  4. spacr/gui_core.py +78 -65
  5. spacr/gui_elements.py +437 -152
  6. spacr/gui_utils.py +84 -73
  7. spacr/io.py +14 -7
  8. spacr/measure.py +196 -145
  9. spacr/plot.py +2 -42
  10. spacr/resources/font/open_sans/OFL.txt +93 -0
  11. spacr/resources/font/open_sans/OpenSans-Italic-VariableFont_wdth,wght.ttf +0 -0
  12. spacr/resources/font/open_sans/OpenSans-VariableFont_wdth,wght.ttf +0 -0
  13. spacr/resources/font/open_sans/README.txt +100 -0
  14. spacr/resources/font/open_sans/static/OpenSans-Bold.ttf +0 -0
  15. spacr/resources/font/open_sans/static/OpenSans-BoldItalic.ttf +0 -0
  16. spacr/resources/font/open_sans/static/OpenSans-ExtraBold.ttf +0 -0
  17. spacr/resources/font/open_sans/static/OpenSans-ExtraBoldItalic.ttf +0 -0
  18. spacr/resources/font/open_sans/static/OpenSans-Italic.ttf +0 -0
  19. spacr/resources/font/open_sans/static/OpenSans-Light.ttf +0 -0
  20. spacr/resources/font/open_sans/static/OpenSans-LightItalic.ttf +0 -0
  21. spacr/resources/font/open_sans/static/OpenSans-Medium.ttf +0 -0
  22. spacr/resources/font/open_sans/static/OpenSans-MediumItalic.ttf +0 -0
  23. spacr/resources/font/open_sans/static/OpenSans-Regular.ttf +0 -0
  24. spacr/resources/font/open_sans/static/OpenSans-SemiBold.ttf +0 -0
  25. spacr/resources/font/open_sans/static/OpenSans-SemiBoldItalic.ttf +0 -0
  26. spacr/resources/font/open_sans/static/OpenSans_Condensed-Bold.ttf +0 -0
  27. spacr/resources/font/open_sans/static/OpenSans_Condensed-BoldItalic.ttf +0 -0
  28. spacr/resources/font/open_sans/static/OpenSans_Condensed-ExtraBold.ttf +0 -0
  29. spacr/resources/font/open_sans/static/OpenSans_Condensed-ExtraBoldItalic.ttf +0 -0
  30. spacr/resources/font/open_sans/static/OpenSans_Condensed-Italic.ttf +0 -0
  31. spacr/resources/font/open_sans/static/OpenSans_Condensed-Light.ttf +0 -0
  32. spacr/resources/font/open_sans/static/OpenSans_Condensed-LightItalic.ttf +0 -0
  33. spacr/resources/font/open_sans/static/OpenSans_Condensed-Medium.ttf +0 -0
  34. spacr/resources/font/open_sans/static/OpenSans_Condensed-MediumItalic.ttf +0 -0
  35. spacr/resources/font/open_sans/static/OpenSans_Condensed-Regular.ttf +0 -0
  36. spacr/resources/font/open_sans/static/OpenSans_Condensed-SemiBold.ttf +0 -0
  37. spacr/resources/font/open_sans/static/OpenSans_Condensed-SemiBoldItalic.ttf +0 -0
  38. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Bold.ttf +0 -0
  39. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-BoldItalic.ttf +0 -0
  40. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-ExtraBold.ttf +0 -0
  41. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-ExtraBoldItalic.ttf +0 -0
  42. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Italic.ttf +0 -0
  43. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Light.ttf +0 -0
  44. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-LightItalic.ttf +0 -0
  45. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Medium.ttf +0 -0
  46. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-MediumItalic.ttf +0 -0
  47. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Regular.ttf +0 -0
  48. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBold.ttf +0 -0
  49. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBoldItalic.ttf +0 -0
  50. spacr/sequencing.py +481 -587
  51. spacr/settings.py +197 -122
  52. spacr/utils.py +21 -13
  53. {spacr-0.2.46.dist-info → spacr-0.2.56.dist-info}/METADATA +7 -4
  54. spacr-0.2.56.dist-info/RECORD +100 -0
  55. spacr-0.2.46.dist-info/RECORD +0 -60
  56. {spacr-0.2.46.dist-info → spacr-0.2.56.dist-info}/LICENSE +0 -0
  57. {spacr-0.2.46.dist-info → spacr-0.2.56.dist-info}/WHEEL +0 -0
  58. {spacr-0.2.46.dist-info → spacr-0.2.56.dist-info}/entry_points.txt +0 -0
  59. {spacr-0.2.46.dist-info → spacr-0.2.56.dist-info}/top_level.txt +0 -0
spacr/gui_utils.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import os, io, sys, ast, ctypes, ast, sqlite3, requests, time, traceback
2
- import traceback
3
2
  import tkinter as tk
4
3
  from tkinter import ttk
5
4
  import matplotlib
@@ -7,8 +6,7 @@ import matplotlib.pyplot as plt
7
6
  matplotlib.use('Agg')
8
7
  from huggingface_hub import list_repo_files
9
8
 
10
- from . gui_core import initiate_root
11
- from .gui_elements import AnnotateApp, spacrEntry, spacrCheck, spacrCombo, set_default_font
9
+ from .gui_elements import AnnotateApp, spacrEntry, spacrCheck, spacrCombo
12
10
 
13
11
  try:
14
12
  ctypes.windll.shcore.SetProcessDpiAwareness(True)
@@ -28,6 +26,10 @@ def proceed_with_app(root, app_name, app_func):
28
26
  app_func(root.content_frame)
29
27
 
30
28
  def load_app(root, app_name, app_func):
29
+ # Clear the canvas if it exists
30
+ if root.canvas is not None:
31
+ root.clear_frame(root.canvas)
32
+
31
33
  # Cancel all scheduled after tasks
32
34
  if hasattr(root, 'after_tasks'):
33
35
  for task in root.after_tasks:
@@ -37,70 +39,90 @@ def load_app(root, app_name, app_func):
37
39
  # Exit functionality only for the annotation and make_masks apps
38
40
  if app_name not in ["Annotate", "make_masks"] and hasattr(root, 'current_app_exit_func'):
39
41
  root.next_app_func = proceed_with_app
40
- root.next_app_args = (app_name, app_func) # Ensure correct arguments
42
+ root.next_app_args = (app_name, app_func)
41
43
  root.current_app_exit_func()
42
44
  else:
43
45
  proceed_with_app(root, app_name, app_func)
44
-
46
+
45
47
  def parse_list_v1(value):
46
48
  try:
47
49
  parsed_value = ast.literal_eval(value)
48
50
  if isinstance(parsed_value, list):
49
51
  return parsed_value
50
52
  else:
51
- raise ValueError
52
- except (ValueError, SyntaxError):
53
- raise ValueError("Invalid format for list")
53
+ raise ValueError(f"Expected a list but got {type(parsed_value).__name__}")
54
+ except (ValueError, SyntaxError) as e:
55
+ raise ValueError(f"Invalid format for list: {value}. Error: {e}")
54
56
 
55
57
  def parse_list(value):
56
58
  try:
57
59
  parsed_value = ast.literal_eval(value)
58
60
  if isinstance(parsed_value, list):
59
- return parsed_value
61
+ # Check if the list elements are homogeneous (all int or all str)
62
+ if all(isinstance(item, int) for item in parsed_value):
63
+ return parsed_value
64
+ elif all(isinstance(item, str) for item in parsed_value):
65
+ return parsed_value
66
+ else:
67
+ raise ValueError("List contains mixed types or unsupported types")
60
68
  else:
61
69
  raise ValueError(f"Expected a list but got {type(parsed_value).__name__}")
62
70
  except (ValueError, SyntaxError) as e:
63
71
  raise ValueError(f"Invalid format for list: {value}. Error: {e}")
64
-
72
+
65
73
  # Usage example in your create_input_field function
66
74
  def create_input_field(frame, label_text, row, var_type='entry', options=None, default_value=None):
67
- from .gui_elements import set_dark_style
75
+ from .gui_elements import set_dark_style, set_element_size
68
76
  label_column = 0
69
- widget_column = 1
77
+ widget_column = 0 # Both label and widget will be in the same column
70
78
 
71
79
  style_out = set_dark_style(ttk.Style())
80
+ font_loader = style_out['font_loader']
81
+ font_size = style_out['font_size']
82
+ size_dict = set_element_size()
83
+ size_dict['settings_width'] = size_dict['settings_width'] - int(size_dict['settings_width']*0.1)
72
84
 
73
85
  # Replace underscores with spaces and capitalize the first letter
74
86
  label_text = label_text.replace('_', ' ').capitalize()
75
87
 
76
88
  # Configure the column widths
77
- frame.grid_columnconfigure(label_column, weight=0) # Allow the label column to expand
78
- frame.grid_columnconfigure(widget_column, weight=1) # Allow the widget column to expand
89
+ frame.grid_columnconfigure(label_column, weight=1) # Allow the column to expand
90
+
91
+ # Create a custom frame with a translucent background and rounded edges
92
+ custom_frame = tk.Frame(frame, bg=style_out['bg_color'], bd=2, relief='solid', width=size_dict['settings_width'])
93
+ custom_frame.grid(column=label_column, row=row, sticky=tk.EW, padx=(5, 5), pady=5)
94
+
95
+ # Apply styles to custom frame
96
+ custom_frame.update_idletasks()
97
+ custom_frame.config(highlightbackground=style_out['bg_color'], highlightthickness=1, bd=2)
79
98
 
80
99
  # Create and configure the label
81
- label = ttk.Label(frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=(style_out['font_family'], style_out['font_size']), anchor='e', justify='right')
82
- label.grid(column=label_column, row=row, sticky=tk.E, padx=(5, 2), pady=5)
100
+ if font_loader:
101
+ label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=font_loader.get_font(size=font_size), anchor='e', justify='right')
102
+ label = ttk.Label(custom_frame, text=label_text, background=style_out['bg_color'], foreground=style_out['fg_color'], font=(style_out['font_family'], style_out['font_size']), anchor='e', justify='right')
103
+ label.grid(column=label_column, row=0, sticky=tk.W, padx=(5, 2), pady=5) # Place the label in the first row
83
104
 
105
+ # Create and configure the input widget based on var_type
84
106
  if var_type == 'entry':
85
107
  var = tk.StringVar(value=default_value)
86
- entry = spacrEntry(frame, textvariable=var, outline=False)
87
- entry.grid(column=widget_column, row=row, sticky=tk.W, padx=(2, 5), pady=5) # Align widget to the left
88
- return (label, entry, var) # Return both the label and the entry, and the variable
108
+ entry = spacrEntry(custom_frame, textvariable=var, outline=False, width=size_dict['settings_width'])
109
+ entry.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the entry in the second row
110
+ return (label, entry, var, custom_frame) # Return both the label and the entry, and the variable
89
111
  elif var_type == 'check':
90
112
  var = tk.BooleanVar(value=default_value) # Set default value (True/False)
91
- check = spacrCheck(frame, text="", variable=var)
92
- check.grid(column=widget_column, row=row, sticky=tk.W, padx=(2, 5), pady=5) # Align widget to the left
93
- return (label, check, var) # Return both the label and the checkbutton, and the variable
113
+ check = spacrCheck(custom_frame, text="", variable=var)
114
+ check.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the checkbutton in the second row
115
+ return (label, check, var, custom_frame) # Return both the label and the checkbutton, and the variable
94
116
  elif var_type == 'combo':
95
117
  var = tk.StringVar(value=default_value) # Set default value
96
- combo = spacrCombo(frame, textvariable=var, values=options) # Apply TCombobox style
97
- combo.grid(column=widget_column, row=row, sticky=tk.W, padx=(2, 5), pady=5) # Align widget to the left
118
+ combo = spacrCombo(custom_frame, textvariable=var, values=options, width=size_dict['settings_width']) # Apply TCombobox style
119
+ combo.grid(column=widget_column, row=1, sticky=tk.W, padx=(2, 5), pady=5) # Place the combobox in the second row
98
120
  if default_value:
99
121
  combo.set(default_value)
100
- return (label, combo, var) # Return both the label and the combobox, and the variable
122
+ return (label, combo, var, custom_frame) # Return both the label and the combobox, and the variable
101
123
  else:
102
124
  var = None # Placeholder in case of an undefined var_type
103
- return (label, None, var)
125
+ return (label, None, var, custom_frame)
104
126
 
105
127
  def process_stdout_stderr(q):
106
128
  """
@@ -116,10 +138,9 @@ class WriteToQueue(io.TextIOBase):
116
138
  """
117
139
  def __init__(self, q):
118
140
  self.q = q
119
-
120
141
  def write(self, msg):
121
- self.q.put(msg)
122
-
142
+ if msg.strip(): # Avoid empty messages
143
+ self.q.put(msg)
123
144
  def flush(self):
124
145
  pass
125
146
 
@@ -309,23 +330,6 @@ def annotate_with_image_refs(settings, root, shutdown_callback):
309
330
  # Call load_images after setting up the root window
310
331
  app.load_images()
311
332
 
312
- def set_element_size(widget):
313
- screen_width = widget.winfo_screenwidth()
314
- screen_height = widget.winfo_screenheight()
315
- btn_size = screen_width // 40
316
- bar_size = screen_width // 50
317
- settings_width = screen_width // 5
318
- panel_height = screen_height // 12
319
- panel_width = settings_width
320
- size_dict = {
321
- 'btn_size': btn_size,
322
- 'bar_size': bar_size,
323
- 'settings_width': settings_width,
324
- 'panel_width': panel_width,
325
- 'panel_height': panel_height
326
- }
327
- return size_dict
328
-
329
333
  def convert_settings_dict_for_gui(settings):
330
334
  from torchvision import models as torch_models
331
335
  torchvision_models = [name for name, obj in torch_models.__dict__.items() if callable(obj)]
@@ -335,7 +339,9 @@ def convert_settings_dict_for_gui(settings):
335
339
  special_cases = {
336
340
  'metadata_type': ('combo', ['cellvoyager', 'cq1', 'nikon', 'zeis', 'custom'], 'cellvoyager'),
337
341
  'channels': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
342
+ 'train_channels': ('combo', ["['r','g','b']", "['r','g']", "['r','b']", "['g','b']", "['r']", "['g']", "['b']"], "['r','g','b']"),
338
343
  'channel_dims': ('combo', ['[0,1,2,3]', '[0,1,2]', '[0,1]', '[0]'], '[0,1,2,3]'),
344
+ 'dataset_mode': ('combo', ['annotation', 'metadata', 'recruitment'], 'annotate'),
339
345
  'cell_mask_dim': ('combo', chans, None),
340
346
  'cell_chann_dim': ('combo', chans, None),
341
347
  'nucleus_mask_dim': ('combo', chans, None),
@@ -381,8 +387,10 @@ def convert_settings_dict_for_gui(settings):
381
387
  variables[key] = ('entry', None, str(value))
382
388
  else:
383
389
  variables[key] = ('entry', None, str(value))
390
+
384
391
  return variables
385
392
 
393
+
386
394
  def spacrFigShow(fig_queue=None):
387
395
  """
388
396
  Replacement for plt.show() that queues figures instead of displaying them.
@@ -424,13 +432,14 @@ def function_gui_wrapper(function=None, settings={}, q=None, fig_queue=None, imp
424
432
  plt.show = original_show
425
433
 
426
434
  def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
435
+
427
436
  from .gui_utils import process_stdout_stderr
428
- from .core import preprocess_generate_masks, generate_ml_scores, identify_masks_finetune, check_cellpose_models, analyze_recruitment, train_cellpose, compare_cellpose_masks, analyze_plaques, generate_dataset, apply_model_to_tar
437
+ from .core import generate_image_umap, preprocess_generate_masks, generate_ml_scores, identify_masks_finetune, check_cellpose_models, analyze_recruitment, train_cellpose, compare_cellpose_masks, analyze_plaques, generate_dataset, apply_model_to_tar
429
438
  from .io import generate_cellpose_train_test
430
439
  from .measure import measure_crop
431
440
  from .sim import run_multiple_simulations
432
- from .deep_spacr import train_test_model
433
- from .sequencing import analyze_reads, map_barcodes_folder, perform_regression
441
+ from .deep_spacr import deep_spacr
442
+ from .sequencing import generate_barecode_mapping, perform_regression
434
443
  process_stdout_stderr(q)
435
444
 
436
445
  print(f'run_function_gui settings_type: {settings_type}')
@@ -444,12 +453,9 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
444
453
  elif settings_type == 'simulation':
445
454
  function = run_multiple_simulations
446
455
  imports = 1
447
- elif settings_type == 'sequencing':
448
- function = analyze_reads
449
- imports = 1
450
456
  elif settings_type == 'classify':
451
- function = train_test_model
452
- imports = 2
457
+ function = deep_spacr
458
+ imports = 1
453
459
  elif settings_type == 'train_cellpose':
454
460
  function = train_cellpose
455
461
  imports = 1
@@ -463,14 +469,17 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
463
469
  function = check_cellpose_models
464
470
  imports = 1
465
471
  elif settings_type == 'map_barcodes':
466
- function = map_barcodes_folder
467
- imports = 2
472
+ function = generate_barecode_mapping
473
+ imports = 1
468
474
  elif settings_type == 'regression':
469
475
  function = perform_regression
470
476
  imports = 2
471
477
  elif settings_type == 'recruitment':
472
478
  function = analyze_recruitment
473
479
  imports = 2
480
+ elif settings_type == 'umap':
481
+ function = generate_image_umap
482
+ imports = 1
474
483
  else:
475
484
  raise ValueError(f"Invalid settings type: {settings_type}")
476
485
  try:
@@ -481,6 +490,7 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
481
490
  finally:
482
491
  stop_requested.value = 1
483
492
 
493
+
484
494
  def hide_all_settings(vars_dict, categories):
485
495
  """
486
496
  Function to initially hide all settings in the GUI.
@@ -495,26 +505,27 @@ def hide_all_settings(vars_dict, categories):
495
505
 
496
506
  for category, settings in categories.items():
497
507
  if any(setting in vars_dict for setting in settings):
498
- vars_dict[category] = (None, None, tk.IntVar(value=0))
508
+ vars_dict[category] = (None, None, tk.IntVar(value=0), None)
499
509
 
500
510
  # Initially hide all settings
501
511
  for setting in settings:
502
512
  if setting in vars_dict:
503
- label, widget, _ = vars_dict[setting]
513
+ label, widget, _, frame = vars_dict[setting]
504
514
  label.grid_remove()
505
515
  widget.grid_remove()
516
+ frame.grid_remove()
506
517
  return vars_dict
507
518
 
508
519
  def setup_frame(parent_frame):
509
- from .gui_elements import set_dark_style, set_default_font
520
+ from .gui_elements import set_dark_style, set_element_size
510
521
 
511
522
  style = ttk.Style(parent_frame)
512
- size_dict = set_element_size(parent_frame)
523
+ size_dict = set_element_size()
513
524
  style_out = set_dark_style(style)
514
525
 
515
526
  settings_container = tk.PanedWindow(parent_frame, orient=tk.VERTICAL, width=size_dict['settings_width'], bg=style_out['bg_color'])
516
- vertical_container = tk.PanedWindow(parent_frame, orient=tk.VERTICAL, bg=style_out['bg_color'])
517
- horizontal_container = tk.PanedWindow(parent_frame, orient=tk.HORIZONTAL, height=size_dict['panel_height'], bg=style_out['bg_color'])
527
+ vertical_container = tk.PanedWindow(parent_frame, orient=tk.VERTICAL, width=size_dict['panel_width'], bg=style_out['bg_color'])
528
+ horizontal_container = tk.PanedWindow(parent_frame, orient=tk.HORIZONTAL, height=size_dict['panel_height'], width=size_dict['panel_width'], bg=style_out['bg_color'])
518
529
 
519
530
  parent_frame.grid_rowconfigure(0, weight=1)
520
531
  parent_frame.grid_rowconfigure(1, weight=0)
@@ -523,22 +534,18 @@ def setup_frame(parent_frame):
523
534
 
524
535
  settings_container.grid(row=0, column=0, rowspan=2, sticky="nsew")
525
536
  vertical_container.grid(row=0, column=1, sticky="nsew")
526
- horizontal_container.grid(row=1, column=1, sticky="nsew")
537
+ horizontal_container.grid(row=1, column=1, sticky="ew")
527
538
 
528
- # Lock the width of the horizontal_container
529
- horizontal_container.update_idletasks() # Ensure geometry manager calculates size
530
- fixed_width = horizontal_container.winfo_width()
531
- parent_frame.grid_columnconfigure(1, weight=0)
532
- horizontal_container.config(width=fixed_width)
539
+ # Ensure settings_container maintains its width
540
+ settings_container.grid_propagate(False)
541
+ settings_container.update_idletasks()
533
542
 
534
- tk.Label(settings_container, text="Settings Container", bg=style_out['bg_color']).pack(fill=tk.BOTH, expand=True)
535
- tk.Label(vertical_container, text="Vertical Container", bg=style_out['bg_color']).pack(fill=tk.BOTH, expand=True)
543
+ tk.Label(settings_container, text="Settings Container", bg=style_out['bg_color']).grid(row=0, column=0, sticky="ew")
536
544
 
537
545
  set_dark_style(style, parent_frame, [settings_container, vertical_container, horizontal_container])
538
546
 
539
- size = style_out['font_size'] - 2
540
-
541
- set_default_font(parent_frame, font_name=style_out['font_family'], size=size)
547
+ #size = style_out['font_size'] - 2
548
+ #set_default_font(parent_frame, font_name=style_out['font_family'], size=size)
542
549
 
543
550
  return parent_frame, vertical_container, horizontal_container, settings_container
544
551
 
@@ -622,3 +629,7 @@ def download_dataset(q, repo_id, subfolder, local_dir=None, retries=5, delay=5):
622
629
  time.sleep(delay)
623
630
 
624
631
  raise Exception("Failed to download files after multiple attempts.")
632
+
633
+ def ensure_after_tasks(frame):
634
+ if not hasattr(frame, 'after_tasks'):
635
+ frame.after_tasks = []
spacr/io.py CHANGED
@@ -597,7 +597,6 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
597
597
  for idx in range(0, len(all_filenames), batch_size):
598
598
  start = time.time()
599
599
  batch_filenames = all_filenames[idx:idx+batch_size]
600
- files_processed = 0
601
600
  for filename in batch_filenames:
602
601
  images_by_key = _extract_filename_metadata(batch_filenames, src, images_by_key, regular_expression, metadata_type, pick_slice, skip_mode)
603
602
 
@@ -974,7 +973,7 @@ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_s
974
973
  time_ls.append(duration)
975
974
  files_processed = i+1
976
975
  files_to_process = time_stack_path_lists
977
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Concatinating")
976
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type="Concatinating")
978
977
  stack = np.stack(stack_region)
979
978
  save_loc = os.path.join(channel_stack_loc, f'{name}.npz')
980
979
  np.savez(save_loc, data=stack, filenames=filenames_region)
@@ -1005,7 +1004,7 @@ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_s
1005
1004
  time_ls.append(duration)
1006
1005
  files_processed = i+1
1007
1006
  files_to_process = nr_files
1008
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Concatinating")
1007
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type="Concatinating")
1009
1008
  if (i+1) % batch_size == 0 or i+1 == nr_files:
1010
1009
  unique_shapes = {arr.shape[:-1] for arr in stack_ls}
1011
1010
  if len(unique_shapes) > 1:
@@ -2294,15 +2293,23 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2294
2293
  def save_model_at_threshold(threshold, epoch, suffix=""):
2295
2294
  percentile = str(threshold * 100)
2296
2295
  print(f'\rfound: {percentile}% accurate model')#, end='\r', flush=True)
2297
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth')
2296
+ model_path = f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth'
2297
+ torch.save(model, model_path)
2298
+ return model_path
2298
2299
 
2299
2300
  if epoch % 100 == 0 or epoch == epochs:
2300
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth')
2301
+ model_path = f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth'
2302
+ torch.save(model, model_path)
2303
+ return model_path
2301
2304
 
2302
2305
  for threshold in intermedeate_save:
2303
2306
  if results_df['neg_accuracy'].dropna().mean() >= threshold and results_df['pos_accuracy'].dropna().mean() >= threshold:
2304
- save_model_at_threshold(threshold, epoch)
2305
- break # Ensure we only save for the highest matching threshold
2307
+ model_path = save_model_at_threshold(threshold, epoch)
2308
+ break
2309
+ else:
2310
+ model_path = None
2311
+
2312
+ return model_path
2306
2313
 
2307
2314
  def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
2308
2315
  """