spacr 0.0.81__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/measure.py CHANGED
@@ -920,7 +920,8 @@ def measure_crop(settings):
920
920
  from .io import _save_settings_to_db
921
921
  from .timelapse import _timelapse_masks_to_gif, _scmovie
922
922
  from .plot import _save_scimg_plot
923
- from .utils import _list_endpoint_subdirectories, _generate_representative_images, get_measure_crop_settings, measure_test_mode
923
+ from .utils import _list_endpoint_subdirectories, _generate_representative_images, measure_test_mode
924
+ from .settings import get_measure_crop_settings
924
925
 
925
926
  settings = get_measure_crop_settings(settings)
926
927
  settings = measure_test_mode(settings)
@@ -0,0 +1,23 @@
1
+ Key,Value
2
+ img_src,/nas_mnt/carruthers/patrick/Plaque_assay_training/train
3
+ model_name,toxo_plaque
4
+ model_type,cyto
5
+ Signal_to_noise,10
6
+ background,200
7
+ remove_background,False
8
+ learning_rate,0.2
9
+ weight_decay,1e-05
10
+ batch_size,8
11
+ n_epochs,25000
12
+ from_scratch,False
13
+ diameter,30
14
+ resize,True
15
+ width_height,"[1120, 1120]"
16
+ verbose,True
17
+ channels,"[0, 0]"
18
+ normalize,True
19
+ percentiles,
20
+ circular,False
21
+ invert,False
22
+ grayscale,True
23
+ test,False
spacr/plot.py CHANGED
@@ -1110,7 +1110,7 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
1110
1110
  idx = i * n_col + j
1111
1111
  if idx < n_images:
1112
1112
  canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = np.transpose(img[idx], (1, 2, 0))
1113
- plt.figure(figsize=(50, 50))
1113
+ fig = plt.figure(figsize=(50, 50))
1114
1114
  plt.imshow(canvas)
1115
1115
  plt.axis("off")
1116
1116
  for i, label in enumerate(labels):
@@ -1119,7 +1119,7 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
1119
1119
  x = col * img_width + 2
1120
1120
  y = row * img_height + 15
1121
1121
  plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
1122
- plt.show()
1122
+ return fig
1123
1123
 
1124
1124
  def _plot_histograms_and_stats(df):
1125
1125
  conditions = df['condition'].unique()
@@ -1218,8 +1218,7 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
1218
1218
  df = df[df['count'] >= min_count]
1219
1219
 
1220
1220
  # Explicitly set observed=True to avoid FutureWarning
1221
- grouped = df.groupby(['row', 'col'], observed=True)
1222
-
1221
+ grouped = df.groupby(['row', 'col'], observed=True) # Group by row and column
1223
1222
 
1224
1223
  if grouping == 'mean':
1225
1224
  plate = grouped[variable].mean().reset_index()
@@ -1245,7 +1244,7 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
1245
1244
 
1246
1245
  return plate_map, min_max
1247
1246
 
1248
- def _plot_plates(df, variable, grouping, min_max, cmap, min_count=0):
1247
+ def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True):
1249
1248
  plates = df['prc'].str.split('_', expand=True)[0].unique()
1250
1249
  n_rows, n_cols = (len(plates) + 3) // 4, 4
1251
1250
  fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
@@ -1260,7 +1259,8 @@ def _plot_plates(df, variable, grouping, min_max, cmap, min_count=0):
1260
1259
  fig.delaxes(ax[i])
1261
1260
 
1262
1261
  plt.subplots_adjust(wspace=0.1, hspace=0.4)
1263
- plt.show()
1262
+ if verbose:
1263
+ plt.show()
1264
1264
  return fig
1265
1265
 
1266
1266
  def print_mask_and_flows(stack, mask, flows, overlay=False):
@@ -1470,3 +1470,99 @@ def plot_object_outlines(src, objects=['nucleus','cell','pathogen'], channels=[0
1470
1470
  overlay=True,
1471
1471
  max_nr=10,
1472
1472
  randomize=True)
1473
+
1474
+ def volcano_plot(coef_df, filename='volcano_plot.pdf'):
1475
+ # Create the volcano plot
1476
+ plt.figure(figsize=(10, 6))
1477
+ sns.scatterplot(
1478
+ data=coef_df,
1479
+ x='coefficient',
1480
+ y='-log10(p_value)',
1481
+ hue='highlight',
1482
+ palette={True: 'red', False: 'blue'}
1483
+ )
1484
+ plt.title('Volcano Plot of Coefficients')
1485
+ plt.xlabel('Coefficient')
1486
+ plt.ylabel('-log10(p-value)')
1487
+ plt.axhline(y=-np.log10(0.05), color='red', linestyle='--')
1488
+ plt.legend().remove()
1489
+ plt.savefig(filename, format='pdf')
1490
+ print(f'Saved Volcano plot: {filename}')
1491
+ plt.show()
1492
+
1493
+ def plot_histogram(df, dependent_variable):
1494
+ # Plot histogram of the dependent variable
1495
+ plt.figure(figsize=(10, 6))
1496
+ sns.histplot(df[dependent_variable], kde=True)
1497
+ plt.title(f'Histogram of {dependent_variable}')
1498
+ plt.xlabel(dependent_variable)
1499
+ plt.ylabel('Frequency')
1500
+ plt.show()
1501
+
1502
+ def plot_lorenz_curves(csv_files, remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4']):
1503
+
1504
+ def lorenz_curve(data):
1505
+ """Calculate Lorenz curve."""
1506
+ sorted_data = np.sort(data)
1507
+ cumulative_data = np.cumsum(sorted_data)
1508
+ lorenz_curve = cumulative_data / cumulative_data[-1]
1509
+ lorenz_curve = np.insert(lorenz_curve, 0, 0)
1510
+ return lorenz_curve
1511
+
1512
+ combined_data = []
1513
+
1514
+ plt.figure(figsize=(10, 6))
1515
+
1516
+ for idx, csv_file in enumerate(csv_files):
1517
+ if idx == 1:
1518
+ save_fldr = os.path.dirname(csv_file)
1519
+ save_path = os.path.join(save_fldr, 'lorenz_curve.pdf')
1520
+
1521
+ df = pd.read_csv(csv_file)
1522
+ for remove in remove_keys:
1523
+ df = df[df['key'] != remove]
1524
+
1525
+ values = df['value'].values
1526
+ combined_data.extend(values)
1527
+
1528
+ lorenz = lorenz_curve(values)
1529
+ name = os.path.basename(csv_file)[:3]
1530
+ plt.plot(np.linspace(0, 1, len(lorenz)), lorenz, label=name)
1531
+
1532
+ # Plot combined Lorenz curve
1533
+ combined_lorenz = lorenz_curve(np.array(combined_data))
1534
+ plt.plot(np.linspace(0, 1, len(combined_lorenz)), combined_lorenz, label="Combined Lorenz Curve", linestyle='--', color='black')
1535
+
1536
+ plt.title('Lorenz Curves')
1537
+ plt.xlabel('Cumulative Share of Individuals')
1538
+ plt.ylabel('Cumulative Share of Value')
1539
+ plt.legend()
1540
+ plt.grid(False)
1541
+ plt.savefig(save_path)
1542
+ plt.show()
1543
+
1544
+ def plot_permutation(permutation_df):
1545
+ num_features = len(permutation_df)
1546
+ fig_height = max(8, num_features * 0.3) # Set a minimum height of 8 and adjust height based on number of features
1547
+ fig_width = 10 # Width can be fixed or adjusted similarly
1548
+ font_size = max(10, 12 - num_features * 0.2) # Adjust font size dynamically
1549
+
1550
+ fig, ax = plt.subplots(figsize=(fig_width, fig_height))
1551
+ ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
1552
+ ax.set_xlabel('Permutation Importance', fontsize=font_size)
1553
+ ax.tick_params(axis='both', which='major', labelsize=font_size)
1554
+ plt.tight_layout()
1555
+ return fig
1556
+
1557
+ def plot_feature_importance(feature_importance_df):
1558
+ num_features = len(feature_importance_df)
1559
+ fig_height = max(8, num_features * 0.3) # Set a minimum height of 8 and adjust height based on number of features
1560
+ fig_width = 10 # Width can be fixed or adjusted similarly
1561
+ font_size = max(10, 12 - num_features * 0.2) # Adjust font size dynamically
1562
+
1563
+ fig, ax = plt.subplots(figsize=(fig_width, fig_height))
1564
+ ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
1565
+ ax.set_xlabel('Feature Importance', fontsize=font_size)
1566
+ ax.tick_params(axis='both', which='major', labelsize=font_size)
1567
+ plt.tight_layout()
1568
+ return fig