spacr 0.2.53__py3-none-any.whl → 0.2.56__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/core.py +282 -10
- spacr/deep_spacr.py +101 -41
- spacr/gui.py +1 -1
- spacr/gui_core.py +8 -10
- spacr/gui_elements.py +70 -0
- spacr/gui_utils.py +30 -10
- spacr/io.py +12 -4
- spacr/sequencing.py +443 -643
- spacr/settings.py +176 -44
- spacr/utils.py +13 -5
- {spacr-0.2.53.dist-info → spacr-0.2.56.dist-info}/METADATA +2 -1
- {spacr-0.2.53.dist-info → spacr-0.2.56.dist-info}/RECORD +16 -16
- {spacr-0.2.53.dist-info → spacr-0.2.56.dist-info}/LICENSE +0 -0
- {spacr-0.2.53.dist-info → spacr-0.2.56.dist-info}/WHEEL +0 -0
- {spacr-0.2.53.dist-info → spacr-0.2.56.dist-info}/entry_points.txt +0 -0
- {spacr-0.2.53.dist-info → spacr-0.2.56.dist-info}/top_level.txt +0 -0
spacr/core.py
CHANGED
@@ -877,7 +877,106 @@ def annotate_results(pred_loc):
|
|
877
877
|
display(df)
|
878
878
|
return df
|
879
879
|
|
880
|
-
def generate_dataset(
|
880
|
+
def generate_dataset(settings={}):
|
881
|
+
|
882
|
+
from .utils import initiate_counter, add_images_to_tar
|
883
|
+
|
884
|
+
db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
|
885
|
+
dst = os.path.join(settings['src'], 'datasets')
|
886
|
+
all_paths = []
|
887
|
+
|
888
|
+
# Connect to the database and retrieve the image paths
|
889
|
+
print(f"Reading DataBase: {db_path}")
|
890
|
+
try:
|
891
|
+
with sqlite3.connect(db_path) as conn:
|
892
|
+
cursor = conn.cursor()
|
893
|
+
if settings['file_metadata']:
|
894
|
+
if isinstance(settings['file_metadata'], str):
|
895
|
+
cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{settings['file_metadata']}%",))
|
896
|
+
else:
|
897
|
+
cursor.execute("SELECT png_path FROM png_list")
|
898
|
+
|
899
|
+
while True:
|
900
|
+
rows = cursor.fetchmany(1000)
|
901
|
+
if not rows:
|
902
|
+
break
|
903
|
+
all_paths.extend([row[0] for row in rows])
|
904
|
+
|
905
|
+
except sqlite3.Error as e:
|
906
|
+
print(f"Database error: {e}")
|
907
|
+
return
|
908
|
+
except Exception as e:
|
909
|
+
print(f"Error: {e}")
|
910
|
+
return
|
911
|
+
|
912
|
+
if isinstance(settings['sample'], int):
|
913
|
+
selected_paths = random.sample(all_paths, settings['sample'])
|
914
|
+
print(f"Random selection of {len(selected_paths)} paths")
|
915
|
+
else:
|
916
|
+
selected_paths = all_paths
|
917
|
+
random.shuffle(selected_paths)
|
918
|
+
print(f"All paths: {len(selected_paths)} paths")
|
919
|
+
|
920
|
+
total_images = len(selected_paths)
|
921
|
+
print(f"Found {total_images} images")
|
922
|
+
|
923
|
+
# Create a temp folder in dst
|
924
|
+
temp_dir = os.path.join(dst, "temp_tars")
|
925
|
+
os.makedirs(temp_dir, exist_ok=True)
|
926
|
+
|
927
|
+
# Chunking the data
|
928
|
+
num_procs = max(2, cpu_count() - 2)
|
929
|
+
chunk_size = len(selected_paths) // num_procs
|
930
|
+
remainder = len(selected_paths) % num_procs
|
931
|
+
|
932
|
+
paths_chunks = []
|
933
|
+
start = 0
|
934
|
+
for i in range(num_procs):
|
935
|
+
end = start + chunk_size + (1 if i < remainder else 0)
|
936
|
+
paths_chunks.append(selected_paths[start:end])
|
937
|
+
start = end
|
938
|
+
|
939
|
+
temp_tar_files = [os.path.join(temp_dir, f"temp_{i}.tar") for i in range(num_procs)]
|
940
|
+
|
941
|
+
print(f"Generating temporary tar files in {dst}")
|
942
|
+
|
943
|
+
# Initialize shared counter and lock
|
944
|
+
counter = Value('i', 0)
|
945
|
+
lock = Lock()
|
946
|
+
|
947
|
+
with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
|
948
|
+
pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
|
949
|
+
|
950
|
+
# Combine the temporary tar files into a final tar
|
951
|
+
date_name = datetime.date.today().strftime('%y%m%d')
|
952
|
+
if not settings['file_metadata'] is None:
|
953
|
+
tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
|
954
|
+
else:
|
955
|
+
tar_name = f"{date_name}_{settings['experiment']}.tar"
|
956
|
+
tar_name = os.path.join(dst, tar_name)
|
957
|
+
if os.path.exists(tar_name):
|
958
|
+
number = random.randint(1, 100)
|
959
|
+
tar_name_2 = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}_{number}.tar"
|
960
|
+
print(f"Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ")
|
961
|
+
tar_name = os.path.join(dst, tar_name_2)
|
962
|
+
|
963
|
+
print(f"Merging temporary files")
|
964
|
+
|
965
|
+
with tarfile.open(tar_name, 'w') as final_tar:
|
966
|
+
for temp_tar_path in temp_tar_files:
|
967
|
+
with tarfile.open(temp_tar_path, 'r') as temp_tar:
|
968
|
+
for member in temp_tar.getmembers():
|
969
|
+
file_obj = temp_tar.extractfile(member)
|
970
|
+
final_tar.addfile(member, file_obj)
|
971
|
+
os.remove(temp_tar_path)
|
972
|
+
|
973
|
+
# Delete the temp folder
|
974
|
+
shutil.rmtree(temp_dir)
|
975
|
+
print(f"\nSaved {total_images} images to {tar_name}")
|
976
|
+
|
977
|
+
return tar_name
|
978
|
+
|
979
|
+
def generate_dataset_v1(src, file_metadata=None, experiment='TSG101_screen', sample=None):
|
881
980
|
|
882
981
|
from .utils import initiate_counter, add_images_to_tar
|
883
982
|
|
@@ -974,7 +1073,7 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
|
|
974
1073
|
shutil.rmtree(temp_dir)
|
975
1074
|
print(f"\nSaved {total_images} images to {tar_name}")
|
976
1075
|
|
977
|
-
def
|
1076
|
+
def apply_model_to_tar_v1(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', n_jobs=10, threshold=0.5, verbose=False):
|
978
1077
|
|
979
1078
|
from .io import TarImageDataset
|
980
1079
|
from .utils import process_vision_results, print_progress
|
@@ -1044,6 +1143,76 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
|
|
1044
1143
|
torch.cuda.memory.empty_cache()
|
1045
1144
|
return df
|
1046
1145
|
|
1146
|
+
def apply_model_to_tar(settings={}):
|
1147
|
+
|
1148
|
+
from .io import TarImageDataset
|
1149
|
+
from .utils import process_vision_results, print_progress
|
1150
|
+
|
1151
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1152
|
+
if settings['normalize']:
|
1153
|
+
transform = transforms.Compose([
|
1154
|
+
transforms.ToTensor(),
|
1155
|
+
transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
|
1156
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
1157
|
+
else:
|
1158
|
+
transform = transforms.Compose([
|
1159
|
+
transforms.ToTensor(),
|
1160
|
+
transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
|
1161
|
+
|
1162
|
+
if settings['verbose']:
|
1163
|
+
print(f"Loading model from {settings['model_path']}")
|
1164
|
+
print(f"Loading dataset from {settings['tar_path']}")
|
1165
|
+
|
1166
|
+
model = torch.load(settings['model_path'])
|
1167
|
+
|
1168
|
+
dataset = TarImageDataset(settings['tar_path'], transform=transform)
|
1169
|
+
data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
|
1170
|
+
|
1171
|
+
model_name = os.path.splitext(os.path.basename(settings['model_path']))[0]
|
1172
|
+
dataset_name = os.path.splitext(os.path.basename(settings['tar_path']))[0]
|
1173
|
+
date_name = datetime.date.today().strftime('%y%m%d')
|
1174
|
+
dst = os.path.dirname(settings['tar_path'])
|
1175
|
+
result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
|
1176
|
+
|
1177
|
+
model.eval()
|
1178
|
+
model = model.to(device)
|
1179
|
+
|
1180
|
+
if settings['verbose']:
|
1181
|
+
print(model)
|
1182
|
+
print(f'Generated dataset with {len(dataset)} images')
|
1183
|
+
print(f'Generating loader from {len(data_loader)} batches')
|
1184
|
+
print(f'Results wil be saved in: {result_loc}')
|
1185
|
+
print(f'Model is in eval mode')
|
1186
|
+
print(f'Model loaded to device')
|
1187
|
+
|
1188
|
+
prediction_pos_probs = []
|
1189
|
+
filenames_list = []
|
1190
|
+
time_ls = []
|
1191
|
+
gc.collect()
|
1192
|
+
with torch.no_grad():
|
1193
|
+
for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
|
1194
|
+
start = time.time()
|
1195
|
+
images = batch_images.to(torch.float).to(device)
|
1196
|
+
outputs = model(images)
|
1197
|
+
batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
|
1198
|
+
prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
|
1199
|
+
filenames_list.extend(filenames)
|
1200
|
+
stop = time.time()
|
1201
|
+
duration = stop - start
|
1202
|
+
time_ls.append(duration)
|
1203
|
+
files_processed = batch_idx*settings['batch_size']
|
1204
|
+
files_to_process = len(data_loader)
|
1205
|
+
print_progress(files_processed, files_to_process, n_jobs=settings['n_jobs'], time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Tar dataset")
|
1206
|
+
|
1207
|
+
data = {'path':filenames_list, 'pred':prediction_pos_probs}
|
1208
|
+
df = pd.DataFrame(data, index=None)
|
1209
|
+
df = process_vision_results(df, settings['score_threshold'])
|
1210
|
+
|
1211
|
+
df.to_csv(result_loc, index=True, header=True, mode='w')
|
1212
|
+
torch.cuda.empty_cache()
|
1213
|
+
torch.cuda.memory.empty_cache()
|
1214
|
+
return df
|
1215
|
+
|
1047
1216
|
def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, n_jobs=10):
|
1048
1217
|
|
1049
1218
|
from .io import NoClassDataset
|
@@ -1206,19 +1375,19 @@ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
|
|
1206
1375
|
for path in train_data:
|
1207
1376
|
start = time.time()
|
1208
1377
|
shutil.copy(path, os.path.join(train_class_dir, os.path.basename(path)))
|
1209
|
-
processed_files += 1
|
1210
1378
|
duration = time.time() - start
|
1211
1379
|
time_ls.append(duration)
|
1212
1380
|
print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Train dataset")
|
1381
|
+
processed_files += 1
|
1213
1382
|
|
1214
1383
|
# Copy test files
|
1215
1384
|
for path in test_data:
|
1216
1385
|
start = time.time()
|
1217
1386
|
shutil.copy(path, os.path.join(test_class_dir, os.path.basename(path)))
|
1218
|
-
processed_files += 1
|
1219
1387
|
duration = time.time() - start
|
1220
1388
|
time_ls.append(duration)
|
1221
1389
|
print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Test dataset")
|
1390
|
+
processed_files += 1
|
1222
1391
|
|
1223
1392
|
# Print summary
|
1224
1393
|
for cls in classes:
|
@@ -1226,9 +1395,9 @@ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
|
|
1226
1395
|
test_class_dir = os.path.join(dst, f'test/{cls}')
|
1227
1396
|
print(f'Train class {cls}: {len(os.listdir(train_class_dir))}, Test class {cls}: {len(os.listdir(test_class_dir))}')
|
1228
1397
|
|
1229
|
-
return
|
1398
|
+
return os.path.join(dst, 'train'), os.path.join(dst, 'test')
|
1230
1399
|
|
1231
|
-
def
|
1400
|
+
def generate_training_dataset_v1(src, mode='annotation', annotation_column='test', annotated_classes=[1,2], classes=['nc','pc'], size=200, test_split=0.1, class_metadata=[['c1'],['c2']], metadata_type_by='col', channel_of_interest=3, custom_measurement=None, tables=None, png_type='cell_png'):
|
1232
1401
|
|
1233
1402
|
from .io import _read_and_merge_data, _read_db
|
1234
1403
|
from .utils import get_paths_from_db, annotate_conditions
|
@@ -1329,6 +1498,110 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1329
1498
|
|
1330
1499
|
return
|
1331
1500
|
|
1501
|
+
def generate_training_dataset(settings):
|
1502
|
+
|
1503
|
+
from .io import _read_and_merge_data, _read_db
|
1504
|
+
from .utils import get_paths_from_db, annotate_conditions
|
1505
|
+
from .settings import set_generate_training_dataset_defaults
|
1506
|
+
|
1507
|
+
settings = set_generate_training_dataset_defaults(settings)
|
1508
|
+
|
1509
|
+
db_path = os.path.join(settings['src'], 'measurements','measurements.db')
|
1510
|
+
dst = os.path.join(settings['src'], 'datasets', 'training')
|
1511
|
+
|
1512
|
+
if os.path.exists(dst):
|
1513
|
+
for i in range(1, 1000):
|
1514
|
+
dst = os.path.join(settings['src'], 'datasets', f'training_{i}')
|
1515
|
+
if not os.path.exists(dst):
|
1516
|
+
print(f'Creating new directory for training: {dst}')
|
1517
|
+
break
|
1518
|
+
|
1519
|
+
if settings['dataset_mode'] == 'annotation':
|
1520
|
+
class_paths_ls_2 = []
|
1521
|
+
class_paths_ls = training_dataset_from_annotation(db_path, dst, settings['annotation_column'], annotated_classes=settings['annotated_classes'])
|
1522
|
+
for class_paths in class_paths_ls:
|
1523
|
+
class_paths_temp = random.sample(class_paths, settings['size'])
|
1524
|
+
class_paths_ls_2.append(class_paths_temp)
|
1525
|
+
class_paths_ls = class_paths_ls_2
|
1526
|
+
|
1527
|
+
elif settings['dataset_mode'] == 'metadata':
|
1528
|
+
class_paths_ls = []
|
1529
|
+
class_len_ls = []
|
1530
|
+
[df] = _read_db(db_loc=db_path, tables=['png_list'])
|
1531
|
+
df['metadata_based_class'] = pd.NA
|
1532
|
+
for i, class_ in enumerate(settings['classes']):
|
1533
|
+
ls = settings['class_metadata'][i]
|
1534
|
+
df.loc[df[settings['metadata_type_by']].isin(ls), 'metadata_based_class'] = class_
|
1535
|
+
|
1536
|
+
for class_ in settings['classes']:
|
1537
|
+
if settings['size'] == None:
|
1538
|
+
c_s = []
|
1539
|
+
for c in settings['classes']:
|
1540
|
+
c_s_t_df = df[df['metadata_based_class'] == c]
|
1541
|
+
c_s.append(len(c_s_t_df))
|
1542
|
+
print(f'Found {len(c_s_t_df)} images for class {c}')
|
1543
|
+
size = min(c_s)
|
1544
|
+
print(f'Using the smallest class size: {size}')
|
1545
|
+
|
1546
|
+
class_temp_df = df[df['metadata_based_class'] == class_]
|
1547
|
+
class_len_ls.append(len(class_temp_df))
|
1548
|
+
print(f'Found {len(class_temp_df)} images for class {class_}')
|
1549
|
+
class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), settings['size'])
|
1550
|
+
class_paths_ls.append(class_paths_temp)
|
1551
|
+
|
1552
|
+
elif settings['dataset_mode'] == 'recruitment':
|
1553
|
+
class_paths_ls = []
|
1554
|
+
if not isinstance(settings['tables'], list):
|
1555
|
+
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
1556
|
+
|
1557
|
+
df, _ = _read_and_merge_data(locs=[db_path],
|
1558
|
+
tables=tables,
|
1559
|
+
verbose=False,
|
1560
|
+
include_multinucleated=True,
|
1561
|
+
include_multiinfected=True,
|
1562
|
+
include_noninfected=True)
|
1563
|
+
|
1564
|
+
print('length df 1', len(df))
|
1565
|
+
|
1566
|
+
df = annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['pathogen'], pathogen_loc=None, treatments=settings['classes'], treatment_loc=settings['class_metadata'], types = settings['metadata_type_by'])
|
1567
|
+
print('length df 2', len(df))
|
1568
|
+
[png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
|
1569
|
+
|
1570
|
+
if settings['custom_measurement'] != None:
|
1571
|
+
|
1572
|
+
if not isinstance(settings['custom_measurement'], list):
|
1573
|
+
print(f'custom_measurement should be a list, add [ measurement_1, measurement_2 ] or [ measurement ]')
|
1574
|
+
return
|
1575
|
+
|
1576
|
+
if isinstance(settings['custom_measurement'], list):
|
1577
|
+
if len(settings['custom_measurement']) == 2:
|
1578
|
+
print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment ({settings['custom_measurement'][0]}/{settings['custom_measurement'][1]})")
|
1579
|
+
df['recruitment'] = df[f"{settings['custom_measurement'][0]}']/df[f'{settings['custom_measurement'][1]}"]
|
1580
|
+
if len(settings['custom_measurement']) == 1:
|
1581
|
+
print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment ({settings['custom_measurement'][0]})")
|
1582
|
+
df['recruitment'] = df[f"{settings['custom_measurement'][0]}"]
|
1583
|
+
else:
|
1584
|
+
print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {settings['channel_of_interest']})")
|
1585
|
+
df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity']/df[f'cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
|
1586
|
+
|
1587
|
+
q25 = df['recruitment'].quantile(0.25)
|
1588
|
+
q75 = df['recruitment'].quantile(0.75)
|
1589
|
+
df_lower = df[df['recruitment'] <= q25]
|
1590
|
+
df_upper = df[df['recruitment'] >= q75]
|
1591
|
+
|
1592
|
+
class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=settings['png_type'])
|
1593
|
+
|
1594
|
+
class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), settings['size'])
|
1595
|
+
class_paths_ls.append(class_paths_lower)
|
1596
|
+
|
1597
|
+
class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=settings['png_type'])
|
1598
|
+
class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), settings['size'])
|
1599
|
+
class_paths_ls.append(class_paths_upper)
|
1600
|
+
|
1601
|
+
train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=settings['classes'], test_split=settings['test_split'])
|
1602
|
+
|
1603
|
+
return train_class_dir, test_class_dir
|
1604
|
+
|
1332
1605
|
def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
|
1333
1606
|
|
1334
1607
|
"""
|
@@ -2497,7 +2770,6 @@ def ml_analysis(df, channel_of_interest=3, location_column='col', positive_contr
|
|
2497
2770
|
df_metadata = df[[location_column]].copy()
|
2498
2771
|
df, features = filter_dataframe_features(df, channel_of_interest, exclude, remove_low_variance_features, remove_highly_correlated_features, verbose)
|
2499
2772
|
|
2500
|
-
|
2501
2773
|
if verbose:
|
2502
2774
|
print(f'Found {len(features)} numerical features in the dataframe')
|
2503
2775
|
print(f'Features used in training: {features}')
|
@@ -2642,7 +2914,6 @@ def check_index(df, elements=5, split_char='_'):
|
|
2642
2914
|
print(idx)
|
2643
2915
|
raise ValueError(f"Found {len(problematic_indices)} problematic indices that do not split into {elements} parts.")
|
2644
2916
|
|
2645
|
-
#def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c2', neg='c1', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
|
2646
2917
|
def generate_ml_scores(src, settings):
|
2647
2918
|
|
2648
2919
|
from .io import _read_and_merge_data
|
@@ -2680,7 +2951,7 @@ def generate_ml_scores(src, settings):
|
|
2680
2951
|
settings['top_features'],
|
2681
2952
|
settings['n_estimators'],
|
2682
2953
|
settings['test_size'],
|
2683
|
-
settings['
|
2954
|
+
settings['model_type_ml'],
|
2684
2955
|
settings['n_jobs'],
|
2685
2956
|
settings['remove_low_variance_features'],
|
2686
2957
|
settings['remove_highly_correlated_features'],
|
@@ -2701,7 +2972,7 @@ def generate_ml_scores(src, settings):
|
|
2701
2972
|
min_count=settings['minimum_cell_count'],
|
2702
2973
|
verbose=settings['verbose'])
|
2703
2974
|
|
2704
|
-
data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['
|
2975
|
+
data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type_ml'], settings['channel_of_interest'])
|
2705
2976
|
df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df = output
|
2706
2977
|
|
2707
2978
|
settings_df.to_csv(settings_csv, index=False)
|
@@ -2858,6 +3129,7 @@ def generate_image_umap(settings={}):
|
|
2858
3129
|
settings['plot_outlines'] = False
|
2859
3130
|
settings['smooth_lines'] = False
|
2860
3131
|
|
3132
|
+
print(f'Generating Image UMAP ...')
|
2861
3133
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
2862
3134
|
settings_dir = os.path.join(settings['src'][0],'settings')
|
2863
3135
|
settings_csv = os.path.join(settings_dir,'embedding_settings.csv')
|
spacr/deep_spacr.py
CHANGED
@@ -196,7 +196,7 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
|
|
196
196
|
test_time = end_time - start_time
|
197
197
|
return result, results_df
|
198
198
|
|
199
|
-
def train_test_model(
|
199
|
+
def train_test_model(settings):
|
200
200
|
|
201
201
|
from .io import _save_settings, _copy_missclassified
|
202
202
|
from .utils import pick_best_model
|
@@ -208,7 +208,10 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
208
208
|
gc.collect()
|
209
209
|
|
210
210
|
settings = set_default_train_test_model(settings)
|
211
|
-
|
211
|
+
|
212
|
+
src = settings['src']
|
213
|
+
|
214
|
+
channels_str = ''.join(settings['train_channels'])
|
212
215
|
dst = os.path.join(src,'model', settings['model_type'], channels_str, str(f"epochs_{settings['epochs']}"))
|
213
216
|
os.makedirs(dst, exist_ok=True)
|
214
217
|
settings['src'] = src
|
@@ -217,8 +220,8 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
217
220
|
settings_csv = os.path.join(dst,'train_test_model_settings.csv')
|
218
221
|
settings_df.to_csv(settings_csv, index=False)
|
219
222
|
|
220
|
-
if custom_model:
|
221
|
-
model = torch.load(custom_model_path)
|
223
|
+
if settings['custom_model']:
|
224
|
+
model = torch.load(settings['custom_model_path'])
|
222
225
|
|
223
226
|
if settings['train']:
|
224
227
|
_save_settings(settings, src)
|
@@ -234,7 +237,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
234
237
|
validation_split=settings['val_split'],
|
235
238
|
pin_memory=settings['pin_memory'],
|
236
239
|
normalize=settings['normalize'],
|
237
|
-
channels=settings['
|
240
|
+
channels=settings['train_channels'],
|
238
241
|
augment=settings['augment'],
|
239
242
|
verbose=settings['verbose'])
|
240
243
|
|
@@ -242,28 +245,28 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
242
245
|
train_fig.savefig(train_batch_1_figure, format='pdf', dpi=600)
|
243
246
|
|
244
247
|
if settings['train']:
|
245
|
-
model = train_model(dst = settings['dst'],
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
248
|
+
model, model_path = train_model(dst = settings['dst'],
|
249
|
+
model_type=settings['model_type'],
|
250
|
+
train_loaders = train,
|
251
|
+
train_loader_names = plate_names,
|
252
|
+
train_mode = settings['train_mode'],
|
253
|
+
epochs = settings['epochs'],
|
254
|
+
learning_rate = settings['learning_rate'],
|
255
|
+
init_weights = settings['init_weights'],
|
256
|
+
weight_decay = settings['weight_decay'],
|
257
|
+
amsgrad = settings['amsgrad'],
|
258
|
+
optimizer_type = settings['optimizer_type'],
|
259
|
+
use_checkpoint = settings['use_checkpoint'],
|
260
|
+
dropout_rate = settings['dropout_rate'],
|
261
|
+
n_jobs = settings['n_jobs'],
|
262
|
+
val_loaders = val,
|
263
|
+
test_loaders = None,
|
264
|
+
intermedeate_save = settings['intermedeate_save'],
|
265
|
+
schedule = settings['schedule'],
|
266
|
+
loss_type=settings['loss_type'],
|
267
|
+
gradient_accumulation=settings['gradient_accumulation'],
|
268
|
+
gradient_accumulation_steps=settings['gradient_accumulation_steps'],
|
269
|
+
channels=settings['train_channels'])
|
267
270
|
|
268
271
|
torch.cuda.empty_cache()
|
269
272
|
torch.cuda.memory.empty_cache()
|
@@ -280,7 +283,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
280
283
|
validation_split=0.0,
|
281
284
|
pin_memory=settings['pin_memory'],
|
282
285
|
normalize=settings['normalize'],
|
283
|
-
channels=settings['
|
286
|
+
channels=settings['train_channels'],
|
284
287
|
augment=False,
|
285
288
|
verbose=settings['verbose'])
|
286
289
|
if model == None:
|
@@ -314,6 +317,8 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
314
317
|
torch.cuda.empty_cache()
|
315
318
|
torch.cuda.memory.empty_cache()
|
316
319
|
gc.collect()
|
320
|
+
|
321
|
+
return model_path
|
317
322
|
|
318
323
|
def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, n_jobs=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b']):
|
319
324
|
"""
|
@@ -348,7 +353,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
348
353
|
"""
|
349
354
|
|
350
355
|
from .io import _save_model, _save_progress
|
351
|
-
from .utils import compute_irm_penalty, calculate_loss, choose_model
|
356
|
+
from .utils import compute_irm_penalty, calculate_loss, choose_model, print_progress
|
352
357
|
|
353
358
|
print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
|
354
359
|
|
@@ -386,6 +391,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
386
391
|
else:
|
387
392
|
scheduler = None
|
388
393
|
|
394
|
+
time_ls = []
|
389
395
|
if train_mode == 'erm':
|
390
396
|
for epoch in range(1, epochs+1):
|
391
397
|
model.train()
|
@@ -412,7 +418,13 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
412
418
|
optimizer.zero_grad()
|
413
419
|
|
414
420
|
avg_loss = running_loss / batch_idx
|
415
|
-
print(f'\rTrain: epoch: {epoch} batch: {batch_idx}/{len(train_loaders)} avg_loss: {avg_loss:.5f} time: {(time.time()-start_time):.5f}', end='\r', flush=True)
|
421
|
+
#print(f'\rTrain: epoch: {epoch} batch: {batch_idx}/{len(train_loaders)} avg_loss: {avg_loss:.5f} time: {(time.time()-start_time):.5f}', end='\r', flush=True)
|
422
|
+
|
423
|
+
batch_size = len(train_loaders)
|
424
|
+
duration = time.time() - start_time
|
425
|
+
time_ls.append(duration)
|
426
|
+
metricks = f"Loss: {avg_loss:.5f}"
|
427
|
+
print_progress(files_processed=epoch, files_to_process=epochs, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type=f"Training {model_type} model", metricks=metricks)
|
416
428
|
|
417
429
|
end_time = time.time()
|
418
430
|
train_time = end_time - start_time
|
@@ -421,6 +433,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
421
433
|
train_names = 'train'
|
422
434
|
results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='erm', loss_type=loss_type)
|
423
435
|
train_metrics_df['train_test_time'] = train_test_time
|
436
|
+
|
424
437
|
if val_loaders != None:
|
425
438
|
val_names = 'val'
|
426
439
|
result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='erm', loss_type=loss_type)
|
@@ -430,6 +443,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
430
443
|
|
431
444
|
results_df = pd.concat([results_df, result])
|
432
445
|
train_metrics_df['val_time'] = val_time
|
446
|
+
|
433
447
|
if test_loaders != None:
|
434
448
|
test_names = 'test'
|
435
449
|
result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='erm', loss_type=loss_type)
|
@@ -444,9 +458,30 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
444
458
|
scheduler.step()
|
445
459
|
|
446
460
|
_save_progress(dst, results_df, train_metrics_df, epoch, epochs)
|
447
|
-
clear_output(wait=True)
|
448
|
-
display(results_df)
|
449
|
-
|
461
|
+
#clear_output(wait=True)
|
462
|
+
#display(results_df)
|
463
|
+
|
464
|
+
train_idx = f"{epoch}_train"
|
465
|
+
val_idx = f"{epoch}_val"
|
466
|
+
train_acc = results_df.loc[train_idx, 'accuracy']
|
467
|
+
neg_train_acc = results_df.loc[train_idx, 'neg_accuracy']
|
468
|
+
pos_train_acc = results_df.loc[train_idx, 'pos_accuracy']
|
469
|
+
val_acc = results_df.loc[val_idx, 'accuracy']
|
470
|
+
neg_val_acc = results_df.loc[val_idx, 'neg_accuracy']
|
471
|
+
pos_val_acc = results_df.loc[val_idx, 'pos_accuracy']
|
472
|
+
train_loss = results_df.loc[train_idx, 'loss']
|
473
|
+
train_prauc = results_df.loc[train_idx, 'prauc']
|
474
|
+
val_loss = results_df.loc[val_idx, 'loss']
|
475
|
+
val_prauc = results_df.loc[val_idx, 'prauc']
|
476
|
+
|
477
|
+
metricks = f"Train Acc: {train_acc:.5f} Val Acc: {val_acc:.5f} Train Loss: {train_loss:.5f} Val Loss: {val_loss:.5f} Train PRAUC: {train_prauc:.5f} Val PRAUC: {val_prauc:.5f}, Nc Train Acc: {neg_train_acc:.5f} Nc Val Acc: {neg_val_acc:.5f} Pc Train Acc: {pos_train_acc:.5f} Pc Val Acc: {pos_val_acc:.5f}"
|
478
|
+
|
479
|
+
batch_size = len(train_loaders)
|
480
|
+
duration = time.time() - start_time
|
481
|
+
time_ls.append(duration)
|
482
|
+
print_progress(files_processed=epoch, files_to_process=epochs, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type=f"Training {model_type} model", metricks=metricks)
|
483
|
+
|
484
|
+
model_path = _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=channels)
|
450
485
|
|
451
486
|
if train_mode == 'irm':
|
452
487
|
dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
|
@@ -517,9 +552,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
517
552
|
clear_output(wait=True)
|
518
553
|
display(results_df)
|
519
554
|
_save_progress(dst, results_df, train_metrics_df, epoch, epochs)
|
520
|
-
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
521
|
-
print(f'Saved model: {
|
522
|
-
|
555
|
+
model_path = _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
556
|
+
print(f'Saved model: {model_path}')
|
557
|
+
|
558
|
+
return model, model_path
|
523
559
|
|
524
560
|
def visualize_saliency_map(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
|
525
561
|
|
@@ -778,8 +814,32 @@ def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, cha
|
|
778
814
|
smooth_grad_image = Image.fromarray((smooth_grad_map * 255).astype(np.uint8))
|
779
815
|
smooth_grad_image.save(os.path.join(save_dir, f'smooth_grad_{file}'))
|
780
816
|
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
817
|
+
def deep_spacr(settings={}):
|
818
|
+
from .settings import deep_spacr_defaults
|
819
|
+
from .core import generate_training_dataset, generate_dataset, apply_model_to_tar
|
820
|
+
|
821
|
+
settings = deep_spacr_defaults(settings)
|
822
|
+
src = settings['src']
|
823
|
+
|
824
|
+
if settings['train'] or settings['test']:
|
825
|
+
if settings['generate_training_dataset']:
|
826
|
+
print(f"Generating train and test datasets ...")
|
827
|
+
train_path, test_path = generate_training_dataset(settings)
|
828
|
+
print(f'Generated Train set: {train_path}')
|
829
|
+
print(f'Generated Train set: {test_path}')
|
830
|
+
settings['src'] = os.path.dirname(train_path)
|
831
|
+
|
832
|
+
if settings['train_DL_model']:
|
833
|
+
print(f"Training model ...")
|
834
|
+
model_path = train_test_model(settings)
|
835
|
+
settings['model_path'] = model_path
|
836
|
+
settings['src'] = src
|
837
|
+
|
838
|
+
if settings['apply_model_to_dataset']:
|
839
|
+
if not os.path.exists(settings['tar_path']):
|
840
|
+
print(f"Generating dataset ...")
|
841
|
+
tar_path = generate_dataset(settings)
|
842
|
+
settings['tar_path'] = tar_path
|
843
|
+
|
844
|
+
if os.path.exists(settings['model_path']):
|
845
|
+
apply_model_to_tar(settings)
|
spacr/gui.py
CHANGED
@@ -27,7 +27,7 @@ class MainApp(tk.Tk):
|
|
27
27
|
}
|
28
28
|
|
29
29
|
self.additional_gui_apps = {
|
30
|
-
"Sequencing": (lambda frame: initiate_root(self, 'sequencing'), "Analyze sequencing data."),
|
30
|
+
#"Sequencing": (lambda frame: initiate_root(self, 'sequencing'), "Analyze sequencing data."),
|
31
31
|
"Umap": (lambda frame: initiate_root(self, 'umap'), "Generate UMAP embeddings with datapoints represented as images."),
|
32
32
|
"Train Cellpose": (lambda frame: initiate_root(self, 'train_cellpose'), "Train custom Cellpose models."),
|
33
33
|
"ML Analyze": (lambda frame: initiate_root(self, 'ml_analyze'), "Machine learning analysis of data."),
|