spacr 0.0.81__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/deep_spacr.py CHANGED
@@ -6,7 +6,6 @@ from torch.autograd import grad
6
6
  from torch.optim.lr_scheduler import StepLR
7
7
  import torch.nn.functional as F
8
8
  from IPython.display import display, clear_output
9
-
10
9
  import matplotlib.pyplot as plt
11
10
  from PIL import Image
12
11
 
@@ -200,11 +199,20 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
200
199
  from .io import _save_settings, _copy_missclassified
201
200
  from .utils import pick_best_model
202
201
  from .core import generate_loaders
203
-
202
+ from .settings import set_default_train_test_model
203
+
204
+ torch.cuda.empty_cache()
205
+ torch.cuda.memory.empty_cache()
206
+ gc.collect()
207
+
208
+ settings = set_default_train_test_model(settings)
209
+ channels_str = ''.join(settings['channels'])
210
+ dst = os.path.join(src,'model', settings['model_type'], channels_str, str(f"epochs_{settings['epochs']}"))
211
+ os.makedirs(dst, exist_ok=True)
204
212
  settings['src'] = src
213
+ settings['dst'] = dst
205
214
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
206
- settings_csv = os.path.join(src,'settings','train_test_model_settings.csv')
207
- os.makedirs(os.path.join(src,'settings'), exist_ok=True)
215
+ settings_csv = os.path.join(dst,'train_test_model_settings.csv')
208
216
  settings_df.to_csv(settings_csv, index=False)
209
217
 
210
218
  if custom_model:
@@ -212,15 +220,9 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
212
220
 
213
221
  if settings['train']:
214
222
  _save_settings(settings, src)
215
- torch.cuda.empty_cache()
216
- torch.cuda.memory.empty_cache()
217
- gc.collect()
218
- dst = os.path.join(src,'model')
219
- os.makedirs(dst, exist_ok=True)
220
- settings['src'] = src
221
- settings['dst'] = dst
223
+
222
224
  if settings['train']:
223
- train, val, plate_names = generate_loaders(src,
225
+ train, val, plate_names, train_fig = generate_loaders(src,
224
226
  train_mode=settings['train_mode'],
225
227
  mode='train',
226
228
  image_size=settings['image_size'],
@@ -231,11 +233,42 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
231
233
  pin_memory=settings['pin_memory'],
232
234
  normalize=settings['normalize'],
233
235
  channels=settings['channels'],
236
+ augment=settings['augment'],
234
237
  verbose=settings['verbose'])
235
-
236
-
238
+
239
+ train_batch_1_figure = os.path.join(dst, 'batch_1.pdf')
240
+ train_fig.savefig(train_batch_1_figure, format='pdf', dpi=600)
241
+
242
+ if settings['train']:
243
+ model = train_model(dst = settings['dst'],
244
+ model_type=settings['model_type'],
245
+ train_loaders = train,
246
+ train_loader_names = plate_names,
247
+ train_mode = settings['train_mode'],
248
+ epochs = settings['epochs'],
249
+ learning_rate = settings['learning_rate'],
250
+ init_weights = settings['init_weights'],
251
+ weight_decay = settings['weight_decay'],
252
+ amsgrad = settings['amsgrad'],
253
+ optimizer_type = settings['optimizer_type'],
254
+ use_checkpoint = settings['use_checkpoint'],
255
+ dropout_rate = settings['dropout_rate'],
256
+ num_workers = settings['num_workers'],
257
+ val_loaders = val,
258
+ test_loaders = None,
259
+ intermedeate_save = settings['intermedeate_save'],
260
+ schedule = settings['schedule'],
261
+ loss_type=settings['loss_type'],
262
+ gradient_accumulation=settings['gradient_accumulation'],
263
+ gradient_accumulation_steps=settings['gradient_accumulation_steps'],
264
+ channels=settings['channels'])
265
+
266
+ torch.cuda.empty_cache()
267
+ torch.cuda.memory.empty_cache()
268
+ gc.collect()
269
+
237
270
  if settings['test']:
238
- test, _, plate_names_test = generate_loaders(src,
271
+ test, _, plate_names_test, train_fig = generate_loaders(src,
239
272
  train_mode=settings['train_mode'],
240
273
  mode='test',
241
274
  image_size=settings['image_size'],
@@ -246,6 +279,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
246
279
  pin_memory=settings['pin_memory'],
247
280
  normalize=settings['normalize'],
248
281
  channels=settings['channels'],
282
+ augment=False,
249
283
  verbose=settings['verbose'])
250
284
  if model == None:
251
285
  model_path = pick_best_model(src+'/model')
@@ -258,10 +292,10 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
258
292
  print(type(model))
259
293
  print(model)
260
294
 
261
- model_fldr = os.path.join(src,'model')
295
+ model_fldr = dst
262
296
  time_now = datetime.date.today().strftime('%y%m%d')
263
- result_loc = f'{model_fldr}/{model_type}_time_{time_now}_result.csv'
264
- acc_loc = f'{model_fldr}/{model_type}_time_{time_now}_acc.csv'
297
+ result_loc = f"{model_fldr}/{settings['model_type']}_time_{time_now}_test_result.csv"
298
+ acc_loc = f"{model_fldr}/{settings['model_type']}_time_{time_now}_test_acc.csv"
265
299
  print(f'Results wil be saved in: {result_loc}')
266
300
 
267
301
  result, accuracy = test_model_performance(loaders=test,
@@ -274,37 +308,12 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
274
308
  result.to_csv(result_loc, index=True, header=True, mode='w')
275
309
  accuracy.to_csv(acc_loc, index=True, header=True, mode='w')
276
310
  _copy_missclassified(accuracy)
277
- else:
278
- test = None
279
-
280
- if settings['train']:
281
- train_model(dst = settings['dst'],
282
- model_type=settings['model_type'],
283
- train_loaders = train,
284
- train_loader_names = plate_names,
285
- train_mode = settings['train_mode'],
286
- epochs = settings['epochs'],
287
- learning_rate = settings['learning_rate'],
288
- init_weights = settings['init_weights'],
289
- weight_decay = settings['weight_decay'],
290
- amsgrad = settings['amsgrad'],
291
- optimizer_type = settings['optimizer_type'],
292
- use_checkpoint = settings['use_checkpoint'],
293
- dropout_rate = settings['dropout_rate'],
294
- num_workers = settings['num_workers'],
295
- val_loaders = val,
296
- test_loaders = test,
297
- intermedeate_save = settings['intermedeate_save'],
298
- schedule = settings['schedule'],
299
- loss_type=settings['loss_type'],
300
- gradient_accumulation=settings['gradient_accumulation'],
301
- gradient_accumulation_steps=settings['gradient_accumulation_steps'])
302
311
 
303
312
  torch.cuda.empty_cache()
304
313
  torch.cuda.memory.empty_cache()
305
314
  gc.collect()
306
315
 
307
- def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, num_workers=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4):
316
+ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, num_workers=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b']):
308
317
  """
309
318
  Trains a model using the specified parameters.
310
319
 
@@ -349,7 +358,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
349
358
  kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
350
359
 
351
360
  for idx, (images, labels, filenames) in enumerate(train_loaders):
352
- batch, channels, height, width = images.shape
361
+ batch, chans, height, width = images.shape
353
362
  break
354
363
 
355
364
  model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
@@ -432,10 +441,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
432
441
  if schedule == 'step_lr':
433
442
  scheduler.step()
434
443
 
435
- _save_progress(dst, results_df, train_metrics_df)
444
+ _save_progress(dst, results_df, train_metrics_df, epoch, epochs)
436
445
  clear_output(wait=True)
437
446
  display(results_df)
438
- _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
447
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=channels)
439
448
 
440
449
  if train_mode == 'irm':
441
450
  dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
@@ -505,10 +514,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
505
514
 
506
515
  clear_output(wait=True)
507
516
  display(results_df)
508
- _save_progress(dst, results_df, train_metrics_df)
517
+ _save_progress(dst, results_df, train_metrics_df, epoch, epochs)
509
518
  _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
510
519
  print(f'Saved model: {dst}')
511
- return
520
+ return model
512
521
 
513
522
  def visualize_saliency_map(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
514
523
 
@@ -693,4 +702,82 @@ def visualize_integrated_gradients(src, model_path, target_label_idx=0, image_si
693
702
  if save_integrated_grads:
694
703
  os.makedirs(save_dir, exist_ok=True)
695
704
  integrated_grads_image = Image.fromarray((integrated_grads * 255).astype(np.uint8))
696
- integrated_grads_image.save(os.path.join(save_dir, f'integrated_grads_{file}'))
705
+ integrated_grads_image.save(os.path.join(save_dir, f'integrated_grads_{file}'))
706
+
707
+ class SmoothGrad:
708
+ def __init__(self, model, n_samples=50, stdev_spread=0.15):
709
+ self.model = model
710
+ self.n_samples = n_samples
711
+ self.stdev_spread = stdev_spread
712
+
713
+ def compute_smooth_grad(self, input_tensor, target_class):
714
+ self.model.eval()
715
+ stdev = self.stdev_spread * (input_tensor.max() - input_tensor.min())
716
+ total_gradients = torch.zeros_like(input_tensor)
717
+
718
+ for i in range(self.n_samples):
719
+ noise = torch.normal(mean=0, std=stdev, size=input_tensor.shape).to(input_tensor.device)
720
+ noisy_input = input_tensor + noise
721
+ noisy_input.requires_grad_()
722
+ output = self.model(noisy_input)
723
+ self.model.zero_grad()
724
+ output[0, target_class].backward()
725
+ total_gradients += noisy_input.grad
726
+
727
+ avg_gradients = total_gradients / self.n_samples
728
+ return avg_gradients.abs()
729
+
730
+ def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, channels=[1,2,3], normalize=True, save_smooth_grad=False, save_dir='smooth_grad'):
731
+
732
+ from .utils import preprocess_image
733
+
734
+ use_cuda = torch.cuda.is_available()
735
+ device = torch.device("cuda" if use_cuda else "cpu")
736
+
737
+ model = torch.load(model_path)
738
+ model.to(device)
739
+ smooth_grad = SmoothGrad(model)
740
+
741
+ if save_smooth_grad and not os.path.exists(save_dir):
742
+ os.makedirs(save_dir)
743
+
744
+ images = []
745
+ filenames = []
746
+ for file in os.listdir(src):
747
+ if not file.endswith('.png'):
748
+ continue
749
+ image_path = os.path.join(src, file)
750
+ image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
751
+ images.append(image)
752
+ filenames.append(file)
753
+
754
+ input_tensor = input_tensor.to(device)
755
+ smooth_grad_map = smooth_grad.compute_smooth_grad(input_tensor, target_label_idx)
756
+ smooth_grad_map = np.mean(smooth_grad_map.cpu().data.numpy(), axis=1).squeeze()
757
+
758
+ fig, ax = plt.subplots(1, 3, figsize=(20, 5))
759
+ ax[0].imshow(image)
760
+ ax[0].axis('off')
761
+ ax[0].set_title("Original Image")
762
+ ax[1].imshow(smooth_grad_map, cmap='hot')
763
+ ax[1].axis('off')
764
+ ax[1].set_title("SmoothGrad")
765
+ overlay = np.array(image)
766
+ overlay = overlay / overlay.max()
767
+ smooth_grad_map_rgb = np.stack([smooth_grad_map] * 3, axis=-1) # Convert smooth grad map to RGB
768
+ overlay = (overlay * 0.5 + smooth_grad_map_rgb * 0.5).clip(0, 1)
769
+ ax[2].imshow(overlay)
770
+ ax[2].axis('off')
771
+ ax[2].set_title("Overlay")
772
+ plt.show()
773
+
774
+ if save_smooth_grad:
775
+ os.makedirs(save_dir, exist_ok=True)
776
+ smooth_grad_image = Image.fromarray((smooth_grad_map * 255).astype(np.uint8))
777
+ smooth_grad_image.save(os.path.join(save_dir, f'smooth_grad_{file}'))
778
+
779
+ # Usage
780
+ #src = '/path/to/images'
781
+ #model_path = '/path/to/model.pth'
782
+ #target_label_idx = 0 # Change this to the target class index
783
+ #visualize_smooth_grad(src, model_path, target_label_idx)
spacr/graph_learning.py CHANGED
@@ -9,6 +9,7 @@ from PIL import Image
9
9
  import dgl.nn.pytorch as dglnn
10
10
  from sklearn.datasets import make_classification
11
11
  from .utils import SelectChannels
12
+ from IPython.display import display
12
13
 
13
14
  # approach outline
14
15
  #
@@ -241,6 +242,31 @@ def analyze_associations(probabilities, sequencing_data):
241
242
  sequencing_data['positive_prob'] = probabilities
242
243
  return sequencing_data.groupby('gRNA').positive_prob.mean().sort_values(ascending=False)
243
244
 
245
+ def process_sequencing_df(seq):
246
+
247
+ if isinstance(seq, pd.DataFrame):
248
+ sequencing_df = seq
249
+ elif isinstance(seq, str):
250
+ sequencing_df = pd.read_csv(seq)
251
+
252
+ # Check if 'plate_row' column exists and split into 'plate' and 'row'
253
+ if 'plate_row' in sequencing_df.columns:
254
+ sequencing_df[['plate', 'row']] = sequencing_df['plate_row'].str.split('_', expand=True)
255
+
256
+ # Check if 'plate', 'row' and 'col' or 'plate', 'row' and 'column' exist
257
+ if {'plate', 'row', 'col'}.issubset(sequencing_df.columns) or {'plate', 'row', 'column'}.issubset(sequencing_df.columns):
258
+ if 'col' in sequencing_df.columns:
259
+ sequencing_df['prc'] = sequencing_df[['plate', 'row', 'col']].agg('_'.join, axis=1)
260
+ elif 'column' in sequencing_df.columns:
261
+ sequencing_df['prc'] = sequencing_df[['plate', 'row', 'column']].agg('_'.join, axis=1)
262
+
263
+ # Check if 'count', 'total_reads', 'read_fraction', 'grna' exist and create new dataframe
264
+ if {'count', 'total_reads', 'read_fraction', 'grna'}.issubset(sequencing_df.columns):
265
+ new_df = sequencing_df[['grna', 'prc', 'count', 'total_reads', 'read_fraction']]
266
+ return new_df
267
+
268
+ return sequencing_df
269
+
244
270
  def train_graph_transformer(src, lr=0.01, epochs=100, hidden_feats=128, n_classes=2, row_limit=None, image_size=224, channels=[1,2,3], normalize=True, test_mode=False):
245
271
  if test_mode:
246
272
  # Load MNIST data
@@ -260,7 +286,6 @@ def train_graph_transformer(src, lr=0.01, epochs=100, hidden_feats=128, n_classe
260
286
 
261
287
  # Normalize synthetic sequencing data
262
288
  sequencing_data = normalize_sequencing_data(sequencing_data)
263
-
264
289
  else:
265
290
  from .io import _read_and_join_tables
266
291
  from .utils import get_db_paths, get_sequencing_paths, correct_paths
@@ -274,18 +299,13 @@ def train_graph_transformer(src, lr=0.01, epochs=100, hidden_feats=128, n_classe
274
299
  sequencing_data = pd.DataFrame()
275
300
  for seq in seq_paths:
276
301
  sequencing_df = pd.read_csv(seq)
302
+ sequencing_df = process_sequencing_df(sequencing_df)
277
303
  sequencing_data = pd.concat([sequencing_data, sequencing_df], axis=0)
278
304
 
279
- all_df = pd.DataFrame()
280
- for db_path in db_paths:
281
- df = _read_and_join_tables(db_path, table_names=['png_list'])
282
- all_df = pd.concat([all_df, df], axis=0)
283
-
284
- tables = ['png_list']
285
305
  all_df = pd.DataFrame()
286
306
  image_paths = []
287
307
  for i, db_path in enumerate(db_paths):
288
- df = _read_and_join_tables(db_path, table_names=tables)
308
+ df = _read_and_join_tables(db_path, table_names=['png_list'])
289
309
  df, image_paths_tmp = correct_paths(df, src[i])
290
310
  all_df = pd.concat([all_df, df], axis=0)
291
311
  image_paths.extend(image_paths_tmp)