napistu 0.2.5.dev7__py3-none-any.whl → 0.3.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. napistu/__init__.py +1 -3
  2. napistu/__main__.py +126 -96
  3. napistu/constants.py +35 -41
  4. napistu/context/__init__.py +10 -0
  5. napistu/context/discretize.py +462 -0
  6. napistu/context/filtering.py +387 -0
  7. napistu/gcs/__init__.py +1 -1
  8. napistu/identifiers.py +74 -15
  9. napistu/indices.py +68 -0
  10. napistu/ingestion/__init__.py +1 -1
  11. napistu/ingestion/bigg.py +47 -62
  12. napistu/ingestion/constants.py +18 -133
  13. napistu/ingestion/gtex.py +113 -0
  14. napistu/ingestion/hpa.py +147 -0
  15. napistu/ingestion/sbml.py +0 -97
  16. napistu/ingestion/string.py +2 -2
  17. napistu/matching/__init__.py +10 -0
  18. napistu/matching/constants.py +18 -0
  19. napistu/matching/interactions.py +518 -0
  20. napistu/matching/mount.py +529 -0
  21. napistu/matching/species.py +510 -0
  22. napistu/mcp/__init__.py +7 -4
  23. napistu/mcp/__main__.py +128 -72
  24. napistu/mcp/client.py +16 -25
  25. napistu/mcp/codebase.py +201 -145
  26. napistu/mcp/component_base.py +170 -0
  27. napistu/mcp/config.py +223 -0
  28. napistu/mcp/constants.py +45 -2
  29. napistu/mcp/documentation.py +253 -136
  30. napistu/mcp/documentation_utils.py +13 -48
  31. napistu/mcp/execution.py +372 -305
  32. napistu/mcp/health.py +47 -65
  33. napistu/mcp/profiles.py +10 -6
  34. napistu/mcp/server.py +161 -80
  35. napistu/mcp/tutorials.py +139 -87
  36. napistu/modify/__init__.py +1 -1
  37. napistu/modify/gaps.py +1 -1
  38. napistu/network/__init__.py +1 -1
  39. napistu/network/constants.py +101 -34
  40. napistu/network/data_handling.py +388 -0
  41. napistu/network/ig_utils.py +351 -0
  42. napistu/network/napistu_graph_core.py +354 -0
  43. napistu/network/neighborhoods.py +40 -40
  44. napistu/network/net_create.py +373 -309
  45. napistu/network/net_propagation.py +47 -19
  46. napistu/network/{net_utils.py → ng_utils.py} +124 -272
  47. napistu/network/paths.py +67 -51
  48. napistu/network/precompute.py +11 -11
  49. napistu/ontologies/__init__.py +10 -0
  50. napistu/ontologies/constants.py +129 -0
  51. napistu/ontologies/dogma.py +243 -0
  52. napistu/ontologies/genodexito.py +649 -0
  53. napistu/ontologies/mygene.py +369 -0
  54. napistu/ontologies/renaming.py +198 -0
  55. napistu/rpy2/__init__.py +229 -86
  56. napistu/rpy2/callr.py +47 -77
  57. napistu/rpy2/constants.py +24 -23
  58. napistu/rpy2/rids.py +61 -648
  59. napistu/sbml_dfs_core.py +587 -222
  60. napistu/scverse/__init__.py +15 -0
  61. napistu/scverse/constants.py +28 -0
  62. napistu/scverse/loading.py +727 -0
  63. napistu/utils.py +118 -10
  64. {napistu-0.2.5.dev7.dist-info → napistu-0.3.1.dev1.dist-info}/METADATA +8 -3
  65. napistu-0.3.1.dev1.dist-info/RECORD +133 -0
  66. tests/conftest.py +22 -0
  67. tests/test_context_discretize.py +56 -0
  68. tests/test_context_filtering.py +267 -0
  69. tests/test_identifiers.py +100 -0
  70. tests/test_indices.py +65 -0
  71. tests/{test_edgelist.py → test_ingestion_napistu_edgelist.py} +2 -2
  72. tests/test_matching_interactions.py +108 -0
  73. tests/test_matching_mount.py +305 -0
  74. tests/test_matching_species.py +394 -0
  75. tests/test_mcp_config.py +193 -0
  76. tests/test_mcp_documentation_utils.py +12 -3
  77. tests/test_mcp_server.py +156 -19
  78. tests/test_network_data_handling.py +397 -0
  79. tests/test_network_ig_utils.py +23 -0
  80. tests/test_network_neighborhoods.py +19 -0
  81. tests/test_network_net_create.py +459 -0
  82. tests/test_network_ng_utils.py +30 -0
  83. tests/test_network_paths.py +56 -0
  84. tests/{test_precomputed_distances.py → test_network_precompute.py} +8 -6
  85. tests/test_ontologies_genodexito.py +58 -0
  86. tests/test_ontologies_mygene.py +39 -0
  87. tests/test_ontologies_renaming.py +110 -0
  88. tests/test_rpy2_callr.py +79 -0
  89. tests/test_rpy2_init.py +151 -0
  90. tests/test_sbml.py +0 -31
  91. tests/test_sbml_dfs_core.py +134 -10
  92. tests/test_scverse_loading.py +778 -0
  93. tests/test_set_coverage.py +2 -2
  94. tests/test_utils.py +121 -1
  95. napistu/mechanism_matching.py +0 -1353
  96. napistu/rpy2/netcontextr.py +0 -467
  97. napistu-0.2.5.dev7.dist-info/RECORD +0 -98
  98. tests/test_igraph.py +0 -367
  99. tests/test_mechanism_matching.py +0 -784
  100. tests/test_net_utils.py +0 -149
  101. tests/test_netcontextr.py +0 -105
  102. tests/test_rpy2.py +0 -61
  103. /napistu/ingestion/{cpr_edgelist.py → napistu_edgelist.py} +0 -0
  104. {napistu-0.2.5.dev7.dist-info → napistu-0.3.1.dev1.dist-info}/WHEEL +0 -0
  105. {napistu-0.2.5.dev7.dist-info → napistu-0.3.1.dev1.dist-info}/entry_points.txt +0 -0
  106. {napistu-0.2.5.dev7.dist-info → napistu-0.3.1.dev1.dist-info}/licenses/LICENSE +0 -0
  107. {napistu-0.2.5.dev7.dist-info → napistu-0.3.1.dev1.dist-info}/top_level.txt +0 -0
  108. /tests/{test_obo.py → test_ingestion_obo.py} +0 -0
@@ -0,0 +1,778 @@
1
+ import pytest
2
+
3
+ import anndata
4
+ import mudata
5
+ import numpy as np
6
+ import pandas as pd
7
+ from scipy import sparse
8
+
9
+ from napistu.scverse import loading
10
+ from napistu.scverse.constants import ADATA, SCVERSE_DEFS
11
+
12
+
13
+ @pytest.fixture
14
+ def minimal_adata():
15
+ """Create a minimal AnnData object for testing."""
16
+ # Create random data
17
+ n_obs, n_vars = 10, 5
18
+ X = np.random.randn(n_obs, n_vars)
19
+
20
+ # Create observation and variable annotations
21
+ obs = pd.DataFrame(
22
+ {"cell_type": ["type_" + str(i) for i in range(n_obs)]},
23
+ index=["cell_" + str(i) for i in range(n_obs)],
24
+ )
25
+ var = pd.DataFrame(
26
+ {
27
+ "gene_name": ["gene_" + str(i) for i in range(n_vars)],
28
+ "ensembl_transcript": [
29
+ f"ENST{i:011d}" for i in range(n_vars)
30
+ ], # Add ensembl_transcript
31
+ },
32
+ index=["gene_" + str(i) for i in range(n_vars)],
33
+ )
34
+
35
+ # Create AnnData object
36
+ adata = anndata.AnnData(X=X, obs=obs, var=var)
37
+
38
+ # Add multiple layers to test table_name specification
39
+ adata.layers["counts"] = np.random.randint(0, 100, size=(n_obs, n_vars))
40
+ adata.layers["normalized"] = np.random.randn(n_obs, n_vars)
41
+
42
+ # Add varm matrix (n_vars × n_features)
43
+ n_features = 3
44
+ varm_array = np.random.randn(n_vars, n_features)
45
+ adata.varm["gene_scores"] = varm_array
46
+ # Store column names separately since varm is a numpy array
47
+ adata.uns["gene_scores_features"] = ["score1", "score2", "score3"]
48
+
49
+ # Add variable pairwise matrices (varp)
50
+ # Dense correlation matrix (n_vars × n_vars)
51
+ correlations = np.random.rand(n_vars, n_vars)
52
+ # Make it symmetric for a correlation matrix
53
+ correlations = (correlations + correlations.T) / 2
54
+ adata.varp["correlations"] = correlations
55
+
56
+ # Sparse adjacency matrix
57
+ adjacency = sparse.random(n_vars, n_vars, density=0.2)
58
+ # Make it symmetric
59
+ adjacency = (adjacency + adjacency.T) / 2
60
+ adata.varp["adjacency"] = adjacency
61
+
62
+ return adata
63
+
64
+
65
+ def test_load_raw_table_success(minimal_adata):
66
+ """Test successful loading of various table types."""
67
+ # Test identity table (X)
68
+ x_result = loading._load_raw_table(minimal_adata, "X")
69
+ assert isinstance(x_result, np.ndarray)
70
+ assert x_result.shape == minimal_adata.X.shape
71
+
72
+ # Test dict-like table with name
73
+ layer_result = loading._load_raw_table(minimal_adata, "layers", "counts")
74
+ assert isinstance(layer_result, np.ndarray)
75
+ assert layer_result.shape == minimal_adata.layers["counts"].shape
76
+
77
+
78
+ def test_load_raw_table_errors(minimal_adata):
79
+ """Test error cases for loading tables."""
80
+ # Test invalid table type
81
+ with pytest.raises(ValueError, match="is not a valid AnnData attribute"):
82
+ loading._load_raw_table(minimal_adata, "invalid_type")
83
+
84
+ # Test missing table name when required
85
+ with pytest.raises(
86
+ ValueError, match="Multiple tables found.*and table_name is not specified"
87
+ ):
88
+ loading._load_raw_table(minimal_adata, "layers")
89
+
90
+
91
+ def test_get_table_from_dict_attr_success(minimal_adata):
92
+ """Test successful retrieval from dict-like attributes."""
93
+ # Test getting specific table
94
+ result = loading._get_table_from_dict_attr(minimal_adata, "varm", "gene_scores")
95
+ assert isinstance(result, np.ndarray)
96
+ assert result.shape == (minimal_adata.n_vars, 3)
97
+
98
+
99
+ def test_get_table_from_dict_attr_errors(minimal_adata):
100
+ """Test error cases for dict-like attribute access."""
101
+ # Test missing table name with multiple tables
102
+ with pytest.raises(
103
+ ValueError, match="Multiple tables found.*and table_name is not specified"
104
+ ):
105
+ loading._get_table_from_dict_attr(minimal_adata, "layers")
106
+
107
+ # Test nonexistent table name
108
+ with pytest.raises(ValueError, match="table_name 'nonexistent' not found"):
109
+ loading._get_table_from_dict_attr(minimal_adata, "layers", "nonexistent")
110
+
111
+
112
+ def test_select_results_attrs_success(minimal_adata):
113
+ """Test successful selection of results attributes."""
114
+ # Test numpy array selection - shape should be (n_obs x n_vars)
115
+ array = np.random.randn(
116
+ minimal_adata.n_obs, minimal_adata.n_vars
117
+ ) # 10x5 to match minimal_adata
118
+ array_results_attrs = minimal_adata.obs.index[:3].tolist()
119
+ array_result = loading._select_results_attrs(
120
+ minimal_adata, array, "X", array_results_attrs
121
+ )
122
+ assert isinstance(array_result, pd.DataFrame)
123
+ assert array_result.shape[0] == minimal_adata.var.shape[0]
124
+ assert len(array_result.columns) == len(array_results_attrs)
125
+ # Check orientation - vars should be rows
126
+ assert list(array_result.index) == minimal_adata.var.index.tolist()
127
+
128
+ # Test varm selection
129
+ varm_results_attrs = ["score1", "score2"]
130
+ varm_features = minimal_adata.uns["gene_scores_features"]
131
+ # Get column indices for the requested features
132
+ varm_col_indices = [varm_features.index(attr) for attr in varm_results_attrs]
133
+ varm_result = loading._select_results_attrs(
134
+ minimal_adata,
135
+ minimal_adata.varm["gene_scores"],
136
+ "varm",
137
+ varm_results_attrs,
138
+ table_colnames=varm_features,
139
+ )
140
+ assert isinstance(varm_result, pd.DataFrame)
141
+ assert varm_result.shape == (
142
+ minimal_adata.n_vars,
143
+ 2,
144
+ ) # All vars x selected features
145
+ assert list(varm_result.columns) == varm_results_attrs
146
+ assert list(varm_result.index) == minimal_adata.var.index.tolist()
147
+ # Check values match original
148
+ np.testing.assert_array_equal(
149
+ varm_result.values, minimal_adata.varm["gene_scores"][:, varm_col_indices]
150
+ )
151
+
152
+ # Test varp selection with dense matrix
153
+ varp_results_attrs = minimal_adata.var.index[:2].tolist() # Select first two genes
154
+ varp_result = loading._select_results_attrs(
155
+ minimal_adata, minimal_adata.varp["correlations"], "varp", varp_results_attrs
156
+ )
157
+ assert isinstance(varp_result, pd.DataFrame)
158
+ assert varp_result.shape == (minimal_adata.n_vars, 2) # All vars x selected genes
159
+ assert list(varp_result.columns) == varp_results_attrs
160
+ assert list(varp_result.index) == minimal_adata.var.index.tolist()
161
+ # Check values match original
162
+ np.testing.assert_array_equal(
163
+ varp_result.values,
164
+ minimal_adata.varp["correlations"][:, :2], # First two columns
165
+ )
166
+
167
+ # Test varp selection with sparse matrix
168
+ sparse_result = loading._select_results_attrs(
169
+ minimal_adata, minimal_adata.varp["adjacency"], "varp", varp_results_attrs
170
+ )
171
+ assert isinstance(sparse_result, pd.DataFrame)
172
+ assert sparse_result.shape == (minimal_adata.n_vars, 2)
173
+ assert list(sparse_result.columns) == varp_results_attrs
174
+ assert list(sparse_result.index) == minimal_adata.var.index.tolist()
175
+
176
+ # Test full table selection (results_attrs=None)
177
+ full_varm_result = loading._select_results_attrs(
178
+ minimal_adata,
179
+ minimal_adata.varm["gene_scores"],
180
+ "varm",
181
+ table_colnames=minimal_adata.uns["gene_scores_features"],
182
+ )
183
+ assert isinstance(full_varm_result, pd.DataFrame)
184
+ assert full_varm_result.shape == (minimal_adata.n_vars, 3)
185
+ assert list(full_varm_result.columns) == minimal_adata.uns["gene_scores_features"]
186
+ assert list(full_varm_result.index) == minimal_adata.var.index.tolist()
187
+
188
+
189
+ def test_select_results_attrs_errors(minimal_adata):
190
+ """Test error cases for results attribute selection."""
191
+ # Test invalid results attributes - shape should match minimal_adata
192
+ array = np.random.randn(minimal_adata.n_obs, minimal_adata.n_vars)
193
+ with pytest.raises(
194
+ ValueError, match="The following results attributes were not found"
195
+ ):
196
+ loading._select_results_attrs(minimal_adata, array, "X", ["nonexistent_attr"])
197
+
198
+ # Test invalid gene names for varp
199
+ with pytest.raises(
200
+ ValueError, match="The following results attributes were not found"
201
+ ):
202
+ loading._select_results_attrs(
203
+ minimal_adata,
204
+ minimal_adata.varp["correlations"],
205
+ "varp",
206
+ results_attrs=["nonexistent_gene"],
207
+ )
208
+
209
+ # Test missing table_colnames for varm
210
+ with pytest.raises(ValueError, match="table_colnames is required for varm tables"):
211
+ loading._select_results_attrs(
212
+ minimal_adata, minimal_adata.varm["gene_scores"], "varm", ["score1"]
213
+ )
214
+
215
+ # Test DataFrame for array-type table
216
+ with pytest.raises(ValueError, match="must be a numpy array, not a DataFrame"):
217
+ loading._select_results_attrs(
218
+ minimal_adata,
219
+ pd.DataFrame(minimal_adata.varm["gene_scores"]),
220
+ "varm",
221
+ ["score1"],
222
+ table_colnames=minimal_adata.uns["gene_scores_features"],
223
+ )
224
+
225
+
226
+ def test_create_results_df(minimal_adata):
227
+ """Test DataFrame creation from different AnnData table types."""
228
+ # Test varm table
229
+ varm_attrs = ["score1", "score2"]
230
+ varm_features = minimal_adata.uns["gene_scores_features"]
231
+ # Get column indices for the requested features
232
+ varm_col_indices = [varm_features.index(attr) for attr in varm_attrs]
233
+ varm_array = minimal_adata.varm["gene_scores"][:, varm_col_indices]
234
+ varm_result = loading._create_results_df(
235
+ array=varm_array,
236
+ attrs=varm_attrs,
237
+ var_index=minimal_adata.var.index,
238
+ table_type=ADATA.VARM,
239
+ )
240
+ assert varm_result.shape == (minimal_adata.n_vars, len(varm_attrs))
241
+ pd.testing.assert_index_equal(varm_result.index, minimal_adata.var.index)
242
+ pd.testing.assert_index_equal(varm_result.columns, pd.Index(varm_attrs))
243
+ np.testing.assert_array_equal(varm_result.values, varm_array)
244
+
245
+ # Test varp table with dense correlations
246
+ varp_attrs = minimal_adata.var.index[:2].tolist() # First two genes
247
+ varp_array = minimal_adata.varp["correlations"][:, :2]
248
+ varp_result = loading._create_results_df(
249
+ array=varp_array,
250
+ attrs=varp_attrs,
251
+ var_index=minimal_adata.var.index,
252
+ table_type=ADATA.VARP,
253
+ )
254
+ assert varp_result.shape == (minimal_adata.n_vars, len(varp_attrs))
255
+ pd.testing.assert_index_equal(varp_result.index, minimal_adata.var.index)
256
+ pd.testing.assert_index_equal(varp_result.columns, pd.Index(varp_attrs))
257
+ np.testing.assert_array_equal(varp_result.values, varp_array)
258
+
259
+ # Test X table
260
+ obs_attrs = minimal_adata.obs.index[:3].tolist() # First three observations
261
+ x_array = minimal_adata.X[0:3, :] # Select first three observations
262
+ x_result = loading._create_results_df(
263
+ array=x_array,
264
+ attrs=obs_attrs,
265
+ var_index=minimal_adata.var.index,
266
+ table_type=ADATA.X,
267
+ )
268
+ assert x_result.shape == (minimal_adata.n_vars, len(obs_attrs))
269
+ pd.testing.assert_index_equal(x_result.index, minimal_adata.var.index)
270
+ pd.testing.assert_index_equal(x_result.columns, pd.Index(obs_attrs))
271
+ np.testing.assert_array_equal(x_result.values, x_array.T)
272
+
273
+ # Test layers table
274
+ layer_attrs = minimal_adata.obs.index[:2].tolist() # First two observations
275
+ layer_array = minimal_adata.layers["counts"][
276
+ 0:2, :
277
+ ] # Select first two observations
278
+ layer_result = loading._create_results_df(
279
+ array=layer_array,
280
+ attrs=layer_attrs,
281
+ var_index=minimal_adata.var.index,
282
+ table_type=ADATA.LAYERS,
283
+ )
284
+ assert layer_result.shape == (minimal_adata.n_vars, len(layer_attrs))
285
+ pd.testing.assert_index_equal(layer_result.index, minimal_adata.var.index)
286
+ pd.testing.assert_index_equal(layer_result.columns, pd.Index(layer_attrs))
287
+ np.testing.assert_array_equal(layer_result.values, layer_array.T)
288
+
289
+
290
+ @pytest.fixture
291
+ def minimal_mudata(minimal_adata):
292
+ """Create a minimal MuData object for testing.
293
+
294
+ Uses minimal_adata as the RNA modality and creates a simple protein modality.
295
+ Focuses on testing MuData-level operations rather than modality-specific features.
296
+ """
297
+ # Create protein modality with minimal features
298
+ n_vars_protein = 3
299
+ adata_protein = anndata.AnnData(
300
+ X=np.random.randn(minimal_adata.n_obs, n_vars_protein),
301
+ obs=minimal_adata.obs, # Share obs to ensure alignment
302
+ var=pd.DataFrame(
303
+ {
304
+ "uniprot": [
305
+ f"P{i:05d}" for i in range(n_vars_protein)
306
+ ] # Valid ontology column
307
+ },
308
+ index=[f"protein_{i}" for i in range(n_vars_protein)],
309
+ ),
310
+ )
311
+
312
+ # Create MuData with both modalities
313
+ mdata = mudata.MuData({"rna": minimal_adata, "protein": adata_protein})
314
+
315
+ # Add varm table at MuData level
316
+ n_features = 3
317
+ varm_array = np.random.randn(mdata.n_vars, n_features)
318
+ mdata.varm["gene_scores"] = varm_array
319
+ mdata.uns["gene_scores_features"] = ["score1", "score2", "score3"]
320
+
321
+ return mdata
322
+
323
+
324
+ def test_prepare_anndata_results_df_anndata(minimal_adata):
325
+ """Test prepare_anndata_results_df with AnnData input."""
326
+ # Test var table
327
+ var_results = loading.prepare_anndata_results_df(
328
+ minimal_adata, table_type=ADATA.VAR
329
+ )
330
+ assert isinstance(var_results, pd.DataFrame)
331
+ assert var_results.shape[0] == minimal_adata.n_vars
332
+ assert "gene_name" in var_results.columns
333
+ assert "ensembl_transcript" in var_results.columns
334
+
335
+ # Test varm table
336
+ varm_results = loading.prepare_anndata_results_df(
337
+ minimal_adata,
338
+ table_type=ADATA.VARM,
339
+ table_name="gene_scores",
340
+ results_attrs=minimal_adata.uns["gene_scores_features"],
341
+ table_colnames=minimal_adata.uns["gene_scores_features"], # Pass column names
342
+ )
343
+ assert isinstance(varm_results, pd.DataFrame)
344
+ assert varm_results.shape[0] == minimal_adata.n_vars
345
+ assert (
346
+ varm_results.shape[1] == 5
347
+ ) # score1, score2, score3 + gene_name + ensembl_transcript from var table
348
+
349
+ # Check we have both the scores and systematic identifiers
350
+ assert all(
351
+ score in varm_results.columns
352
+ for score in minimal_adata.uns["gene_scores_features"]
353
+ )
354
+ assert "gene_name" in varm_results.columns
355
+ assert "ensembl_transcript" in varm_results.columns
356
+
357
+ # Test with ontology extraction
358
+ var_results_with_ontology = loading.prepare_anndata_results_df(
359
+ minimal_adata,
360
+ table_type=ADATA.VAR,
361
+ index_which_ontology="ensembl_gene", # Use a new ontology name
362
+ )
363
+ assert "ensembl_gene" in var_results_with_ontology.columns
364
+ pd.testing.assert_series_equal(
365
+ var_results_with_ontology["ensembl_gene"],
366
+ pd.Series(
367
+ minimal_adata.var.index, index=minimal_adata.var.index, name="ensembl_gene"
368
+ ),
369
+ )
370
+
371
+ # Test error when trying to use existing column
372
+ with pytest.raises(
373
+ ValueError, match="Cannot use 'gene_name' as index_which_ontology"
374
+ ):
375
+ loading.prepare_anndata_results_df(
376
+ minimal_adata,
377
+ table_type=ADATA.VAR,
378
+ index_which_ontology="gene_name", # Should fail - already exists
379
+ )
380
+
381
+
382
+ def test_split_mdata_results_by_modality(minimal_mudata):
383
+ """Test splitting results table by modality."""
384
+ # Create a combined results table with all vars
385
+ all_results = pd.DataFrame(
386
+ np.random.randn(len(minimal_mudata.var_names), 2),
387
+ index=minimal_mudata.var_names,
388
+ columns=["score1", "score2"],
389
+ )
390
+
391
+ # Split by modality
392
+ modality_results = loading._split_mdata_results_by_modality(
393
+ minimal_mudata, all_results
394
+ )
395
+
396
+ # Check we got both modalities
397
+ assert set(modality_results.keys()) == {"rna", "protein"}
398
+
399
+ # Check RNA results
400
+ rna_results = modality_results["rna"]
401
+ assert isinstance(rna_results, pd.DataFrame)
402
+ assert rna_results.shape[0] == minimal_mudata.mod["rna"].n_vars
403
+ assert list(rna_results.index) == list(minimal_mudata.mod["rna"].var.index)
404
+ assert list(rna_results.columns) == ["score1", "score2"]
405
+
406
+ # Check protein results
407
+ protein_results = modality_results["protein"]
408
+ assert isinstance(protein_results, pd.DataFrame)
409
+ assert protein_results.shape[0] == minimal_mudata.mod["protein"].n_vars
410
+ assert list(protein_results.index) == list(minimal_mudata.mod["protein"].var.index)
411
+ assert list(protein_results.columns) == ["score1", "score2"]
412
+
413
+ # Check that all original results are preserved
414
+ assert len(all_results) == sum(len(df) for df in modality_results.values())
415
+ for modality, results in modality_results.items():
416
+ pd.testing.assert_frame_equal(
417
+ results, all_results.loc[minimal_mudata.mod[modality].var_names]
418
+ )
419
+
420
+
421
+ def test_split_mdata_results_by_modality_errors(minimal_mudata):
422
+ """Test error cases for splitting results by modality."""
423
+ # Create results with wrong index but matching length
424
+ wrong_index_results = pd.DataFrame(
425
+ np.random.randn(len(minimal_mudata.var_names), 2),
426
+ # Create completely different indices that don't overlap with real ones
427
+ index=[f"wrong_var_{i}" for i in range(len(minimal_mudata.var_names))],
428
+ columns=["score1", "score2"],
429
+ )
430
+
431
+ # Should raise error due to index mismatch
432
+ with pytest.raises(ValueError, match="Index mismatch in rna"):
433
+ loading._split_mdata_results_by_modality(minimal_mudata, wrong_index_results)
434
+
435
+
436
+ def test_multimodality_ontology_config():
437
+ """Test MultiModalityOntologyConfig creation and validation."""
438
+ # Test successful creation with different ontology types
439
+ config_dict = {
440
+ "transcriptomics": {
441
+ "ontologies": None, # Auto-detect
442
+ "index_which_ontology": None,
443
+ },
444
+ "proteomics": {
445
+ "ontologies": {"uniprot", "pharos"}, # Set of columns
446
+ "index_which_ontology": "uniprot",
447
+ },
448
+ "atac": {
449
+ "ontologies": {"peak1": "peak_id"}, # Dict mapping
450
+ "index_which_ontology": None,
451
+ },
452
+ }
453
+ config = loading.MultiModalityOntologyConfig.from_dict(config_dict)
454
+
455
+ # Test dictionary-like access
456
+ assert len(config) == 3
457
+ assert set(config) == {"transcriptomics", "proteomics", "atac"}
458
+
459
+ # Test modality access and type preservation
460
+ transcriptomics = config["transcriptomics"]
461
+ assert transcriptomics.ontologies is None
462
+ assert transcriptomics.index_which_ontology is None
463
+
464
+ proteomics = config["proteomics"]
465
+ assert isinstance(proteomics.ontologies, set)
466
+ assert proteomics.ontologies == {"uniprot", "pharos"}
467
+ assert proteomics.index_which_ontology == "uniprot"
468
+
469
+ atac = config["atac"]
470
+ assert isinstance(atac.ontologies, dict)
471
+ assert atac.ontologies == {"peak1": "peak_id"}
472
+ assert atac.index_which_ontology is None
473
+
474
+ # Test items() method
475
+ for modality, modality_config in config.items():
476
+ assert modality in config_dict
477
+ assert modality_config.ontologies == config_dict[modality]["ontologies"]
478
+ assert (
479
+ modality_config.index_which_ontology
480
+ == config_dict[modality]["index_which_ontology"]
481
+ )
482
+
483
+ # Test validation - missing required field
484
+ with pytest.raises(ValueError):
485
+ loading.MultiModalityOntologyConfig.from_dict(
486
+ {
487
+ "transcriptomics": {
488
+ "index_which_ontology": "ensembl_gene" # Missing ontologies field
489
+ }
490
+ }
491
+ )
492
+
493
+ # Test validation - wrong type for ontologies
494
+ with pytest.raises(ValueError):
495
+ loading.MultiModalityOntologyConfig.from_dict(
496
+ {
497
+ "transcriptomics": {
498
+ "ontologies": "ensembl_gene", # Should be None, set, or dict
499
+ "index_which_ontology": "ensembl_gene",
500
+ }
501
+ }
502
+ )
503
+
504
+ # Test empty config
505
+ empty_config = loading.MultiModalityOntologyConfig(root={})
506
+ assert len(empty_config) == 0
507
+
508
+ # Test optional index_which_ontology
509
+ minimal_config = loading.MultiModalityOntologyConfig.from_dict(
510
+ {"transcriptomics": {"ontologies": None}} # No index_which_ontology
511
+ )
512
+ assert minimal_config["transcriptomics"].index_which_ontology is None
513
+
514
+
515
+ def test_prepare_mudata_results_df(minimal_mudata):
516
+ """Test prepare_mudata_results_df with different ontology configurations.
517
+
518
+ The function should:
519
+ 1. Use MuData's var/varm tables for the actual data
520
+ 2. Use each modality's var table for ontology information
521
+ 3. Return a dictionary of DataFrames, one per modality
522
+ 4. Each DataFrame should have the ontology columns and any requested data columns
523
+ """
524
+ # Arrange
525
+ config = {
526
+ "rna": {
527
+ "ontologies": None, # Auto-detect
528
+ "index_which_ontology": "ensembl_gene", # Rename index to this
529
+ },
530
+ "protein": {
531
+ "ontologies": {
532
+ "protein_id": "ensembl_protein",
533
+ "uniprot": "uniprot",
534
+ }, # Map source column to valid ontology name
535
+ "index_which_ontology": None,
536
+ },
537
+ }
538
+
539
+ # First add the source column to the protein modality
540
+ minimal_mudata.mod["protein"].var["protein_id"] = minimal_mudata.mod["protein"].var[
541
+ "uniprot"
542
+ ]
543
+
544
+ expected_rna_ontologies = {
545
+ "ensembl_gene", # From index
546
+ "ensembl_transcript", # From RNA modality var
547
+ "gene_name", # From RNA modality var
548
+ }
549
+ expected_protein_ontologies = {
550
+ "uniprot",
551
+ "ensembl_protein",
552
+ } # Both original and renamed columns
553
+
554
+ # Act - Test var table extraction
555
+ var_results = loading.prepare_mudata_results_df(
556
+ minimal_mudata, mudata_ontologies=config, table_type=ADATA.VAR
557
+ )
558
+
559
+ # Assert - Basic structure
560
+ assert set(var_results.keys()) == {"rna", "protein"}
561
+
562
+ # Assert - RNA modality
563
+ rna_results = var_results["rna"]
564
+ assert isinstance(rna_results, pd.DataFrame)
565
+ assert rna_results.shape[0] == minimal_mudata.mod["rna"].n_vars
566
+ assert expected_rna_ontologies.issubset(
567
+ set(rna_results.columns)
568
+ ), f"Missing ontology columns: {expected_rna_ontologies - set(rna_results.columns)}"
569
+
570
+ # Assert - Protein modality
571
+ protein_results = var_results["protein"]
572
+ assert isinstance(protein_results, pd.DataFrame)
573
+ assert protein_results.shape[0] == minimal_mudata.mod["protein"].n_vars
574
+ assert expected_protein_ontologies.issubset(
575
+ set(protein_results.columns)
576
+ ), f"Missing ontology columns: {expected_protein_ontologies - set(protein_results.columns)}"
577
+ # Check that source column was correctly renamed
578
+ assert "protein_id" not in protein_results.columns
579
+ assert "ensembl_protein" in protein_results.columns
580
+ pd.testing.assert_series_equal(
581
+ protein_results["ensembl_protein"],
582
+ minimal_mudata.mod["protein"].var["protein_id"],
583
+ check_names=False, # Ignore Series names in comparison
584
+ )
585
+
586
+ # Act - Test varm table extraction with explicit results_attrs
587
+ varm_results = loading.prepare_mudata_results_df(
588
+ minimal_mudata,
589
+ mudata_ontologies=config,
590
+ table_type=ADATA.VARM,
591
+ table_name="gene_scores",
592
+ results_attrs=["score1", "score2"],
593
+ table_colnames=["score1", "score2", "score3"],
594
+ )
595
+
596
+ # Assert - RNA varm results
597
+ rna_varm = varm_results["rna"]
598
+ expected_rna_varm_cols = expected_rna_ontologies | {"score1", "score2"}
599
+ assert isinstance(rna_varm, pd.DataFrame)
600
+ assert rna_varm.shape[0] == minimal_mudata.mod["rna"].n_vars
601
+ assert expected_rna_varm_cols.issubset(
602
+ set(rna_varm.columns)
603
+ ), f"Missing columns: {expected_rna_varm_cols - set(rna_varm.columns)}"
604
+
605
+ # Assert - Protein varm results
606
+ protein_varm = varm_results["protein"]
607
+ expected_protein_varm_cols = expected_protein_ontologies | {"score1", "score2"}
608
+ assert isinstance(protein_varm, pd.DataFrame)
609
+ assert protein_varm.shape[0] == minimal_mudata.mod["protein"].n_vars
610
+ assert expected_protein_varm_cols.issubset(
611
+ set(protein_varm.columns)
612
+ ), f"Missing columns: {expected_protein_varm_cols - set(protein_varm.columns)}"
613
+
614
+
615
+ def test_prepare_mudata_results_df_errors(minimal_mudata):
616
+ """Test error cases for prepare_mudata_results_df."""
617
+ # Test missing modality configuration
618
+ with pytest.raises(
619
+ ValueError, match="Missing ontology configurations for modalities"
620
+ ):
621
+ loading.prepare_mudata_results_df(
622
+ minimal_mudata,
623
+ mudata_ontologies={"rna": {"ontologies": None}},
624
+ table_type=ADATA.VAR,
625
+ )
626
+
627
+ # Test invalid table type
628
+ with pytest.raises(ValueError, match="table_type must be one of"):
629
+ loading.prepare_mudata_results_df(
630
+ minimal_mudata,
631
+ mudata_ontologies={
632
+ "rna": {"ontologies": None},
633
+ "protein": {"ontologies": None},
634
+ },
635
+ table_type="invalid_type",
636
+ )
637
+
638
+ # Test missing table_colnames for varm
639
+ # Add varm table to RNA modality first
640
+ minimal_mudata.mod["rna"].varm["scores"] = np.random.randn(
641
+ minimal_mudata.mod["rna"].n_vars, 3
642
+ )
643
+ minimal_mudata.mod["rna"].uns["scores_features"] = ["score1", "score2", "score3"]
644
+ with pytest.raises(ValueError, match="table_name 'scores' not found in adata.varm"):
645
+ loading.prepare_mudata_results_df(
646
+ minimal_mudata,
647
+ mudata_ontologies={
648
+ "rna": {"ontologies": None},
649
+ "protein": {"ontologies": None},
650
+ },
651
+ table_type=ADATA.VARM,
652
+ table_name="scores",
653
+ results_attrs=["score1", "score2"], # Missing table_colnames
654
+ )
655
+
656
+
657
+ def test_prepare_mudata_results_df_adata_level(minimal_mudata):
658
+ """Test prepare_mudata_results_df with level='adata' to extract adata-specific attributes."""
659
+ # Arrange - Add some adata-specific var attributes that don't exist at MuData level
660
+ minimal_mudata.mod["rna"].var["rna_specific"] = [
661
+ f"rna_val_{i}" for i in range(minimal_mudata.mod["rna"].n_vars)
662
+ ]
663
+ minimal_mudata.mod["protein"].var["protein_specific"] = [
664
+ f"prot_val_{i}" for i in range(minimal_mudata.mod["protein"].n_vars)
665
+ ]
666
+
667
+ config = {
668
+ "rna": {
669
+ "ontologies": None, # Auto-detect
670
+ "index_which_ontology": None,
671
+ },
672
+ "protein": {
673
+ "ontologies": {"uniprot"},
674
+ "index_which_ontology": None,
675
+ },
676
+ }
677
+
678
+ # Act - Test var table extraction at adata level
679
+ var_results = loading.prepare_mudata_results_df(
680
+ minimal_mudata,
681
+ mudata_ontologies=config,
682
+ table_type=ADATA.VAR,
683
+ level=SCVERSE_DEFS.ADATA,
684
+ )
685
+
686
+ # Assert - Check that adata-specific attributes are included
687
+ rna_results = var_results["rna"]
688
+ protein_results = var_results["protein"]
689
+
690
+ # RNA should have its specific attribute
691
+ assert "rna_specific" in rna_results.columns
692
+ assert "gene_name" in rna_results.columns # From original RNA var
693
+ assert "ensembl_transcript" in rna_results.columns # From original RNA var
694
+
695
+ # Protein should have its specific attribute
696
+ assert "protein_specific" in protein_results.columns
697
+ assert "uniprot" in protein_results.columns # From original protein var
698
+
699
+ # Check values are correct
700
+ pd.testing.assert_series_equal(
701
+ rna_results["rna_specific"],
702
+ minimal_mudata.mod["rna"].var["rna_specific"],
703
+ check_names=False,
704
+ )
705
+ pd.testing.assert_series_equal(
706
+ protein_results["protein_specific"],
707
+ minimal_mudata.mod["protein"].var["protein_specific"],
708
+ check_names=False,
709
+ )
710
+
711
+
712
+ def test_prepare_mudata_results_df_mdata_vs_adata_level(minimal_mudata):
713
+ """Test that level='mdata' vs level='adata' produce different results when appropriate."""
714
+ # Arrange - Add adata-specific varm tables
715
+ rna_varm = np.random.randn(minimal_mudata.mod["rna"].n_vars, 2)
716
+ protein_varm = np.random.randn(minimal_mudata.mod["protein"].n_vars, 2)
717
+
718
+ minimal_mudata.mod["rna"].varm["modality_scores"] = rna_varm
719
+ minimal_mudata.mod["protein"].varm["modality_scores"] = protein_varm
720
+
721
+ config = {
722
+ "rna": {"ontologies": None},
723
+ "protein": {"ontologies": {"uniprot"}},
724
+ }
725
+
726
+ # Act - Extract using both levels
727
+ mdata_results = loading.prepare_mudata_results_df(
728
+ minimal_mudata,
729
+ mudata_ontologies=config,
730
+ table_type=ADATA.VARM,
731
+ table_name="gene_scores", # This exists at MuData level
732
+ results_attrs=["score1"],
733
+ table_colnames=["score1", "score2", "score3"],
734
+ level=SCVERSE_DEFS.MDATA,
735
+ )
736
+
737
+ adata_results = loading.prepare_mudata_results_df(
738
+ minimal_mudata,
739
+ mudata_ontologies=config,
740
+ table_type=ADATA.VARM,
741
+ table_name="modality_scores", # This exists at modality level
742
+ results_attrs=[
743
+ "0"
744
+ ], # Using column index as string since we don't have explicit names
745
+ table_colnames=["0", "1"],
746
+ level=SCVERSE_DEFS.ADATA,
747
+ )
748
+
749
+ # Assert - Both should succeed but access different data
750
+ assert "score1" in mdata_results["rna"].columns
751
+ assert "score1" in mdata_results["protein"].columns
752
+
753
+ assert "0" in adata_results["rna"].columns
754
+ assert "0" in adata_results["protein"].columns
755
+
756
+ # The values should be different since they come from different varm tables
757
+ # (We can't easily check exact values due to random generation, but structure should be correct)
758
+ assert mdata_results["rna"].shape[0] == minimal_mudata.mod["rna"].n_vars
759
+ assert adata_results["rna"].shape[0] == minimal_mudata.mod["rna"].n_vars
760
+
761
+
762
+ def test_prepare_mudata_results_df_level_validation(minimal_mudata):
763
+ """Test that invalid level parameter raises appropriate error."""
764
+ config = {
765
+ "rna": {"ontologies": None},
766
+ "protein": {"ontologies": {"uniprot"}},
767
+ }
768
+
769
+ with pytest.raises(
770
+ ValueError,
771
+ match=r"level must be one of \['adata', 'mdata'\], got invalid_level",
772
+ ):
773
+ loading.prepare_mudata_results_df(
774
+ minimal_mudata,
775
+ mudata_ontologies=config,
776
+ table_type=ADATA.VAR,
777
+ level="invalid_level",
778
+ )