napistu 0.2.5.dev7__py3-none-any.whl → 0.3.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. napistu/__init__.py +1 -3
  2. napistu/__main__.py +126 -96
  3. napistu/constants.py +35 -41
  4. napistu/context/__init__.py +10 -0
  5. napistu/context/discretize.py +462 -0
  6. napistu/context/filtering.py +387 -0
  7. napistu/gcs/__init__.py +1 -1
  8. napistu/identifiers.py +74 -15
  9. napistu/indices.py +68 -0
  10. napistu/ingestion/__init__.py +1 -1
  11. napistu/ingestion/bigg.py +47 -62
  12. napistu/ingestion/constants.py +18 -133
  13. napistu/ingestion/gtex.py +113 -0
  14. napistu/ingestion/hpa.py +147 -0
  15. napistu/ingestion/sbml.py +0 -97
  16. napistu/ingestion/string.py +2 -2
  17. napistu/matching/__init__.py +10 -0
  18. napistu/matching/constants.py +18 -0
  19. napistu/matching/interactions.py +518 -0
  20. napistu/matching/mount.py +529 -0
  21. napistu/matching/species.py +510 -0
  22. napistu/mcp/__init__.py +7 -4
  23. napistu/mcp/__main__.py +128 -72
  24. napistu/mcp/client.py +16 -25
  25. napistu/mcp/codebase.py +201 -145
  26. napistu/mcp/component_base.py +170 -0
  27. napistu/mcp/config.py +223 -0
  28. napistu/mcp/constants.py +45 -2
  29. napistu/mcp/documentation.py +253 -136
  30. napistu/mcp/documentation_utils.py +13 -48
  31. napistu/mcp/execution.py +372 -305
  32. napistu/mcp/health.py +47 -65
  33. napistu/mcp/profiles.py +10 -6
  34. napistu/mcp/server.py +161 -80
  35. napistu/mcp/tutorials.py +139 -87
  36. napistu/modify/__init__.py +1 -1
  37. napistu/modify/gaps.py +1 -1
  38. napistu/network/__init__.py +1 -1
  39. napistu/network/constants.py +101 -34
  40. napistu/network/data_handling.py +388 -0
  41. napistu/network/ig_utils.py +351 -0
  42. napistu/network/napistu_graph_core.py +354 -0
  43. napistu/network/neighborhoods.py +40 -40
  44. napistu/network/net_create.py +373 -309
  45. napistu/network/net_propagation.py +47 -19
  46. napistu/network/{net_utils.py → ng_utils.py} +124 -272
  47. napistu/network/paths.py +67 -51
  48. napistu/network/precompute.py +11 -11
  49. napistu/ontologies/__init__.py +10 -0
  50. napistu/ontologies/constants.py +129 -0
  51. napistu/ontologies/dogma.py +243 -0
  52. napistu/ontologies/genodexito.py +649 -0
  53. napistu/ontologies/mygene.py +369 -0
  54. napistu/ontologies/renaming.py +198 -0
  55. napistu/rpy2/__init__.py +229 -86
  56. napistu/rpy2/callr.py +47 -77
  57. napistu/rpy2/constants.py +24 -23
  58. napistu/rpy2/rids.py +61 -648
  59. napistu/sbml_dfs_core.py +587 -222
  60. napistu/scverse/__init__.py +15 -0
  61. napistu/scverse/constants.py +28 -0
  62. napistu/scverse/loading.py +727 -0
  63. napistu/utils.py +118 -10
  64. {napistu-0.2.5.dev7.dist-info → napistu-0.3.1.dev1.dist-info}/METADATA +8 -3
  65. napistu-0.3.1.dev1.dist-info/RECORD +133 -0
  66. tests/conftest.py +22 -0
  67. tests/test_context_discretize.py +56 -0
  68. tests/test_context_filtering.py +267 -0
  69. tests/test_identifiers.py +100 -0
  70. tests/test_indices.py +65 -0
  71. tests/{test_edgelist.py → test_ingestion_napistu_edgelist.py} +2 -2
  72. tests/test_matching_interactions.py +108 -0
  73. tests/test_matching_mount.py +305 -0
  74. tests/test_matching_species.py +394 -0
  75. tests/test_mcp_config.py +193 -0
  76. tests/test_mcp_documentation_utils.py +12 -3
  77. tests/test_mcp_server.py +156 -19
  78. tests/test_network_data_handling.py +397 -0
  79. tests/test_network_ig_utils.py +23 -0
  80. tests/test_network_neighborhoods.py +19 -0
  81. tests/test_network_net_create.py +459 -0
  82. tests/test_network_ng_utils.py +30 -0
  83. tests/test_network_paths.py +56 -0
  84. tests/{test_precomputed_distances.py → test_network_precompute.py} +8 -6
  85. tests/test_ontologies_genodexito.py +58 -0
  86. tests/test_ontologies_mygene.py +39 -0
  87. tests/test_ontologies_renaming.py +110 -0
  88. tests/test_rpy2_callr.py +79 -0
  89. tests/test_rpy2_init.py +151 -0
  90. tests/test_sbml.py +0 -31
  91. tests/test_sbml_dfs_core.py +134 -10
  92. tests/test_scverse_loading.py +778 -0
  93. tests/test_set_coverage.py +2 -2
  94. tests/test_utils.py +121 -1
  95. napistu/mechanism_matching.py +0 -1353
  96. napistu/rpy2/netcontextr.py +0 -467
  97. napistu-0.2.5.dev7.dist-info/RECORD +0 -98
  98. tests/test_igraph.py +0 -367
  99. tests/test_mechanism_matching.py +0 -784
  100. tests/test_net_utils.py +0 -149
  101. tests/test_netcontextr.py +0 -105
  102. tests/test_rpy2.py +0 -61
  103. /napistu/ingestion/{cpr_edgelist.py → napistu_edgelist.py} +0 -0
  104. {napistu-0.2.5.dev7.dist-info → napistu-0.3.1.dev1.dist-info}/WHEEL +0 -0
  105. {napistu-0.2.5.dev7.dist-info → napistu-0.3.1.dev1.dist-info}/entry_points.txt +0 -0
  106. {napistu-0.2.5.dev7.dist-info → napistu-0.3.1.dev1.dist-info}/licenses/LICENSE +0 -0
  107. {napistu-0.2.5.dev7.dist-info → napistu-0.3.1.dev1.dist-info}/top_level.txt +0 -0
  108. /tests/{test_obo.py → test_ingestion_obo.py} +0 -0
@@ -0,0 +1,727 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Optional, List, Union, Set, Dict
5
+
6
+ import anndata
7
+ import pandas as pd
8
+ import mudata
9
+ import numpy as np
10
+ from pydantic import BaseModel, Field, RootModel
11
+
12
+ from napistu.matching import species
13
+ from napistu.constants import ONTOLOGIES_LIST
14
+ from napistu.scverse.constants import (
15
+ ADATA,
16
+ ADATA_DICTLIKE_ATTRS,
17
+ ADATA_IDENTITY_ATTRS,
18
+ ADATA_FEATURELEVEL_ATTRS,
19
+ ADATA_ARRAY_ATTRS,
20
+ SCVERSE_DEFS,
21
+ VALID_MUDATA_LEVELS,
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def prepare_anndata_results_df(
28
+ adata: Union[anndata.AnnData, mudata.MuData],
29
+ table_type: str = ADATA.VAR,
30
+ table_name: Optional[str] = None,
31
+ results_attrs: Optional[List[str]] = None,
32
+ ontologies: Optional[Union[Set[str], Dict[str, str]]] = None,
33
+ index_which_ontology: Optional[str] = None,
34
+ table_colnames: Optional[List[str]] = None,
35
+ ) -> pd.DataFrame:
36
+ """
37
+ Prepare a results table from an AnnData object for use in Napistu.
38
+
39
+ This function extracts a table from an AnnData object and formats it for use in Napistu.
40
+ The returned DataFrame will always include systematic identifiers from the var table,
41
+ along with the requested results data.
42
+
43
+ Parameters
44
+ ----------
45
+ adata : anndata.AnnData or mudata.MuData
46
+ The AnnData or MuData object containing the results to be formatted.
47
+ table_type : str, optional
48
+ The type of table to extract from the AnnData object. Must be one of: "var", "varm", or "X".
49
+ table_name : str, optional
50
+ The name of the table to extract from the AnnData object.
51
+ results_attrs : list of str, optional
52
+ The attributes to extract from the table.
53
+ index_which_ontology : str, optional
54
+ The ontology to use for the systematic identifiers. This column will be pulled out of the
55
+ index renamed to the ontology name, and added to the results table as a new column with
56
+ the same name. Must not already exist in var table.
57
+ ontologies : Optional[Union[Set[str], Dict[str, str]]], default=None
58
+ Either:
59
+ - Set of columns to treat as ontologies (these should be entries in ONTOLOGIES_LIST )
60
+ - Dict mapping wide column names to ontology names in the ONTOLOGIES_LIST controlled vocabulary
61
+ - None to automatically detect valid ontology columns based on ONTOLOGIES_LIST
62
+
63
+ If index_which_ontology is defined, it should be represented in these ontologies.
64
+ table_colnames : Optional[List[str]], optional
65
+ Column names for varm tables. Required when table_type is "varm". Ignored otherwise.
66
+
67
+ Returns
68
+ -------
69
+ pd.DataFrame
70
+ A DataFrame containing the formatted results with systematic identifiers.
71
+ The index will match the var_names of the AnnData object.
72
+
73
+ Raises
74
+ ------
75
+ ValueError
76
+ If table_type is not one of: "var", "varm", or "X"
77
+ If index_which_ontology already exists in var table
78
+ """
79
+
80
+ if table_type not in ADATA_FEATURELEVEL_ATTRS:
81
+ raise ValueError(
82
+ f"table_type must be one of {ADATA_FEATURELEVEL_ATTRS}, got {table_type}"
83
+ )
84
+
85
+ # pull out the table containing results
86
+ raw_results_table = _load_raw_table(adata, table_type, table_name)
87
+
88
+ # convert the raw results to a pd.DataFrame with rows corresponding to vars and columns
89
+ # being attributes of interest
90
+ results_data_table = _select_results_attrs(
91
+ adata, raw_results_table, table_type, results_attrs, table_colnames
92
+ )
93
+
94
+ # Extract and validate ontologies from var table
95
+ var_ontologies = _extract_ontologies(adata.var, ontologies, index_which_ontology)
96
+
97
+ # Combine ontologies with results data
98
+ # Both should have the same index (var_names)
99
+ results_table = pd.concat([var_ontologies, results_data_table], axis=1)
100
+
101
+ return results_table
102
+
103
+
104
+ def prepare_mudata_results_df(
105
+ mdata: mudata.MuData,
106
+ mudata_ontologies: Union[
107
+ "MultiModalityOntologyConfig",
108
+ Dict[
109
+ str,
110
+ Dict[str, Union[Optional[Union[Set[str], Dict[str, str]]], Optional[str]]],
111
+ ],
112
+ ],
113
+ table_type: str = ADATA.VAR,
114
+ table_name: Optional[str] = None,
115
+ results_attrs: Optional[List[str]] = None,
116
+ table_colnames: Optional[List[str]] = None,
117
+ level: str = "mdata",
118
+ ) -> Dict[str, pd.DataFrame]:
119
+ """
120
+ Prepare results tables from a MuData object for use in Napistu, with adata-specific ontology handling.
121
+
122
+ This function extracts tables from each adata in a MuData object and formats them for use in Napistu.
123
+ Each adata's table will include systematic identifiers from its var table along with the requested results data.
124
+ Ontology handling is configured per-adata using MultiModalityOntologyConfig.
125
+
126
+ Parameters
127
+ ----------
128
+ mdata : mudata.MuData
129
+ The MuData object containing the results to be formatted.
130
+ mudata_ontologies : MultiModalityOntologyConfig or dict
131
+ Configuration for ontology handling modality (each with a separate AnnData object). Must include an entry for each modality. Can be either:
132
+ - A MultiModalityOntologyConfig object
133
+ - A dictionary that can be converted to MultiModalityOntologyConfig using from_dict()
134
+ Each modality's 'ontologies' field can be:
135
+ - None to automatically detect valid ontology columns
136
+ - Set of columns to treat as ontologies
137
+ - Dict mapping wide column names to ontology names
138
+ The 'index_which_ontology' field is optional.
139
+ table_type : str, optional
140
+ The type of table to extract from each modality. Must be one of: "var", "varm", or "X".
141
+ table_name : str, optional
142
+ The name of the table to extract from each modality.
143
+ results_attrs : list of str, optional
144
+ The attributes to extract from the table.
145
+ table_colnames : list of str, optional
146
+ Column names for varm tables. Required when table_type is "varm". Ignored otherwise.
147
+ level : str, optional
148
+ Whether to extract data from "mdata" (MuData-level) or "adata" (individual AnnData-level) tables.
149
+ Default is "mdata".
150
+
151
+ Returns
152
+ -------
153
+ Dict[str, pd.DataFrame]
154
+ Dictionary mapping modality names to their formatted results DataFrames.
155
+ Each DataFrame contains the modality's results with systematic identifiers.
156
+ The index of each DataFrame will match the var_names of that modality.
157
+
158
+ Raises
159
+ ------
160
+ ValueError
161
+ If table_type is not one of: "var", "varm", or "X"
162
+ If mudata_ontologies contains invalid configuration
163
+ If modality-specific ontology extraction fails
164
+ If any modality is missing from mudata_ontologies
165
+ If level is not "global" or "modality"
166
+ """
167
+ if table_type not in ADATA_FEATURELEVEL_ATTRS:
168
+ raise ValueError(
169
+ f"table_type must be one of {ADATA_FEATURELEVEL_ATTRS}, got {table_type}"
170
+ )
171
+
172
+ if level not in VALID_MUDATA_LEVELS:
173
+ raise ValueError(
174
+ f"level must be one of {sorted(VALID_MUDATA_LEVELS)}, got {level}"
175
+ )
176
+
177
+ # Convert dict config to MultiModalityOntologyConfig if needed
178
+ if isinstance(mudata_ontologies, dict):
179
+ mudata_ontologies = MultiModalityOntologyConfig.from_dict(mudata_ontologies)
180
+
181
+ # Validate that all modalities have configurations
182
+ missing_modalities = set(mdata.mod.keys()) - set(mudata_ontologies.root.keys())
183
+ if missing_modalities:
184
+ raise ValueError(
185
+ f"Missing ontology configurations for modalities: {missing_modalities}. "
186
+ "Each modality must have at least the 'ontologies' field specified."
187
+ )
188
+
189
+ if level == SCVERSE_DEFS.MDATA:
190
+ # Use MuData-level tables
191
+ # Pull out the table containing results
192
+ raw_results_table = _load_raw_table(mdata, table_type, table_name)
193
+
194
+ # Convert the raw results to a pd.DataFrame with rows corresponding to vars and columns
195
+ # being attributes of interest
196
+ results_data_table = _select_results_attrs(
197
+ mdata, raw_results_table, table_type, results_attrs, table_colnames
198
+ )
199
+
200
+ # Split results by modality
201
+ split_results_data_tables = _split_mdata_results_by_modality(
202
+ mdata, results_data_table
203
+ )
204
+ else:
205
+ # Use modality-level tables
206
+ split_results_data_tables = {}
207
+ for modality in mdata.mod.keys():
208
+ # Load raw table from this modality
209
+ raw_results_table = _load_raw_table(
210
+ mdata.mod[modality], table_type, table_name
211
+ )
212
+
213
+ # Convert to DataFrame
214
+ results_data_table = _select_results_attrs(
215
+ mdata.mod[modality],
216
+ raw_results_table,
217
+ table_type,
218
+ results_attrs,
219
+ table_colnames,
220
+ )
221
+
222
+ split_results_data_tables[modality] = results_data_table
223
+
224
+ # Extract each modality's ontology table and then merge it with
225
+ # the modality's data table
226
+ split_results_tables = {}
227
+ for modality in mdata.mod.keys():
228
+ # Get ontology config for this modality
229
+ modality_ontology_spec = mudata_ontologies[modality]
230
+
231
+ # Extract ontologies according to the modality's specification
232
+ ontology_table = _extract_ontologies(
233
+ mdata.mod[modality].var,
234
+ modality_ontology_spec.ontologies,
235
+ modality_ontology_spec.index_which_ontology,
236
+ )
237
+
238
+ # Combine ontologies with results
239
+ split_results_tables[modality] = pd.concat(
240
+ [ontology_table, split_results_data_tables[modality]], axis=1
241
+ )
242
+
243
+ return split_results_tables
244
+
245
+
246
+ def _load_raw_table(
247
+ adata: Union[anndata.AnnData, mudata.MuData],
248
+ table_type: str,
249
+ table_name: Optional[str] = None,
250
+ ) -> Union[pd.DataFrame, np.ndarray]:
251
+ """
252
+ Load an AnnData table.
253
+
254
+ This function loads an AnnData table and returns it as a pd.DataFrame.
255
+
256
+ Parameters
257
+ ----------
258
+ adata : anndata.AnnData or mudata.MuData
259
+ The AnnData or MuData object to load the table from.
260
+ table_type : str
261
+ The type of table to load.
262
+ table_name : str, optional
263
+ The name of the table to load.
264
+
265
+ Returns
266
+ -------
267
+ pd.DataFrame or np.ndarray
268
+ The loaded table.
269
+ """
270
+
271
+ valid_attrs = ADATA_DICTLIKE_ATTRS | ADATA_IDENTITY_ATTRS
272
+ if table_type not in valid_attrs:
273
+ raise ValueError(
274
+ f"table_type {table_type} is not a valid AnnData attribute. Valid attributes are: {valid_attrs}"
275
+ )
276
+
277
+ if table_type in ADATA_IDENTITY_ATTRS:
278
+ if table_name is not None:
279
+ logger.debug(
280
+ f"table_name {table_name} is not None, but table_type is in IDENTITY_TABLES. "
281
+ f"table_name will be ignored."
282
+ )
283
+ return getattr(adata, table_type)
284
+
285
+ # pull out a dict-like attribute
286
+ return _get_table_from_dict_attr(adata, table_type, table_name)
287
+
288
+
289
+ def _get_table_from_dict_attr(
290
+ adata: Union[anndata.AnnData, mudata.MuData],
291
+ attr_name: str,
292
+ table_name: Optional[str] = None,
293
+ ) -> Union[pd.DataFrame, np.ndarray]:
294
+ """
295
+ Get a table from a dict-like AnnData attribute (varm, layers, etc.)
296
+
297
+ Parameters
298
+ ----------
299
+ adata : anndata.AnnData or mudata.MuData
300
+ The AnnData or MuData object to load the table from
301
+ attr_name : str
302
+ Name of the attribute ('varm', 'layers', etc.)
303
+ table_name : str, optional
304
+ Specific table name to retrieve. If None and only one table exists,
305
+ that table will be returned. If None and multiple tables exist,
306
+ raises ValueError
307
+
308
+ Returns
309
+ -------
310
+ Union[pd.DataFrame, np.ndarray]
311
+ The table data. For array-type attributes (varm, varp, X, layers),
312
+ returns numpy array. For other attributes, returns DataFrame
313
+
314
+ Raises
315
+ ------
316
+ ValueError
317
+ If attr_name is not a valid dict-like attribute
318
+ If no tables found in the attribute
319
+ If multiple tables found and table_name not specified
320
+ If specified table_name not found
321
+ """
322
+
323
+ if attr_name not in ADATA_DICTLIKE_ATTRS:
324
+ raise ValueError(
325
+ f"attr_name {attr_name} is not a dict-like AnnData attribute. Valid attributes are: {ADATA_DICTLIKE_ATTRS}"
326
+ )
327
+
328
+ attr_dict = getattr(adata, attr_name)
329
+ available_tables = list(attr_dict.keys())
330
+
331
+ if len(available_tables) == 0:
332
+ raise ValueError(f"No tables found in adata.{attr_name}")
333
+ elif (len(available_tables) > 1) and (table_name is None):
334
+ raise ValueError(
335
+ f"Multiple tables found in adata.{attr_name} and table_name is not specified. "
336
+ f"Available: {available_tables}"
337
+ )
338
+ elif (len(available_tables) == 1) and (table_name is None):
339
+ return attr_dict[available_tables[0]]
340
+ elif table_name not in available_tables:
341
+ raise ValueError(
342
+ f"table_name '{table_name}' not found in adata.{attr_name}. "
343
+ f"Available: {available_tables}"
344
+ )
345
+ else:
346
+ return attr_dict[table_name]
347
+
348
+
349
+ def _select_results_attrs(
350
+ adata: anndata.AnnData,
351
+ raw_results_table: Union[pd.DataFrame, np.ndarray],
352
+ table_type: str,
353
+ results_attrs: Optional[List[str]] = None,
354
+ table_colnames: Optional[List[str]] = None,
355
+ ) -> pd.DataFrame:
356
+ """
357
+ Select results attributes from an AnnData object.
358
+
359
+ This function selects results attributes from raw_results_table derived
360
+ from an AnnData object and converts them if needed to a pd.DataFrame
361
+ with appropriate indices.
362
+
363
+ Parameters
364
+ ----------
365
+ adata : anndata.AnnData
366
+ The AnnData object containing the results to be formatted.
367
+ raw_results_table : pd.DataFrame or np.ndarray
368
+ The raw results table to be formatted.
369
+ table_type: str,
370
+ The type of table `raw_results_table` refers to.
371
+ results_attrs : list of str, optional
372
+ The attributes to extract from the raw_results_table.
373
+ table_colnames: list of str, optional,
374
+ If `table_type` is `varm`, this is the names of all columns (e.g., PC1, PC2, etc.). Ignored otherwise
375
+
376
+ Returns
377
+ -------
378
+ pd.DataFrame
379
+ A DataFrame containing the formatted results.
380
+ """
381
+ logger.debug(
382
+ f"_select_results_attrs called with table_type={table_type}, results_attrs={results_attrs}"
383
+ )
384
+
385
+ # Validate that array-type tables are not passed as DataFrames
386
+ if table_type in ADATA_ARRAY_ATTRS and isinstance(raw_results_table, pd.DataFrame):
387
+ raise ValueError(
388
+ f"Table type {table_type} must be a numpy array, not a DataFrame. Got {type(raw_results_table)}"
389
+ )
390
+
391
+ if isinstance(raw_results_table, pd.DataFrame):
392
+ if results_attrs is not None:
393
+ # Get available columns for better error message
394
+ available_attrs = raw_results_table.columns.tolist()
395
+ missing_attrs = [
396
+ attr for attr in results_attrs if attr not in available_attrs
397
+ ]
398
+ if missing_attrs:
399
+ raise ValueError(
400
+ f"The following results attributes were not found: {missing_attrs}\n"
401
+ f"Available attributes are: {available_attrs}"
402
+ )
403
+ results_table_data = raw_results_table.loc[:, results_attrs]
404
+ else:
405
+ results_table_data = raw_results_table
406
+ return results_table_data
407
+
408
+ # Convert sparse matrix to dense if needed
409
+ if hasattr(raw_results_table, "toarray"):
410
+ raw_results_table = raw_results_table.toarray()
411
+
412
+ valid_attrs = _get_valid_attrs_for_feature_level_array(
413
+ adata, table_type, raw_results_table, table_colnames
414
+ )
415
+
416
+ if results_attrs is not None:
417
+ invalid_results_attrs = [x for x in results_attrs if x not in valid_attrs]
418
+ if len(invalid_results_attrs) > 0:
419
+ raise ValueError(
420
+ f"The following results attributes were not found: {invalid_results_attrs}\n"
421
+ f"Available attributes are: {valid_attrs}"
422
+ )
423
+
424
+ # Get positions based on table type
425
+ if table_type == ADATA.VARM:
426
+ positions = [table_colnames.index(attr) for attr in results_attrs]
427
+ selected_array = raw_results_table[:, positions]
428
+ elif table_type == ADATA.VARP:
429
+ positions = [adata.var.index.get_loc(attr) for attr in results_attrs]
430
+ selected_array = raw_results_table[:, positions]
431
+ else: # X or layers
432
+ positions = [adata.obs.index.get_loc(attr) for attr in results_attrs]
433
+ selected_array = raw_results_table[positions, :]
434
+
435
+ results_table_data = _create_results_df(
436
+ selected_array, results_attrs, adata.var.index, table_type
437
+ )
438
+ else:
439
+ results_table_data = _create_results_df(
440
+ raw_results_table, valid_attrs, adata.var.index, table_type
441
+ )
442
+
443
+ return results_table_data
444
+
445
+
446
+ def _get_valid_attrs_for_feature_level_array(
447
+ adata: anndata.AnnData,
448
+ table_type: str,
449
+ raw_results_table: np.ndarray,
450
+ table_colnames: Optional[List[str]] = None,
451
+ ) -> list[str]:
452
+ """
453
+ Get valid attributes for a feature-level array.
454
+
455
+ Parameters
456
+ ----------
457
+ adata : anndata.AnnData
458
+ The AnnData object
459
+ table_type : str
460
+ The type of table
461
+ raw_results_table : np.ndarray
462
+ The raw results table for dimension validation
463
+ table_colnames : Optional[List[str]]
464
+ Column names for varm tables
465
+
466
+ Returns
467
+ -------
468
+ list[str]
469
+ List of valid attributes for this table type
470
+
471
+ Raises
472
+ ------
473
+ ValueError
474
+ If table_type is invalid or if table_colnames validation fails for varm tables
475
+ """
476
+ if table_type not in ADATA_ARRAY_ATTRS:
477
+ raise ValueError(
478
+ f"table_type {table_type} is not a valid AnnData array attribute. Valid attributes are: {ADATA_ARRAY_ATTRS}"
479
+ )
480
+
481
+ if table_type in [ADATA.X, ADATA.LAYERS]:
482
+ valid_attrs = adata.obs.index.tolist()
483
+ elif table_type == ADATA.VARP:
484
+ valid_attrs = adata.var.index.tolist()
485
+ else: # varm
486
+ if table_colnames is None:
487
+ raise ValueError("table_colnames is required for varm tables")
488
+ if len(table_colnames) != raw_results_table.shape[1]:
489
+ raise ValueError(
490
+ f"table_colnames must have length {raw_results_table.shape[1]}"
491
+ )
492
+ valid_attrs = table_colnames
493
+
494
+ return valid_attrs
495
+
496
+
497
+ def _create_results_df(
498
+ array: np.ndarray, attrs: List[str], var_index: pd.Index, table_type: str
499
+ ) -> pd.DataFrame:
500
+ """Create a DataFrame with the right orientation based on table type.
501
+
502
+ For varm/varp tables:
503
+ - rows are vars (var_index)
504
+ - columns are attrs (features/selected vars)
505
+ For X/layers:
506
+ - rows are attrs (selected observations)
507
+ - columns are vars (var_index)
508
+ - then transpose to get vars as rows
509
+ """
510
+ if table_type in [ADATA.VARM, ADATA.VARP]:
511
+ return pd.DataFrame(array, index=var_index, columns=attrs)
512
+ else:
513
+ return pd.DataFrame(array, index=attrs, columns=var_index).T
514
+
515
+
516
+ def _split_mdata_results_by_modality(
517
+ mdata: mudata.MuData,
518
+ results_data_table: pd.DataFrame,
519
+ ) -> Dict[str, pd.DataFrame]:
520
+ """
521
+ Split a results table by modality and verify compatibility with var tables.
522
+
523
+ Parameters
524
+ ----------
525
+ mdata : mudata.MuData
526
+ MuData object containing multiple modalities
527
+ results_data_table : pd.DataFrame
528
+ Results table with vars as rows, typically from prepare_anndata_results_df()
529
+
530
+ Returns
531
+ -------
532
+ Dict[str, pd.DataFrame]
533
+ Dictionary with modality names as keys and DataFrames as values.
534
+ Each DataFrame contains just the results for that modality.
535
+ The index of each DataFrame is guaranteed to match the corresponding
536
+ modality's var table for later merging.
537
+
538
+ Raises
539
+ ------
540
+ ValueError
541
+ If any modality's vars are not found in the results table
542
+ If any modality's results have different indices than its var table
543
+ """
544
+ # Initialize results dictionary
545
+ results: Dict[str, pd.DataFrame] = {}
546
+
547
+ # Process each modality
548
+ for modality in mdata.mod.keys():
549
+ # Get the var_names for this modality
550
+ mod_vars = mdata.mod[modality].var_names
551
+
552
+ # Check if all modality vars exist in results
553
+ missing_vars = set(mod_vars) - set(results_data_table.index)
554
+ if missing_vars:
555
+ raise ValueError(
556
+ f"Index mismatch in {modality}: vars {missing_vars} not found in results table"
557
+ )
558
+
559
+ # Extract results for this modality
560
+ mod_results = results_data_table.loc[mod_vars]
561
+
562
+ # Verify index alignment with var table
563
+ if not mod_results.index.equals(mdata.mod[modality].var.index):
564
+ raise ValueError(
565
+ f"Index mismatch in {modality}: var table and results subset have different indices"
566
+ )
567
+
568
+ # Store just the results
569
+ results[modality] = mod_results
570
+
571
+ return results
572
+
573
+
574
+ def _extract_ontologies(
575
+ var_table: pd.DataFrame,
576
+ ontologies: Optional[Union[Set[str], Dict[str, str]]] = None,
577
+ index_which_ontology: Optional[str] = None,
578
+ ) -> pd.DataFrame:
579
+ """
580
+ Extract ontology columns from a var table, optionally including the index as an ontology.
581
+
582
+ Parameters
583
+ ----------
584
+ var_table : pd.DataFrame
585
+ The var table containing systematic identifiers
586
+ ontologies : Optional[Union[Set[str], Dict[str, str]]], default=None
587
+ Either:
588
+ - Set of columns to treat as ontologies (these should be entries in ONTOLOGIES_LIST)
589
+ - Dict mapping wide column names to ontology names in the ONTOLOGIES_LIST controlled vocabulary
590
+ - None to automatically detect valid ontology columns based on ONTOLOGIES_LIST
591
+ index_which_ontology : Optional[str], default=None
592
+ If provided, extract the index as this ontology. Must not already exist in var table.
593
+
594
+ Returns
595
+ -------
596
+ pd.DataFrame
597
+ DataFrame containing only the ontology columns, with the same index as var_table
598
+
599
+ Raises
600
+ ------
601
+ ValueError
602
+ If index_which_ontology already exists in var table
603
+ If any renamed ontology column already exists in var table
604
+ If any rename values are duplicated
605
+ If any final column names are not in ONTOLOGIES_LIST
606
+ """
607
+ # Make a copy to avoid modifying original
608
+ var_table = var_table.copy()
609
+
610
+ # Extract index as ontology if requested
611
+ if index_which_ontology is not None:
612
+ if index_which_ontology in var_table.columns:
613
+ raise ValueError(
614
+ f"Cannot use '{index_which_ontology}' as index_which_ontology - "
615
+ f"column already exists in var table"
616
+ )
617
+ # Add the column with index values
618
+ var_table[index_which_ontology] = var_table.index
619
+
620
+ # if ontologies is a dict, validate rename values are unique and don't exist
621
+ if isinstance(ontologies, dict):
622
+ # Check for duplicate rename values
623
+ rename_values = list(ontologies.values())
624
+ if len(rename_values) != len(set(rename_values)):
625
+ duplicates = [val for val in rename_values if rename_values.count(val) > 1]
626
+ raise ValueError(
627
+ f"Duplicate rename values found in ontologies mapping: {duplicates}. "
628
+ "Each ontology must be renamed to a unique value."
629
+ )
630
+
631
+ # Check for existing columns with rename values
632
+ existing_rename_cols = set(rename_values) & set(var_table.columns)
633
+ if existing_rename_cols:
634
+ # Filter out cases where we're mapping a column to itself
635
+ actual_conflicts = {
636
+ rename_val
637
+ for src, rename_val in ontologies.items()
638
+ if rename_val in existing_rename_cols and src != rename_val
639
+ }
640
+ if actual_conflicts:
641
+ raise ValueError(
642
+ f"Cannot rename ontologies - columns already exist in var table: {actual_conflicts}"
643
+ )
644
+
645
+ # Validate and get matching ontologies
646
+ matching_ontologies = species._validate_wide_ontologies(var_table, ontologies)
647
+ if isinstance(ontologies, dict):
648
+ var_ontologies = var_table.loc[:, ontologies.keys()]
649
+ # Rename columns according to the mapping
650
+ var_ontologies = var_ontologies.rename(columns=ontologies)
651
+ else:
652
+ var_ontologies = var_table.loc[:, list(matching_ontologies)]
653
+
654
+ # Final validation: ensure all column names are in ONTOLOGIES_LIST
655
+ invalid_cols = set(var_ontologies.columns) - set(ONTOLOGIES_LIST)
656
+ if invalid_cols:
657
+ raise ValueError(
658
+ f"The following column names are not in ONTOLOGIES_LIST: {invalid_cols}. "
659
+ f"All column names must be one of: {ONTOLOGIES_LIST}"
660
+ )
661
+
662
+ return var_ontologies
663
+
664
+
665
+ class ModalityOntologyConfig(BaseModel):
666
+ """Configuration for ontology handling in a single modality."""
667
+
668
+ ontologies: Optional[Union[Set[str], Dict[str, str]]] = Field(
669
+ description="Ontology configuration. Can be either:\n"
670
+ "- None to automatically detect valid ontology columns\n"
671
+ "- Set of columns to treat as ontologies\n"
672
+ "- Dict mapping wide column names to ontology names"
673
+ )
674
+ index_which_ontology: Optional[str] = Field(
675
+ default=None, description="If provided, extract the index as this ontology"
676
+ )
677
+
678
+
679
+ class MultiModalityOntologyConfig(RootModel):
680
+ """Configuration for ontology handling across multiple modalities."""
681
+
682
+ root: Dict[str, ModalityOntologyConfig]
683
+
684
+ def __getitem__(self, key: str) -> ModalityOntologyConfig:
685
+ return self.root[key]
686
+
687
+ def items(self):
688
+ return self.root.items()
689
+
690
+ @classmethod
691
+ def from_dict(
692
+ cls,
693
+ data: Dict[
694
+ str,
695
+ Dict[str, Union[Optional[Union[Set[str], Dict[str, str]]], Optional[str]]],
696
+ ],
697
+ ) -> "MultiModalityOntologyConfig":
698
+ """
699
+ Create a MultiModalityOntologyConfig from a dictionary.
700
+
701
+ Parameters
702
+ ----------
703
+ data : Dict[str, Dict[str, Union[Optional[Union[Set[str], Dict[str, str]]], Optional[str]]]]
704
+ Dictionary mapping modality names to their ontology configurations.
705
+ Each modality config should have 'ontologies' and optionally 'index_which_ontology'.
706
+ The 'ontologies' field can be:
707
+ - None to automatically detect valid ontology columns
708
+ - Set of columns to treat as ontologies
709
+ - Dict mapping wide column names to ontology names
710
+
711
+ Returns
712
+ -------
713
+ MultiModalityOntologyConfig
714
+ Validated ontology configuration
715
+ """
716
+ return cls(
717
+ root={
718
+ modality: ModalityOntologyConfig(**config)
719
+ for modality, config in data.items()
720
+ }
721
+ )
722
+
723
+ def __iter__(self):
724
+ return iter(self.root)
725
+
726
+ def __len__(self):
727
+ return len(self.root)