spacr 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. spacr/__init__.py +19 -3
  2. spacr/cellpose.py +311 -0
  3. spacr/core.py +140 -2493
  4. spacr/deep_spacr.py +151 -29
  5. spacr/gui.py +1 -0
  6. spacr/gui_core.py +74 -63
  7. spacr/gui_elements.py +110 -5
  8. spacr/gui_utils.py +346 -6
  9. spacr/io.py +624 -44
  10. spacr/logger.py +28 -9
  11. spacr/measure.py +107 -95
  12. spacr/mediar.py +0 -3
  13. spacr/ml.py +964 -0
  14. spacr/openai.py +37 -0
  15. spacr/plot.py +280 -15
  16. spacr/resources/data/lopit.csv +3833 -0
  17. spacr/resources/data/toxoplasma_metadata.csv +8843 -0
  18. spacr/resources/icons/convert.png +0 -0
  19. spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
  20. spacr/sequencing.py +241 -1311
  21. spacr/settings.py +129 -43
  22. spacr/sim.py +0 -2
  23. spacr/submodules.py +348 -0
  24. spacr/timelapse.py +0 -2
  25. spacr/toxo.py +233 -0
  26. spacr/utils.py +271 -171
  27. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/METADATA +7 -1
  28. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/RECORD +32 -33
  29. spacr/chris.py +0 -50
  30. spacr/graph_learning.py +0 -340
  31. spacr/resources/MEDIAR/.git +0 -1
  32. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  33. spacr/resources/icons/.DS_Store +0 -0
  34. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  35. spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  36. spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  37. spacr/sim_app.py +0 -0
  38. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/LICENSE +0 -0
  39. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/WHEEL +0 -0
  40. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/entry_points.txt +0 -0
  41. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- spacr/__init__.py,sha256=bKVlCCJatRBMqjfxSXds2D_Swjf30T6agvAQ0Usz80o,1176
1
+ spacr/__init__.py,sha256=3TNo4PgxHZTHOhyPc8AORvG3tzdPFEc30KAtsOou174,1618
2
2
  spacr/__main__.py,sha256=bkAJJD2kjIqOP-u1kLvct9jQQCeUXzlEjdgitwi1Lm8,75
3
3
  spacr/app_annotate.py,sha256=nEIL7Fle9CDKGo3sucG_03DgjUQt5W1M1IHBIpVBr08,2171
4
4
  spacr/app_classify.py,sha256=urTP_wlZ58hSyM5a19slYlBxN0PdC-9-ga0hvq8CGWc,165
@@ -7,27 +7,28 @@ spacr/app_mask.py,sha256=l-dBY8ftzCMdDe6-pXc2Nh_u-idNL9G7UOARiLJBtds,153
7
7
  spacr/app_measure.py,sha256=_K7APYIeOKpV6e_LcqabBjvEi7mfq9Fch8175x1x0k8,162
8
8
  spacr/app_sequencing.py,sha256=DjG26jy4cpddnV8WOOAIiExtOe9MleVMY4MFa5uTo5w,157
9
9
  spacr/app_umap.py,sha256=ZWAmf_OsIKbYvolYuWPMYhdlVe-n2CADoJulAizMiEo,153
10
- spacr/chris.py,sha256=YlBjSgeZaY8HPy6jkrT_ISAnCMAKVfvCxF0I9eAZLFM,2418
11
- spacr/core.py,sha256=uyZdJ93ysd8oXgRX9b-6iTCldJ0CM1dq5VxF-xbZyN8,150703
12
- spacr/deep_spacr.py,sha256=a2YewgkQvLV-95NYJAutnojvJmX4S8z_wv6Tb-XIgUI,34484
13
- spacr/graph_learning.py,sha256=1tR-ZxvXE3dBz1Saw7BeVFcrsUFu9OlUZeZVifih9eo,13070
14
- spacr/gui.py,sha256=zUkIyAuOwwoMDoExxtI-QHRfOhE1R2rulXJDNxwSLGc,7947
15
- spacr/gui_core.py,sha256=ZUIqvK7x6NzgrmuTRbvwCTTSpU3yWUaId6MZjXv16us,40128
16
- spacr/gui_elements.py,sha256=OA514FUVRKAcdu9CFVOt7UEzn1vztakQ-rDyKqV0b9A,129771
17
- spacr/gui_utils.py,sha256=DCI--DNoYDWY1q0Aohd0XwFqjdPM3K5kCgRKiJGTnfc,30697
18
- spacr/io.py,sha256=VjH_1zXmf0yEdtABnsoabCEIpU0S3wB-7Hog7_ntCdE,117267
19
- spacr/logger.py,sha256=7Zqr3TuuOQLWT32gYr2q1qvv7x0a2JhLANmZcnBXAW8,670
20
- spacr/measure.py,sha256=ooMOP2OE0BHUNqIkg0ltwV2FiO6hZDIcRC6A0YmGcws,54875
21
- spacr/mediar.py,sha256=5HaCyZYiOff74PCvHwKj-jSRua0QRoIv1mvElPfVKtY,14830
22
- spacr/plot.py,sha256=yFC5m54lB8xVnC7vQp50-FvRn6lCjCt2mgi2GeiRwSs,73979
23
- spacr/sequencing.py,sha256=92KmjFa8Ptwmpf-GtyH3-uX6djFOYR5lJjMBHeciqhs,66921
24
- spacr/settings.py,sha256=PfIPLyMyBAfOodtdgNT8QzbysNDxTnsONXdI-fKtIDQ,68038
25
- spacr/sim.py,sha256=FveaVgBi3eypO2oVB5Dx-v0CC1Ny7UPfXkJiiRRodAk,71212
26
- spacr/sim_app.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
- spacr/timelapse.py,sha256=KMYCgHzf9LTZe-lWl5mvH2EjbKRE6OhpwdY13wEumGc,39504
28
- spacr/utils.py,sha256=JGjL_Tg5ec1qmaf9BV3gqlREUpxjK8hutgPEGj5mAEs,189141
10
+ spacr/cellpose.py,sha256=zv4BzhaP2O-mtQ-pUfYvpOyxgn1ke_bDWgdHD5UWm9I,13942
11
+ spacr/core.py,sha256=KNsjpQR_L5mAABrJ4U_JMNsUMav81ILvnjL0Jf2ohU8,43858
12
+ spacr/deep_spacr.py,sha256=HT7x_xHoM8s-AX6aSEHv4mf4wMld9PNDTgzsX-0FqO4,39416
13
+ spacr/gui.py,sha256=ndmWP4F0QO5j6DM6MNzoGtzv_7Yj4LTW2SLi9URBZIQ,8055
14
+ spacr/gui_core.py,sha256=OJQxzpehIyDzjSjIsvxSHat4NIjkqjX0VZAUQTnzEzg,40921
15
+ spacr/gui_elements.py,sha256=3ru8FPZtXCZSj7167GJj18-Zo6TVebhAzkit-mmqmTI,135342
16
+ spacr/gui_utils.py,sha256=76utRICvY0k_6X8CA1P_TmYBJARp4b87OkI9t39tldA,45822
17
+ spacr/io.py,sha256=_P0Rb1ftBlheb9Yd0Bm8py_MeV2bh0ZxDlVRYOZAfBY,144579
18
+ spacr/logger.py,sha256=lJhTqt-_wfAunCPl93xE65Wr9Y1oIHJWaZMjunHUeIw,1538
19
+ spacr/measure.py,sha256=8MRjQdB-2n8JVLjEpF3cxvfT-Udug27uJ2ErJJ5t1ic,56000
20
+ spacr/mediar.py,sha256=FwLvbLQW5LQzPgvJZG8Lw7GniA2vbZx6Jv6vIKu7I5c,14743
21
+ spacr/ml.py,sha256=uf43GIxxgSasq9OiWQYnQO0fV3d5yTPEHkV78jHb-i4,42540
22
+ spacr/openai.py,sha256=5vBZ3Jl2llYcW3oaTEXgdyCB2aJujMUIO5K038z7w_A,1246
23
+ spacr/plot.py,sha256=kPVUlMaRSH2zZIm64o8FmhHILMCRRmzs-uFVPQEIupw,85281
24
+ spacr/sequencing.py,sha256=t18mgpK6rhWuB1LtFOsPxqgpFXxuUmrD06ecsaVQ0Gw,19655
25
+ spacr/settings.py,sha256=IaRnHXRVrkOOfA0ymavXOJTjPE56_BdOPrbEHq7oONQ,73508
26
+ spacr/sim.py,sha256=1xKhXimNU3ukzIw-3l9cF3Znc_brW8h20yv8fSTzvss,71173
27
+ spacr/submodules.py,sha256=ojBFEnOZ_YTGcOvSFEAPP6J7-w08QIAFJLEPYjUCkMM,18337
28
+ spacr/timelapse.py,sha256=FSYpUtAVy6xc3lwprRYgyDTT9ysUhfRQ4zrP9_h2mvg,39465
29
+ spacr/toxo.py,sha256=nkV8w1wvaEsDFoSIOVvU5UcIIaTngHfgKNhhZYzBDyY,9893
30
+ spacr/utils.py,sha256=nW4s7wy4spk3T3yYTQGLvonDz5KX6t_d0NKiHcqMPN0,193044
29
31
  spacr/version.py,sha256=axH5tnGwtgSnJHb5IDhiu4Zjk5GhLyAEDRe-rnaoFOA,409
30
- spacr/resources/MEDIAR/.git,sha256=nHbNNUgehWnXyS2LbJZitX4kbpd1urzYgE0WZYvdMfc,53
31
32
  spacr/resources/MEDIAR/.gitignore,sha256=Ff1q9Nme14JUd-4Q3jZ65aeQ5X4uttptssVDgBVHYo8,152
32
33
  spacr/resources/MEDIAR/LICENSE,sha256=yEj_TRDLUfDpHDNM0StALXIt6mLqSgaV2hcCwa6_TcY,1065
33
34
  spacr/resources/MEDIAR/README.md,sha256=TlL2XhmmNhYTtaBlMCnlJRW-K7qOVeqH1ABLabZAe2k,11877
@@ -79,7 +80,8 @@ spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py,sha256=SalCNvPy
79
80
  spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl,sha256=C1D7NkUZ5er7Kdeyhhwjo0IGUvCsVfKPBzcwfaORd8Q,3762
80
81
  spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py,sha256=UN8BYjraTNNdZUAGjl3yF566ERHAHQvj3GAQ6OETUOI,3615
81
82
  spacr/resources/MEDIAR/train_tools/models/__init__.py,sha256=CkY6rZxr-c9XxXNpQbYUYvHXDpf9E6rUmY1bQ47aEP8,28
82
- spacr/resources/MEDIAR_weights/.DS_Store,sha256=1lFlJ5EFymdzGAUAaI30vcaaLHt3F1LwpG7xILf9jsM,6148
83
+ spacr/resources/data/lopit.csv,sha256=ERI5f9W8RdJGiSx_khoaylD374f8kmvLia1xjhD_mII,4421709
84
+ spacr/resources/data/toxoplasma_metadata.csv,sha256=9TXx0VlClDHAxQmaLhoklE8NuETduXaGHZjhR_6lZfs,2969409
83
85
  spacr/resources/font/open_sans/OFL.txt,sha256=bGMoWBRrE2RcdzDiuYiB8A9OVFlJ0sA2imWwce2DAdo,4484
84
86
  "spacr/resources/font/open_sans/OpenSans-Italic-VariableFont_wdth,wght.ttf",sha256=QSoWv9h46CRX_fdlqFM3O2d3-PF3R1srnb4zUezcLm0,580280
85
87
  "spacr/resources/font/open_sans/OpenSans-VariableFont_wdth,wght.ttf",sha256=E3RLvAefD0kuT7OxShXSQrjZYA-qzUI9WM35N_6nzms,529700
@@ -120,13 +122,14 @@ spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-MediumItalic.ttf,sh
120
122
  spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Regular.ttf,sha256=skg4DCl15zL9ZD4MAL9fOt4WjonKYBUOMj46ItSAe5Q,130848
121
123
  spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBold.ttf,sha256=uCiR97jg6sUHtGKVPNtJEg1zZG5Y9ArQ-raqBGjaeGg,130856
122
124
  spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBoldItalic.ttf,sha256=a5-0oOIrtJltQRa64uFKCdtcjzPvEJ71f_cYavG2i3E,137132
123
- spacr/resources/icons/.DS_Store,sha256=1lFlJ5EFymdzGAUAaI30vcaaLHt3F1LwpG7xILf9jsM,6148
124
125
  spacr/resources/icons/abort.png,sha256=avtIRT7aCJsdZ1WnY_rZStm6cCji5bYPLnlptdcTNcM,6583
125
126
  spacr/resources/icons/annotate.png,sha256=GFgh7DiUMwPG_-xE6W1qU8V_qzSwBi1xKenfoaQxeFA,15495
126
127
  spacr/resources/icons/cellpose_all.png,sha256=HVWOIOBF8p3-On-2UahwMyQXp7awsoC5yWExU1ahDag,20271
127
128
  spacr/resources/icons/cellpose_masks.png,sha256=HVWOIOBF8p3-On-2UahwMyQXp7awsoC5yWExU1ahDag,20271
128
129
  spacr/resources/icons/classify.png,sha256=-iv4sqAwUVJO3CG6fHKHf3_BB0s-I2i4prg-iR7dSBM,35897
130
+ spacr/resources/icons/convert.png,sha256=vLyTkQeUZ9q-pirhtZeXDq3-DzfjoPMjLlgKl5Wv6R0,7069
129
131
  spacr/resources/icons/default.png,sha256=KoNhaSHukO4wDyivyYEgSbb5mGj-sAxmhKikLLtNpWs,20341
132
+ spacr/resources/icons/dna_matrix.mp4,sha256=NegOQkn4q4kHhFgqcIX2dd58wVytBtnkmbgg0ZegL8U,23462876
130
133
  spacr/resources/icons/download.png,sha256=1nUoWRaTc4vIsK6gompdeqk0cIv2GdH-gCNHaEBX6Mc,20467
131
134
  spacr/resources/icons/logo.pdf,sha256=VB4cS41V3VV_QxD7l6CwdQKQiYLErugLBxWoCoxjQU0,377925
132
135
  spacr/resources/icons/logo_spacr.png,sha256=qG3e3bdrAefhl1281rfo0R2XP0qA-c-oaBCXjxMGXkw,42587
@@ -142,18 +145,14 @@ spacr/resources/icons/regression.png,sha256=WIrKY4fSojBOCDkHno4Qb-KH7jcHh6G67dOK
142
145
  spacr/resources/icons/run.png,sha256=ICzyAvsRBCXNAbdn5N3PxCxxVyqxkfC4zOI5Zc8vbxQ,8974
143
146
  spacr/resources/icons/sequencing.png,sha256=P9E_Y76ZysWMKst3_hAw-_4F510XPW1l1TsDElVzt4o,17775
144
147
  spacr/resources/icons/settings.png,sha256=y5Ow5BxJDDsrqom0VNbOMDGGUs6odxbSMDy6y4r_F0w,22269
145
- spacr/resources/icons/spacr_logo_rotation.gif,sha256=bgIx1Hx41Ob90SY-q3PBa3CSxtVRnF9XX-ApUSr0wvY,1502560
146
148
  spacr/resources/icons/train_cellpose.png,sha256=_PZ_R_B6azuUACmscScAkugmgLZvCPKQFGIAsszqNLk,3858
147
149
  spacr/resources/icons/umap.png,sha256=dOLF3DeLYy9k0nkUybiZMe1wzHQwLJFRmgccppw-8bI,27457
148
150
  spacr/resources/images/plate1_E01_T0001F001L01A01Z01C02.tif,sha256=Tl0ZUfZ_AYAbu0up_nO0tPRtF1BxXhWQ3T3pURBCCRo,7958528
149
151
  spacr/resources/images/plate1_E01_T0001F001L01A02Z01C01.tif,sha256=m8N-V71rA1TT4dFlENNg8s0Q0YEXXs8slIn7yObmZJQ,7958528
150
152
  spacr/resources/images/plate1_E01_T0001F001L01A03Z01C03.tif,sha256=Pbhk7xn-KUP6RSIhJsxQcrHFImBm3GEpLkzx7WOc-5M,7958528
151
- spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model,sha256=z8BbHWZPRnE9D_BHO0fBREE85c1vkltDs-incs2ytXQ,26566572
152
- spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv,sha256=fBAGuL_B8ERVdVizO3BHozTDSbZUh1yFzsYK3wkQN68,420
153
- spacr/resources/models/cp/toxo_pv_lumen.CP_model,sha256=2y_CindYhmTvVwBH39SNILF3rI3x9SsRn6qrMxHy3l0,26562451
154
- spacr-0.3.1.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
155
- spacr-0.3.1.dist-info/METADATA,sha256=Fi3B8Vxgz4IzfYodV9zMgRdFLXTtVtel4OCdE06o5cI,5646
156
- spacr-0.3.1.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
157
- spacr-0.3.1.dist-info/entry_points.txt,sha256=BMC0ql9aNNpv8lUZ8sgDLQMsqaVnX5L535gEhKUP5ho,296
158
- spacr-0.3.1.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
159
- spacr-0.3.1.dist-info/RECORD,,
153
+ spacr-0.3.2.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
154
+ spacr-0.3.2.dist-info/METADATA,sha256=T4leZtOLZyUy19ueD2g_TmAQdTRCbkS29QJXcVf6ciw,5837
155
+ spacr-0.3.2.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
156
+ spacr-0.3.2.dist-info/entry_points.txt,sha256=BMC0ql9aNNpv8lUZ8sgDLQMsqaVnX5L535gEhKUP5ho,296
157
+ spacr-0.3.2.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
158
+ spacr-0.3.2.dist-info/RECORD,,
spacr/chris.py DELETED
@@ -1,50 +0,0 @@
1
- import pandas as pd
2
- import numpy as np
3
- from .core import _permutation_importance, _shap_analysis
4
-
5
- def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
6
-
7
- from .io import _read_and_merge_data, _read_db
8
-
9
- db_loc = [src+'/measurements/measurements.db']
10
- loc = src+'/measurements/measurements.db'
11
- df, _ = _read_and_merge_data(db_loc,
12
- tables,
13
- verbose=True,
14
- include_multinucleated=True,
15
- include_multiinfected=True,
16
- include_noninfected=True)
17
-
18
- paths_df = _read_db(loc, tables=['png_list'])
19
-
20
- merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
21
-
22
- return merged_df
23
-
24
- def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
25
- from .io import _read_and_merge_data
26
- from .plot import _plot_plates
27
-
28
- db_loc = [src+'/measurements/measurements.db']
29
- tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
30
- include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
31
-
32
- df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
33
-
34
- if not channel_of_interest is None:
35
- df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
36
- feature_string = f'channel_{channel_of_interest}'
37
- else:
38
- feature_string = None
39
-
40
- output = _permutation_importance(df, feature_string, col_to_compare, pos, neg, exclude, n_repeats, clean, nr_to_plot, n_estimators=n_estimators, random_state=42, model_type=model_type, n_jobs=n_jobs)
41
-
42
- _shap_analysis(output[3], output[4], output[5])
43
-
44
- features = output[0].select_dtypes(include=[np.number]).columns.tolist()
45
-
46
- if not variable in features:
47
- raise ValueError(f"Variable {variable} not found in the dataframe. Please choose one of the following: {features}")
48
-
49
- plate_heatmap = _plot_plates(output[0], variable, grouping, min_max, cmap, min_count)
50
- return [output, plate_heatmap]
spacr/graph_learning.py DELETED
@@ -1,340 +0,0 @@
1
- import os
2
- os.environ['DGLBACKEND'] = 'pytorch'
3
- import torch, dgl
4
- import pandas as pd
5
- import torch.nn as nn
6
- from torchvision import datasets, transforms
7
- from sklearn.preprocessing import StandardScaler
8
- from PIL import Image
9
- import dgl.nn.pytorch as dglnn
10
- from sklearn.datasets import make_classification
11
- from .utils import SelectChannels
12
- from IPython.display import display
13
-
14
- # approach outline
15
- #
16
- # 1. Data Preparation:
17
- # Test Mode: Load MNIST data and generate synthetic gRNA data.
18
- # Real Data: Load image paths and sequencing data as fractions.
19
- #
20
- # 2. Graph Construction:
21
- # Each well is represented as a graph.
22
- # Each graph has cell nodes (with image features) and gRNA nodes (with gRNA fraction features).
23
- # Each cell node is connected to each gRNA node within the same well.
24
- #
25
- # 3. Model Training:
26
- # Use an encoder-decoder architecture with the Graph Transformer model.
27
- # The encoder processes the cell and gRNA nodes.
28
- # The decoder outputs the phenotype score for each cell node.
29
- # The model is trained on all wells (including positive and negative controls).
30
- # The model learns to score the gRNA in column 1 (negative control) as 0 and the gRNA in column 2 (positive control) as 1 based on the cell features.
31
- #
32
- # 4. Model Application:
33
- # Apply the trained model to all wells to get classification probabilities.
34
- #
35
- # 5. Evaluation:
36
- # Evaluate the model's performance using the control wells.
37
- #
38
- # 6. Association Analysis:
39
- # Analyze the association between gRNAs and the classification scores.
40
- #
41
- # The model learns the associations between cell features and phenotype scores based on the controls and then generalizes this learning to the screening wells.
42
-
43
- # Load MNIST data for testing
44
- def load_mnist_data():
45
- transform = transforms.Compose([
46
- transforms.Resize((28, 28)),
47
- transforms.ToTensor(),
48
- transforms.Normalize((0.1307,), (0.3081,))
49
- ])
50
- mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
51
- mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
52
- return mnist_train, mnist_test
53
-
54
- # Generate synthetic gRNA data
55
- def generate_synthetic_grna_data(n_samples, n_features):
56
- X, y = make_classification(n_samples=n_samples, n_features=n_features, n_informative=5, n_redundant=0, n_classes=2, random_state=42)
57
- synthetic_data = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(n_features)])
58
- synthetic_data['label'] = y
59
- return synthetic_data
60
-
61
- # Preprocess image
62
- def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=True):
63
-
64
- if normalize:
65
- preprocess = transforms.Compose([
66
- transforms.ToTensor(),
67
- transforms.CenterCrop(size=(image_size, image_size)),
68
- SelectChannels(channels),
69
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
70
- else:
71
- preprocess = transforms.Compose([
72
- transforms.ToTensor(),
73
- transforms.CenterCrop(size=(image_size, image_size)),
74
- SelectChannels(channels)])
75
-
76
- image = Image.open(image_path).convert('RGB')
77
- return preprocess(image)
78
-
79
- def extract_metadata_from_path(path):
80
- """
81
- Extract metadata from the image path.
82
- The path format is expected to be plate_well_field_objectnumber.png
83
-
84
- Parameters:
85
- path (str): The path to the image file.
86
-
87
- Returns:
88
- dict: A dictionary with the extracted metadata.
89
- """
90
- filename = os.path.basename(path)
91
- name, ext = os.path.splitext(filename)
92
-
93
- # Ensure the file has the correct extension
94
- if ext.lower() != '.png':
95
- raise ValueError("Expected a .png file")
96
-
97
- # Split the name by underscores
98
- parts = name.split('_')
99
- if len(parts) != 4:
100
- raise ValueError("Expected filename format: plate_well_field_objectnumber.png")
101
-
102
- plate, well, field, object_number = parts
103
-
104
- return {'plate': plate, 'well': well,'field': field, 'object_number': object_number}
105
-
106
- # Load images
107
- def load_images(image_paths, image_size=224, channels=[1,2,3], normalize=True):
108
- images = []
109
- metadata_list = []
110
- for path in image_paths:
111
- image = preprocess_image(path, image_size, channels, normalize)
112
- images.append(image)
113
- metadata = extract_metadata_from_path(path) # Extract metadata from image path or database
114
- metadata_list.append(metadata)
115
- return torch.stack(images), metadata_list
116
-
117
- # Normalize sequencing data
118
- def normalize_sequencing_data(sequencing_data):
119
- scaler = StandardScaler()
120
- sequencing_data.iloc[:, 2:] = scaler.fit_transform(sequencing_data.iloc[:, 2:])
121
- return sequencing_data
122
-
123
- # Construct graph for each well
124
- def construct_well_graph(images, image_metadata, grna_data):
125
- cell_nodes = len(images)
126
- grna_nodes = grna_data.shape[0]
127
-
128
- graph = dgl.DGLGraph()
129
- graph.add_nodes(cell_nodes + grna_nodes)
130
-
131
- cell_features = torch.stack(images)
132
- grna_features = torch.tensor(grna_data).float()
133
-
134
- features = torch.cat([cell_features, grna_features], dim=0)
135
- graph.ndata['features'] = features
136
-
137
- for i in range(cell_nodes):
138
- for j in range(cell_nodes, cell_nodes + grna_nodes):
139
- graph.add_edge(i, j)
140
- graph.add_edge(j, i)
141
-
142
- return graph
143
-
144
- def create_graphs_for_wells(images, metadata_list, sequencing_data):
145
- graphs = []
146
- labels = []
147
-
148
- for well in sequencing_data['well'].unique():
149
- well_images = [img for img, meta in zip(images, metadata_list) if meta['well'] == well]
150
- well_metadata = [meta for meta in metadata_list if meta['well'] == well]
151
- well_grna_data = sequencing_data[sequencing_data['well'] == well].iloc[:, 2:].values
152
-
153
- graph = construct_well_graph(well_images, well_metadata, well_grna_data)
154
- graphs.append(graph)
155
-
156
- if well_metadata[0]['column'] == 1: # Negative control
157
- labels.append(0)
158
- elif well_metadata[0]['column'] == 2: # Positive control
159
- labels.append(1)
160
- else:
161
- labels.append(-1) # Screen wells, will be used for evaluation
162
-
163
- return graphs, labels
164
-
165
- # Define Encoder-Decoder Transformer Model
166
- class Encoder(nn.Module):
167
- def __init__(self, in_feats, hidden_feats):
168
- super(Encoder, self).__init__()
169
- self.conv1 = dglnn.GraphConv(in_feats, hidden_feats)
170
- self.conv2 = dglnn.GraphConv(hidden_feats, hidden_feats)
171
-
172
- def forward(self, g, features):
173
- x = self.conv1(g, features)
174
- x = torch.relu(x)
175
- x = self.conv2(g, x)
176
- x = torch.relu(x)
177
- return x
178
-
179
- class Decoder(nn.Module):
180
- def __init__(self, hidden_feats, out_feats):
181
- super(Decoder, self).__init__()
182
- self.linear = nn.Linear(hidden_feats, out_feats)
183
-
184
- def forward(self, x):
185
- return self.linear(x)
186
-
187
- class GraphTransformer(nn.Module):
188
- def __init__(self, in_feats, hidden_feats, out_feats):
189
- super(GraphTransformer, self).__init__()
190
- self.encoder = Encoder(in_feats, hidden_feats)
191
- self.decoder = Decoder(hidden_feats, out_feats)
192
-
193
- def forward(self, g, features):
194
- x = self.encoder(g, features)
195
- with g.local_scope():
196
- g.ndata['h'] = x
197
- hg = dgl.mean_nodes(g, 'h')
198
- return self.decoder(hg)
199
-
200
- def train(graphs, labels, model, loss_fn, optimizer, epochs=100):
201
- for epoch in range(epochs):
202
- model.train()
203
- total_loss = 0
204
- correct = 0
205
- total = 0
206
-
207
- for graph, label in zip(graphs, labels):
208
- if label == -1:
209
- continue # Skip screen wells for training
210
-
211
- features = graph.ndata['features']
212
- logits = model(graph, features)
213
- loss = loss_fn(logits, torch.tensor([label]))
214
-
215
- optimizer.zero_grad()
216
- loss.backward()
217
- optimizer.step()
218
-
219
- total_loss += loss.item()
220
- _, predicted = torch.max(logits, 1)
221
- correct += (predicted == label).sum().item()
222
- total += 1
223
-
224
- accuracy = correct / total if total > 0 else 0
225
- print(f'Epoch {epoch}, Loss: {total_loss / total:.4f}, Accuracy: {accuracy * 100:.2f}%')
226
-
227
- def apply_model(graphs, model):
228
- model.eval()
229
- results = []
230
-
231
- with torch.no_grad():
232
- for graph in graphs:
233
- features = graph.ndata['features']
234
- logits = model(graph, features)
235
- probabilities = torch.softmax(logits, dim=1)
236
- results.append(probabilities[:, 1].item())
237
-
238
- return results
239
-
240
- def analyze_associations(probabilities, sequencing_data):
241
- # Analyze associations between gRNAs and classification scores
242
- sequencing_data['positive_prob'] = probabilities
243
- return sequencing_data.groupby('gRNA').positive_prob.mean().sort_values(ascending=False)
244
-
245
- def process_sequencing_df(seq):
246
-
247
- if isinstance(seq, pd.DataFrame):
248
- sequencing_df = seq
249
- elif isinstance(seq, str):
250
- sequencing_df = pd.read_csv(seq)
251
-
252
- # Check if 'plate_row' column exists and split into 'plate' and 'row'
253
- if 'plate_row' in sequencing_df.columns:
254
- sequencing_df[['plate', 'row']] = sequencing_df['plate_row'].str.split('_', expand=True)
255
-
256
- # Check if 'plate', 'row' and 'col' or 'plate', 'row' and 'column' exist
257
- if {'plate', 'row', 'col'}.issubset(sequencing_df.columns) or {'plate', 'row', 'column'}.issubset(sequencing_df.columns):
258
- if 'col' in sequencing_df.columns:
259
- sequencing_df['prc'] = sequencing_df[['plate', 'row', 'col']].agg('_'.join, axis=1)
260
- elif 'column' in sequencing_df.columns:
261
- sequencing_df['prc'] = sequencing_df[['plate', 'row', 'column']].agg('_'.join, axis=1)
262
-
263
- # Check if 'count', 'total_reads', 'read_fraction', 'grna' exist and create new dataframe
264
- if {'count', 'total_reads', 'read_fraction', 'grna'}.issubset(sequencing_df.columns):
265
- new_df = sequencing_df[['grna', 'prc', 'count', 'total_reads', 'read_fraction']]
266
- return new_df
267
-
268
- return sequencing_df
269
-
270
- def train_graph_transformer(src, lr=0.01, epochs=100, hidden_feats=128, n_classes=2, row_limit=None, image_size=224, channels=[1,2,3], normalize=True, test_mode=False):
271
- if test_mode:
272
- # Load MNIST data
273
- mnist_train, mnist_test = load_mnist_data()
274
-
275
- # Generate synthetic gRNA data
276
- synthetic_grna_data = generate_synthetic_grna_data(len(mnist_train), 10) # 10 synthetic features
277
- sequencing_data = synthetic_grna_data
278
-
279
- # Load MNIST images and metadata
280
- images = []
281
- metadata_list = []
282
- for idx, (img, label) in enumerate(mnist_train):
283
- images.append(img)
284
- metadata_list.append({'index': idx, 'plate': 'plate1', 'well': idx, 'column': label})
285
- images = torch.stack(images)
286
-
287
- # Normalize synthetic sequencing data
288
- sequencing_data = normalize_sequencing_data(sequencing_data)
289
- else:
290
- from .io import _read_and_join_tables
291
- from .utils import get_db_paths, get_sequencing_paths, correct_paths
292
-
293
- db_paths = get_db_paths(src)
294
- seq_paths = get_sequencing_paths(src)
295
-
296
- if isinstance(src, str):
297
- src = [src]
298
-
299
- sequencing_data = pd.DataFrame()
300
- for seq in seq_paths:
301
- sequencing_df = pd.read_csv(seq)
302
- sequencing_df = process_sequencing_df(sequencing_df)
303
- sequencing_data = pd.concat([sequencing_data, sequencing_df], axis=0)
304
-
305
- all_df = pd.DataFrame()
306
- image_paths = []
307
- for i, db_path in enumerate(db_paths):
308
- df = _read_and_join_tables(db_path, table_names=['png_list'])
309
- df, image_paths_tmp = correct_paths(df, src[i])
310
- all_df = pd.concat([all_df, df], axis=0)
311
- image_paths.extend(image_paths_tmp)
312
-
313
- if row_limit is not None:
314
- all_df = all_df.sample(n=row_limit, random_state=42)
315
-
316
- images, metadata_list = load_images(image_paths, image_size, channels, normalize)
317
- sequencing_data = normalize_sequencing_data(sequencing_data)
318
-
319
- # Step 1: Create graphs for each well
320
- graphs, labels = create_graphs_for_wells(images, metadata_list, sequencing_data)
321
-
322
- # Step 2: Train Graph Transformer Model
323
- in_feats = graphs[0].ndata['features'].shape[1]
324
- model = GraphTransformer(in_feats, hidden_feats, n_classes)
325
- loss_fn = nn.CrossEntropyLoss()
326
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
327
-
328
- # Train the model
329
- train(graphs, labels, model, loss_fn, optimizer, epochs)
330
-
331
- # Step 3: Apply the model to all wells (including screen wells)
332
- screen_graphs = [graph for graph, label in zip(graphs, labels) if label == -1]
333
- probabilities = apply_model(screen_graphs, model)
334
-
335
- # Step 4: Analyze associations between gRNAs and classification scores
336
- associations = analyze_associations(probabilities, sequencing_data)
337
- print("Top associated gRNAs with positive control phenotype:")
338
- print(associations.head())
339
-
340
- return model, associations
@@ -1 +0,0 @@
1
- gitdir: ../../../.git/modules/spacr/resources/MEDIAR
Binary file
Binary file
@@ -1,23 +0,0 @@
1
- Key,Value
2
- img_src,/nas_mnt/carruthers/patrick/Plaque_assay_training/train
3
- model_name,toxo_plaque
4
- model_type,cyto
5
- Signal_to_noise,10
6
- background,200
7
- remove_background,False
8
- learning_rate,0.2
9
- weight_decay,1e-05
10
- batch_size,8
11
- n_epochs,25000
12
- from_scratch,False
13
- diameter,30
14
- resize,True
15
- width_height,"[1120, 1120]"
16
- verbose,True
17
- channels,"[0, 0]"
18
- normalize,True
19
- percentiles,
20
- circular,False
21
- invert,False
22
- grayscale,True
23
- test,False
spacr/sim_app.py DELETED
File without changes
File without changes
File without changes