spacr 0.0.82__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +4 -0
- spacr/annotate_app.py +4 -0
- spacr/annotate_app_v2.py +511 -0
- spacr/core.py +254 -172
- spacr/deep_spacr.py +137 -50
- spacr/graph_learning.py +28 -8
- spacr/gui.py +5 -5
- spacr/gui_2.py +106 -36
- spacr/gui_classify_app.py +3 -3
- spacr/gui_mask_app.py +34 -11
- spacr/gui_measure_app.py +32 -17
- spacr/gui_utils.py +96 -29
- spacr/io.py +227 -144
- spacr/measure.py +2 -1
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model +0 -0
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +23 -0
- spacr/plot.py +102 -6
- spacr/sequencing.py +140 -91
- spacr/settings.py +477 -0
- spacr/timelapse.py +0 -3
- spacr/utils.py +312 -275
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/METADATA +1 -1
- spacr-0.1.1.dist-info/RECORD +40 -0
- spacr-0.0.82.dist-info/RECORD +0 -36
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/LICENSE +0 -0
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/WHEEL +0 -0
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/top_level.txt +0 -0
spacr/deep_spacr.py
CHANGED
@@ -6,7 +6,6 @@ from torch.autograd import grad
|
|
6
6
|
from torch.optim.lr_scheduler import StepLR
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from IPython.display import display, clear_output
|
9
|
-
|
10
9
|
import matplotlib.pyplot as plt
|
11
10
|
from PIL import Image
|
12
11
|
|
@@ -200,11 +199,20 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
200
199
|
from .io import _save_settings, _copy_missclassified
|
201
200
|
from .utils import pick_best_model
|
202
201
|
from .core import generate_loaders
|
203
|
-
|
202
|
+
from .settings import set_default_train_test_model
|
203
|
+
|
204
|
+
torch.cuda.empty_cache()
|
205
|
+
torch.cuda.memory.empty_cache()
|
206
|
+
gc.collect()
|
207
|
+
|
208
|
+
settings = set_default_train_test_model(settings)
|
209
|
+
channels_str = ''.join(settings['channels'])
|
210
|
+
dst = os.path.join(src,'model', settings['model_type'], channels_str, str(f"epochs_{settings['epochs']}"))
|
211
|
+
os.makedirs(dst, exist_ok=True)
|
204
212
|
settings['src'] = src
|
213
|
+
settings['dst'] = dst
|
205
214
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
206
|
-
settings_csv = os.path.join(
|
207
|
-
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
215
|
+
settings_csv = os.path.join(dst,'train_test_model_settings.csv')
|
208
216
|
settings_df.to_csv(settings_csv, index=False)
|
209
217
|
|
210
218
|
if custom_model:
|
@@ -212,15 +220,9 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
212
220
|
|
213
221
|
if settings['train']:
|
214
222
|
_save_settings(settings, src)
|
215
|
-
|
216
|
-
torch.cuda.memory.empty_cache()
|
217
|
-
gc.collect()
|
218
|
-
dst = os.path.join(src,'model')
|
219
|
-
os.makedirs(dst, exist_ok=True)
|
220
|
-
settings['src'] = src
|
221
|
-
settings['dst'] = dst
|
223
|
+
|
222
224
|
if settings['train']:
|
223
|
-
train, val, plate_names = generate_loaders(src,
|
225
|
+
train, val, plate_names, train_fig = generate_loaders(src,
|
224
226
|
train_mode=settings['train_mode'],
|
225
227
|
mode='train',
|
226
228
|
image_size=settings['image_size'],
|
@@ -231,11 +233,42 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
231
233
|
pin_memory=settings['pin_memory'],
|
232
234
|
normalize=settings['normalize'],
|
233
235
|
channels=settings['channels'],
|
236
|
+
augment=settings['augment'],
|
234
237
|
verbose=settings['verbose'])
|
235
|
-
|
236
|
-
|
238
|
+
|
239
|
+
train_batch_1_figure = os.path.join(dst, 'batch_1.pdf')
|
240
|
+
train_fig.savefig(train_batch_1_figure, format='pdf', dpi=600)
|
241
|
+
|
242
|
+
if settings['train']:
|
243
|
+
model = train_model(dst = settings['dst'],
|
244
|
+
model_type=settings['model_type'],
|
245
|
+
train_loaders = train,
|
246
|
+
train_loader_names = plate_names,
|
247
|
+
train_mode = settings['train_mode'],
|
248
|
+
epochs = settings['epochs'],
|
249
|
+
learning_rate = settings['learning_rate'],
|
250
|
+
init_weights = settings['init_weights'],
|
251
|
+
weight_decay = settings['weight_decay'],
|
252
|
+
amsgrad = settings['amsgrad'],
|
253
|
+
optimizer_type = settings['optimizer_type'],
|
254
|
+
use_checkpoint = settings['use_checkpoint'],
|
255
|
+
dropout_rate = settings['dropout_rate'],
|
256
|
+
num_workers = settings['num_workers'],
|
257
|
+
val_loaders = val,
|
258
|
+
test_loaders = None,
|
259
|
+
intermedeate_save = settings['intermedeate_save'],
|
260
|
+
schedule = settings['schedule'],
|
261
|
+
loss_type=settings['loss_type'],
|
262
|
+
gradient_accumulation=settings['gradient_accumulation'],
|
263
|
+
gradient_accumulation_steps=settings['gradient_accumulation_steps'],
|
264
|
+
channels=settings['channels'])
|
265
|
+
|
266
|
+
torch.cuda.empty_cache()
|
267
|
+
torch.cuda.memory.empty_cache()
|
268
|
+
gc.collect()
|
269
|
+
|
237
270
|
if settings['test']:
|
238
|
-
test, _, plate_names_test = generate_loaders(src,
|
271
|
+
test, _, plate_names_test, train_fig = generate_loaders(src,
|
239
272
|
train_mode=settings['train_mode'],
|
240
273
|
mode='test',
|
241
274
|
image_size=settings['image_size'],
|
@@ -246,6 +279,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
246
279
|
pin_memory=settings['pin_memory'],
|
247
280
|
normalize=settings['normalize'],
|
248
281
|
channels=settings['channels'],
|
282
|
+
augment=False,
|
249
283
|
verbose=settings['verbose'])
|
250
284
|
if model == None:
|
251
285
|
model_path = pick_best_model(src+'/model')
|
@@ -258,10 +292,10 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
258
292
|
print(type(model))
|
259
293
|
print(model)
|
260
294
|
|
261
|
-
model_fldr =
|
295
|
+
model_fldr = dst
|
262
296
|
time_now = datetime.date.today().strftime('%y%m%d')
|
263
|
-
result_loc = f
|
264
|
-
acc_loc = f
|
297
|
+
result_loc = f"{model_fldr}/{settings['model_type']}_time_{time_now}_test_result.csv"
|
298
|
+
acc_loc = f"{model_fldr}/{settings['model_type']}_time_{time_now}_test_acc.csv"
|
265
299
|
print(f'Results wil be saved in: {result_loc}')
|
266
300
|
|
267
301
|
result, accuracy = test_model_performance(loaders=test,
|
@@ -274,37 +308,12 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
274
308
|
result.to_csv(result_loc, index=True, header=True, mode='w')
|
275
309
|
accuracy.to_csv(acc_loc, index=True, header=True, mode='w')
|
276
310
|
_copy_missclassified(accuracy)
|
277
|
-
else:
|
278
|
-
test = None
|
279
|
-
|
280
|
-
if settings['train']:
|
281
|
-
train_model(dst = settings['dst'],
|
282
|
-
model_type=settings['model_type'],
|
283
|
-
train_loaders = train,
|
284
|
-
train_loader_names = plate_names,
|
285
|
-
train_mode = settings['train_mode'],
|
286
|
-
epochs = settings['epochs'],
|
287
|
-
learning_rate = settings['learning_rate'],
|
288
|
-
init_weights = settings['init_weights'],
|
289
|
-
weight_decay = settings['weight_decay'],
|
290
|
-
amsgrad = settings['amsgrad'],
|
291
|
-
optimizer_type = settings['optimizer_type'],
|
292
|
-
use_checkpoint = settings['use_checkpoint'],
|
293
|
-
dropout_rate = settings['dropout_rate'],
|
294
|
-
num_workers = settings['num_workers'],
|
295
|
-
val_loaders = val,
|
296
|
-
test_loaders = test,
|
297
|
-
intermedeate_save = settings['intermedeate_save'],
|
298
|
-
schedule = settings['schedule'],
|
299
|
-
loss_type=settings['loss_type'],
|
300
|
-
gradient_accumulation=settings['gradient_accumulation'],
|
301
|
-
gradient_accumulation_steps=settings['gradient_accumulation_steps'])
|
302
311
|
|
303
312
|
torch.cuda.empty_cache()
|
304
313
|
torch.cuda.memory.empty_cache()
|
305
314
|
gc.collect()
|
306
315
|
|
307
|
-
def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, num_workers=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4):
|
316
|
+
def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, num_workers=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b']):
|
308
317
|
"""
|
309
318
|
Trains a model using the specified parameters.
|
310
319
|
|
@@ -349,7 +358,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
349
358
|
kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
|
350
359
|
|
351
360
|
for idx, (images, labels, filenames) in enumerate(train_loaders):
|
352
|
-
batch,
|
361
|
+
batch, chans, height, width = images.shape
|
353
362
|
break
|
354
363
|
|
355
364
|
model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
|
@@ -432,10 +441,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
432
441
|
if schedule == 'step_lr':
|
433
442
|
scheduler.step()
|
434
443
|
|
435
|
-
_save_progress(dst, results_df, train_metrics_df)
|
444
|
+
_save_progress(dst, results_df, train_metrics_df, epoch, epochs)
|
436
445
|
clear_output(wait=True)
|
437
446
|
display(results_df)
|
438
|
-
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
447
|
+
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=channels)
|
439
448
|
|
440
449
|
if train_mode == 'irm':
|
441
450
|
dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
|
@@ -505,10 +514,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
505
514
|
|
506
515
|
clear_output(wait=True)
|
507
516
|
display(results_df)
|
508
|
-
_save_progress(dst, results_df, train_metrics_df)
|
517
|
+
_save_progress(dst, results_df, train_metrics_df, epoch, epochs)
|
509
518
|
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
510
519
|
print(f'Saved model: {dst}')
|
511
|
-
return
|
520
|
+
return model
|
512
521
|
|
513
522
|
def visualize_saliency_map(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
|
514
523
|
|
@@ -693,4 +702,82 @@ def visualize_integrated_gradients(src, model_path, target_label_idx=0, image_si
|
|
693
702
|
if save_integrated_grads:
|
694
703
|
os.makedirs(save_dir, exist_ok=True)
|
695
704
|
integrated_grads_image = Image.fromarray((integrated_grads * 255).astype(np.uint8))
|
696
|
-
integrated_grads_image.save(os.path.join(save_dir, f'integrated_grads_{file}'))
|
705
|
+
integrated_grads_image.save(os.path.join(save_dir, f'integrated_grads_{file}'))
|
706
|
+
|
707
|
+
class SmoothGrad:
|
708
|
+
def __init__(self, model, n_samples=50, stdev_spread=0.15):
|
709
|
+
self.model = model
|
710
|
+
self.n_samples = n_samples
|
711
|
+
self.stdev_spread = stdev_spread
|
712
|
+
|
713
|
+
def compute_smooth_grad(self, input_tensor, target_class):
|
714
|
+
self.model.eval()
|
715
|
+
stdev = self.stdev_spread * (input_tensor.max() - input_tensor.min())
|
716
|
+
total_gradients = torch.zeros_like(input_tensor)
|
717
|
+
|
718
|
+
for i in range(self.n_samples):
|
719
|
+
noise = torch.normal(mean=0, std=stdev, size=input_tensor.shape).to(input_tensor.device)
|
720
|
+
noisy_input = input_tensor + noise
|
721
|
+
noisy_input.requires_grad_()
|
722
|
+
output = self.model(noisy_input)
|
723
|
+
self.model.zero_grad()
|
724
|
+
output[0, target_class].backward()
|
725
|
+
total_gradients += noisy_input.grad
|
726
|
+
|
727
|
+
avg_gradients = total_gradients / self.n_samples
|
728
|
+
return avg_gradients.abs()
|
729
|
+
|
730
|
+
def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, channels=[1,2,3], normalize=True, save_smooth_grad=False, save_dir='smooth_grad'):
|
731
|
+
|
732
|
+
from .utils import preprocess_image
|
733
|
+
|
734
|
+
use_cuda = torch.cuda.is_available()
|
735
|
+
device = torch.device("cuda" if use_cuda else "cpu")
|
736
|
+
|
737
|
+
model = torch.load(model_path)
|
738
|
+
model.to(device)
|
739
|
+
smooth_grad = SmoothGrad(model)
|
740
|
+
|
741
|
+
if save_smooth_grad and not os.path.exists(save_dir):
|
742
|
+
os.makedirs(save_dir)
|
743
|
+
|
744
|
+
images = []
|
745
|
+
filenames = []
|
746
|
+
for file in os.listdir(src):
|
747
|
+
if not file.endswith('.png'):
|
748
|
+
continue
|
749
|
+
image_path = os.path.join(src, file)
|
750
|
+
image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
|
751
|
+
images.append(image)
|
752
|
+
filenames.append(file)
|
753
|
+
|
754
|
+
input_tensor = input_tensor.to(device)
|
755
|
+
smooth_grad_map = smooth_grad.compute_smooth_grad(input_tensor, target_label_idx)
|
756
|
+
smooth_grad_map = np.mean(smooth_grad_map.cpu().data.numpy(), axis=1).squeeze()
|
757
|
+
|
758
|
+
fig, ax = plt.subplots(1, 3, figsize=(20, 5))
|
759
|
+
ax[0].imshow(image)
|
760
|
+
ax[0].axis('off')
|
761
|
+
ax[0].set_title("Original Image")
|
762
|
+
ax[1].imshow(smooth_grad_map, cmap='hot')
|
763
|
+
ax[1].axis('off')
|
764
|
+
ax[1].set_title("SmoothGrad")
|
765
|
+
overlay = np.array(image)
|
766
|
+
overlay = overlay / overlay.max()
|
767
|
+
smooth_grad_map_rgb = np.stack([smooth_grad_map] * 3, axis=-1) # Convert smooth grad map to RGB
|
768
|
+
overlay = (overlay * 0.5 + smooth_grad_map_rgb * 0.5).clip(0, 1)
|
769
|
+
ax[2].imshow(overlay)
|
770
|
+
ax[2].axis('off')
|
771
|
+
ax[2].set_title("Overlay")
|
772
|
+
plt.show()
|
773
|
+
|
774
|
+
if save_smooth_grad:
|
775
|
+
os.makedirs(save_dir, exist_ok=True)
|
776
|
+
smooth_grad_image = Image.fromarray((smooth_grad_map * 255).astype(np.uint8))
|
777
|
+
smooth_grad_image.save(os.path.join(save_dir, f'smooth_grad_{file}'))
|
778
|
+
|
779
|
+
# Usage
|
780
|
+
#src = '/path/to/images'
|
781
|
+
#model_path = '/path/to/model.pth'
|
782
|
+
#target_label_idx = 0 # Change this to the target class index
|
783
|
+
#visualize_smooth_grad(src, model_path, target_label_idx)
|
spacr/graph_learning.py
CHANGED
@@ -9,6 +9,7 @@ from PIL import Image
|
|
9
9
|
import dgl.nn.pytorch as dglnn
|
10
10
|
from sklearn.datasets import make_classification
|
11
11
|
from .utils import SelectChannels
|
12
|
+
from IPython.display import display
|
12
13
|
|
13
14
|
# approach outline
|
14
15
|
#
|
@@ -241,6 +242,31 @@ def analyze_associations(probabilities, sequencing_data):
|
|
241
242
|
sequencing_data['positive_prob'] = probabilities
|
242
243
|
return sequencing_data.groupby('gRNA').positive_prob.mean().sort_values(ascending=False)
|
243
244
|
|
245
|
+
def process_sequencing_df(seq):
|
246
|
+
|
247
|
+
if isinstance(seq, pd.DataFrame):
|
248
|
+
sequencing_df = seq
|
249
|
+
elif isinstance(seq, str):
|
250
|
+
sequencing_df = pd.read_csv(seq)
|
251
|
+
|
252
|
+
# Check if 'plate_row' column exists and split into 'plate' and 'row'
|
253
|
+
if 'plate_row' in sequencing_df.columns:
|
254
|
+
sequencing_df[['plate', 'row']] = sequencing_df['plate_row'].str.split('_', expand=True)
|
255
|
+
|
256
|
+
# Check if 'plate', 'row' and 'col' or 'plate', 'row' and 'column' exist
|
257
|
+
if {'plate', 'row', 'col'}.issubset(sequencing_df.columns) or {'plate', 'row', 'column'}.issubset(sequencing_df.columns):
|
258
|
+
if 'col' in sequencing_df.columns:
|
259
|
+
sequencing_df['prc'] = sequencing_df[['plate', 'row', 'col']].agg('_'.join, axis=1)
|
260
|
+
elif 'column' in sequencing_df.columns:
|
261
|
+
sequencing_df['prc'] = sequencing_df[['plate', 'row', 'column']].agg('_'.join, axis=1)
|
262
|
+
|
263
|
+
# Check if 'count', 'total_reads', 'read_fraction', 'grna' exist and create new dataframe
|
264
|
+
if {'count', 'total_reads', 'read_fraction', 'grna'}.issubset(sequencing_df.columns):
|
265
|
+
new_df = sequencing_df[['grna', 'prc', 'count', 'total_reads', 'read_fraction']]
|
266
|
+
return new_df
|
267
|
+
|
268
|
+
return sequencing_df
|
269
|
+
|
244
270
|
def train_graph_transformer(src, lr=0.01, epochs=100, hidden_feats=128, n_classes=2, row_limit=None, image_size=224, channels=[1,2,3], normalize=True, test_mode=False):
|
245
271
|
if test_mode:
|
246
272
|
# Load MNIST data
|
@@ -260,7 +286,6 @@ def train_graph_transformer(src, lr=0.01, epochs=100, hidden_feats=128, n_classe
|
|
260
286
|
|
261
287
|
# Normalize synthetic sequencing data
|
262
288
|
sequencing_data = normalize_sequencing_data(sequencing_data)
|
263
|
-
|
264
289
|
else:
|
265
290
|
from .io import _read_and_join_tables
|
266
291
|
from .utils import get_db_paths, get_sequencing_paths, correct_paths
|
@@ -274,18 +299,13 @@ def train_graph_transformer(src, lr=0.01, epochs=100, hidden_feats=128, n_classe
|
|
274
299
|
sequencing_data = pd.DataFrame()
|
275
300
|
for seq in seq_paths:
|
276
301
|
sequencing_df = pd.read_csv(seq)
|
302
|
+
sequencing_df = process_sequencing_df(sequencing_df)
|
277
303
|
sequencing_data = pd.concat([sequencing_data, sequencing_df], axis=0)
|
278
304
|
|
279
|
-
all_df = pd.DataFrame()
|
280
|
-
for db_path in db_paths:
|
281
|
-
df = _read_and_join_tables(db_path, table_names=['png_list'])
|
282
|
-
all_df = pd.concat([all_df, df], axis=0)
|
283
|
-
|
284
|
-
tables = ['png_list']
|
285
305
|
all_df = pd.DataFrame()
|
286
306
|
image_paths = []
|
287
307
|
for i, db_path in enumerate(db_paths):
|
288
|
-
df = _read_and_join_tables(db_path, table_names=
|
308
|
+
df = _read_and_join_tables(db_path, table_names=['png_list'])
|
289
309
|
df, image_paths_tmp = correct_paths(df, src[i])
|
290
310
|
all_df = pd.concat([all_df, df], axis=0)
|
291
311
|
image_paths.extend(image_paths_tmp)
|
spacr/gui.py
CHANGED
@@ -59,10 +59,10 @@ class MainApp(tk.Tk):
|
|
59
59
|
|
60
60
|
# Load the logo image
|
61
61
|
if not self.load_logo(logo_frame):
|
62
|
-
tk.Label(logo_frame, text="Logo not found", bg="black", fg="white", font=('
|
62
|
+
tk.Label(logo_frame, text="Logo not found", bg="black", fg="white", font=('Helvetica', 24, tkFont.NORMAL)).pack(padx=10, pady=10)
|
63
63
|
|
64
64
|
# Add SpaCr text below the logo with padding for sharper text
|
65
|
-
tk.Label(logo_frame, text="SpaCr", bg="black", fg="#008080", font=('
|
65
|
+
tk.Label(logo_frame, text="SpaCr", bg="black", fg="#008080", font=('Helvetica', 24, tkFont.NORMAL)).pack(padx=10, pady=10)
|
66
66
|
|
67
67
|
# Create a frame for the buttons and descriptions
|
68
68
|
buttons_frame = tk.Frame(self.content_frame, bg="black")
|
@@ -72,10 +72,10 @@ class MainApp(tk.Tk):
|
|
72
72
|
app_func, app_desc = app_data
|
73
73
|
|
74
74
|
# Create custom button with text
|
75
|
-
button = CustomButton(buttons_frame, text=app_name, command=lambda app_name=app_name: self.load_app(app_name))
|
75
|
+
button = CustomButton(buttons_frame, text=app_name, command=lambda app_name=app_name: self.load_app(app_name), font=('Helvetica', 12))
|
76
76
|
button.grid(row=i, column=0, pady=10, padx=10, sticky="w")
|
77
77
|
|
78
|
-
description_label = tk.Label(buttons_frame, text=app_desc, bg="black", fg="white", wraplength=800, justify="left", font=('
|
78
|
+
description_label = tk.Label(buttons_frame, text=app_desc, bg="black", fg="white", wraplength=800, justify="left", font=('Helvetica', 10, tkFont.NORMAL))
|
79
79
|
description_label.grid(row=i, column=1, pady=10, padx=10, sticky="w")
|
80
80
|
|
81
81
|
# Ensure buttons have a fixed width
|
@@ -131,7 +131,7 @@ class MainApp(tk.Tk):
|
|
131
131
|
|
132
132
|
app_frame = tk.Frame(self.content_frame, bg="black")
|
133
133
|
app_frame.pack(fill=tk.BOTH, expand=True)
|
134
|
-
selected_app_func(app_frame
|
134
|
+
selected_app_func(app_frame)#, self.winfo_width(), self.winfo_height())
|
135
135
|
|
136
136
|
def clear_frame(self, frame):
|
137
137
|
for widget in frame.winfo_children():
|
spacr/gui_2.py
CHANGED
@@ -1,52 +1,93 @@
|
|
1
|
-
import
|
1
|
+
import tkinter as tk
|
2
|
+
from tkinter import ttk
|
3
|
+
from tkinter import font as tkFont
|
2
4
|
from PIL import Image, ImageTk
|
3
5
|
import os
|
4
6
|
import requests
|
5
7
|
|
6
|
-
|
8
|
+
# Import your GUI apps
|
9
|
+
from .gui_mask_app import initiate_mask_root
|
10
|
+
from .gui_measure_app import initiate_measure_root
|
11
|
+
from .annotate_app import initiate_annotation_app_root
|
12
|
+
from .mask_app import initiate_mask_app_root
|
13
|
+
from .gui_classify_app import initiate_classify_root
|
14
|
+
|
15
|
+
from .gui_utils import CustomButton, style_text_boxes
|
16
|
+
|
17
|
+
class MainApp(tk.Tk):
|
7
18
|
def __init__(self):
|
8
19
|
super().__init__()
|
9
20
|
self.title("SpaCr GUI Collection")
|
10
|
-
self.geometry("
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
21
|
+
self.geometry("1100x1500")
|
22
|
+
self.configure(bg="black")
|
23
|
+
#self.attributes('-fullscreen', True)
|
24
|
+
|
25
|
+
style = ttk.Style()
|
26
|
+
style_text_boxes(style)
|
27
|
+
|
28
|
+
self.gui_apps = {
|
29
|
+
"Mask": (initiate_mask_root, "Generate cellpose masks for cells, nuclei and pathogen images."),
|
30
|
+
"Measure": (initiate_measure_root, "Measure single object intensity and morphological feature. Crop and save single object image"),
|
31
|
+
"Annotate": (initiate_annotation_app_root, "Annotation single object images on a grid. Annotations are saved to database."),
|
32
|
+
"Make Masks": (initiate_mask_app_root, "Adjust pre-existing Cellpose models to your specific dataset for improved performance"),
|
33
|
+
"Classify": (initiate_classify_root, "Train Torch Convolutional Neural Networks (CNNs) or Transformers to classify single object images.")
|
34
|
+
}
|
35
|
+
|
36
|
+
self.selected_app = tk.StringVar()
|
17
37
|
self.create_widgets()
|
18
38
|
|
19
39
|
def create_widgets(self):
|
20
|
-
|
21
|
-
self
|
40
|
+
# Create the menu bar
|
41
|
+
#create_menu_bar(self)
|
42
|
+
# Create a canvas to hold the selected app and other elements
|
43
|
+
self.canvas = tk.Canvas(self, bg="black", highlightthickness=0, width=4000, height=4000)
|
44
|
+
self.canvas.grid(row=0, column=0, sticky="nsew")
|
45
|
+
self.grid_rowconfigure(0, weight=1)
|
46
|
+
self.grid_columnconfigure(0, weight=1)
|
47
|
+
# Create a frame inside the canvas to hold the main content
|
48
|
+
self.content_frame = tk.Frame(self.canvas, bg="black")
|
49
|
+
self.content_frame.pack(fill=tk.BOTH, expand=True)
|
50
|
+
# Create startup screen with buttons for each GUI app
|
51
|
+
self.create_startup_screen()
|
22
52
|
|
23
|
-
|
53
|
+
def create_startup_screen(self):
|
54
|
+
self.clear_frame(self.content_frame)
|
55
|
+
|
56
|
+
# Create a frame for the logo and description
|
57
|
+
logo_frame = tk.Frame(self.content_frame, bg="black")
|
24
58
|
logo_frame.pack(pady=20, expand=True)
|
25
59
|
|
60
|
+
# Load the logo image
|
26
61
|
if not self.load_logo(logo_frame):
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
62
|
+
tk.Label(logo_frame, text="Logo not found", bg="black", fg="white", font=('Helvetica', 24)).pack(padx=10, pady=10)
|
63
|
+
|
64
|
+
# Add SpaCr text below the logo with padding for sharper text
|
65
|
+
tk.Label(logo_frame, text="SpaCr", bg="black", fg="#008080", font=('Helvetica', 24)).pack(padx=10, pady=10)
|
66
|
+
|
67
|
+
# Create a frame for the buttons and descriptions
|
68
|
+
buttons_frame = tk.Frame(self.content_frame, bg="black")
|
69
|
+
buttons_frame.pack(pady=10, expand=True, padx=10)
|
70
|
+
|
71
|
+
for i, (app_name, app_data) in enumerate(self.gui_apps.items()):
|
72
|
+
app_func, app_desc = app_data
|
73
|
+
|
74
|
+
# Create custom button with text
|
75
|
+
button = CustomButton(buttons_frame, text=app_name, command=lambda app_name=app_name: self.load_app(app_name, app_func), font=('Helvetica', 12))
|
76
|
+
button.grid(row=i, column=0, pady=10, padx=10, sticky="w")
|
77
|
+
|
78
|
+
description_label = tk.Label(buttons_frame, text=app_desc, bg="black", fg="white", wraplength=800, justify="left", font=('Helvetica', 12))
|
79
|
+
description_label.grid(row=i, column=1, pady=10, padx=10, sticky="w")
|
80
|
+
|
81
|
+
# Ensure buttons have a fixed width
|
82
|
+
buttons_frame.grid_columnconfigure(0, minsize=150)
|
83
|
+
# Ensure descriptions expand as needed
|
84
|
+
buttons_frame.grid_columnconfigure(1, weight=1)
|
44
85
|
|
45
86
|
def load_logo(self, frame):
|
46
87
|
def download_image(url, save_path):
|
47
88
|
try:
|
48
89
|
response = requests.get(url, stream=True)
|
49
|
-
response.raise_for_status()
|
90
|
+
response.raise_for_status() # Raise an HTTPError for bad responses
|
50
91
|
with open(save_path, 'wb') as f:
|
51
92
|
for chunk in response.iter_content(chunk_size=8192):
|
52
93
|
f.write(chunk)
|
@@ -57,34 +98,63 @@ class MainApp(ctk.CTk):
|
|
57
98
|
|
58
99
|
try:
|
59
100
|
img_path = os.path.join(os.path.dirname(__file__), 'logo_spacr.png')
|
101
|
+
print(f"Trying to load logo from {img_path}")
|
60
102
|
logo_image = Image.open(img_path)
|
61
103
|
except (FileNotFoundError, Image.UnidentifiedImageError):
|
104
|
+
print(f"File {img_path} not found or is not a valid image. Attempting to download from GitHub.")
|
62
105
|
if download_image('https://raw.githubusercontent.com/EinarOlafsson/spacr/main/spacr/logo_spacr.png', img_path):
|
63
106
|
try:
|
107
|
+
print(f"Downloaded file size: {os.path.getsize(img_path)} bytes")
|
64
108
|
logo_image = Image.open(img_path)
|
65
109
|
except Image.UnidentifiedImageError as e:
|
110
|
+
print(f"Downloaded file is not a valid image: {e}")
|
66
111
|
return False
|
67
112
|
else:
|
68
113
|
return False
|
69
114
|
except Exception as e:
|
115
|
+
print(f"An error occurred while loading the logo: {e}")
|
70
116
|
return False
|
71
|
-
|
72
117
|
try:
|
73
|
-
logo_image = logo_image.resize((
|
118
|
+
logo_image = logo_image.resize((800, 800), Image.Resampling.LANCZOS)
|
74
119
|
logo_photo = ImageTk.PhotoImage(logo_image)
|
75
|
-
logo_label =
|
120
|
+
logo_label = tk.Label(frame, image=logo_photo, bg="black")
|
76
121
|
logo_label.image = logo_photo # Keep a reference to avoid garbage collection
|
77
122
|
logo_label.pack()
|
78
123
|
return True
|
79
124
|
except Exception as e:
|
125
|
+
print(f"An error occurred while processing the logo image: {e}")
|
80
126
|
return False
|
81
127
|
|
82
|
-
def
|
83
|
-
|
128
|
+
def load_app_v1(self, app_name):
|
129
|
+
selected_app_func, _ = self.gui_apps[app_name]
|
130
|
+
self.clear_frame(self.content_frame)
|
131
|
+
|
132
|
+
app_frame = tk.Frame(self.content_frame, bg="black")
|
133
|
+
app_frame.pack(fill=tk.BOTH, expand=True)
|
134
|
+
selected_app_func(app_frame)
|
135
|
+
|
136
|
+
def load_app(root, app_name, app_func):
|
137
|
+
if hasattr(root, 'current_app_id'):
|
138
|
+
root.after_cancel(root.current_app_id)
|
139
|
+
root.current_app_id = None
|
140
|
+
|
141
|
+
# Clear the current content frame
|
142
|
+
for widget in root.content_frame.winfo_children():
|
143
|
+
widget.destroy()
|
144
|
+
|
145
|
+
# Initialize the selected app
|
146
|
+
app_frame = tk.Frame(root.content_frame, bg="black")
|
147
|
+
app_frame.pack(fill=tk.BOTH, expand=True)
|
148
|
+
app_func(app_frame)
|
149
|
+
|
150
|
+
def clear_frame(self, frame):
|
151
|
+
for widget in frame.winfo_children():
|
152
|
+
widget.destroy()
|
153
|
+
|
84
154
|
|
85
155
|
def gui_app():
|
86
156
|
app = MainApp()
|
87
157
|
app.mainloop()
|
88
158
|
|
89
159
|
if __name__ == "__main__":
|
90
|
-
gui_app()
|
160
|
+
gui_app()
|
spacr/gui_classify_app.py
CHANGED
@@ -69,13 +69,13 @@ def import_settings(scrollable_frame):
|
|
69
69
|
vars_dict = generate_fields(new_settings, scrollable_frame)
|
70
70
|
|
71
71
|
#@log_function_call
|
72
|
-
def initiate_classify_root(parent_frame)
|
72
|
+
def initiate_classify_root(parent_frame):
|
73
73
|
global vars_dict, q, canvas, fig_queue, canvas_widget, thread_control
|
74
74
|
|
75
75
|
style = ttk.Style(parent_frame)
|
76
76
|
set_dark_style(style)
|
77
77
|
style_text_boxes(style)
|
78
|
-
set_default_font(parent_frame, font_name="
|
78
|
+
set_default_font(parent_frame, font_name="Helvetica", size=8)
|
79
79
|
|
80
80
|
parent_frame.configure(bg='#333333')
|
81
81
|
parent_frame.grid_rowconfigure(0, weight=1)
|
@@ -179,7 +179,7 @@ def gui_classify():
|
|
179
179
|
root = tk.Tk()
|
180
180
|
root.geometry("1000x800")
|
181
181
|
root.title("SpaCer: generate masks")
|
182
|
-
initiate_classify_root(root,
|
182
|
+
initiate_classify_root(root),
|
183
183
|
create_menu_bar(root)
|
184
184
|
root.mainloop()
|
185
185
|
|