scmcp-shared 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scmcp_shared/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
 
2
- __version__ = "0.3.5"
2
+ __version__ = "0.3.7"
3
3
 
scmcp_shared/schema/io.py CHANGED
@@ -7,7 +7,7 @@ from typing import Optional, Literal
7
7
 
8
8
 
9
9
 
10
- class ReadModel(BaseModel):
10
+ class ReadParams(BaseModel):
11
11
  """Input schema for the read tool."""
12
12
  filename: str = Field(
13
13
  ...,
@@ -30,9 +30,9 @@ class ReadModel(BaseModel):
30
30
  default=False,
31
31
  description="Assume the first column stores row names. This is only necessary if these are not strings: strings in the first column are automatically assumed to be row names."
32
32
  )
33
- first_column_obs: bool = Field(
34
- default=True,
35
- description="If True, assume the first column stores observations (cell or barcode) names when provide text file. If False, the data will be transposed."
33
+ transpose: bool = Field(
34
+ default=False,
35
+ description="If True, the data will be transposed."
36
36
  )
37
37
  backup_url: str = Field(
38
38
  default=None,
@@ -82,7 +82,7 @@ class ReadModel(BaseModel):
82
82
  return v
83
83
 
84
84
 
85
- class WriteModel(BaseModel):
85
+ class WriteParams(BaseModel):
86
86
  """Input schema for the write tool."""
87
87
  filename: str = Field(
88
88
  description="Path to save the file. If no extension is provided, the default format will be used."
scmcp_shared/schema/pl.py CHANGED
@@ -72,12 +72,12 @@ class FigureSizeMixin:
72
72
 
73
73
 
74
74
  # 基础可视化模型,包含所有可视化工具共享的字段
75
- class BaseVisualizationModel(BaseModel, LegendMixin, ColorMappingMixin, FigureSizeMixin):
75
+ class BaseVisualizationParams(BaseModel, LegendMixin, ColorMappingMixin, FigureSizeMixin):
76
76
  """基础可视化模型,包含所有可视化工具共享的字段"""
77
77
  pass
78
78
 
79
79
  # 基础嵌入可视化模型,包含所有嵌入可视化工具共享的字段
80
- class BaseEmbeddingModel(BaseVisualizationModel):
80
+ class BaseEmbeddingParams(BaseVisualizationParams):
81
81
  """基础嵌入可视化模型,包含所有嵌入可视化工具共享的字段"""
82
82
 
83
83
  color: Optional[Union[str, List[str]]] = Field(
@@ -176,8 +176,8 @@ class BaseEmbeddingModel(BaseVisualizationModel):
176
176
  )
177
177
 
178
178
 
179
- # 重构 ScatterModel 作为基础散点图模型
180
- class BaseScatterModel(BaseVisualizationModel):
179
+ # 重构 ScatterParams 作为基础散点图模型
180
+ class BaseScatterParams(BaseVisualizationParams):
181
181
  """基础散点图模型"""
182
182
 
183
183
  x: Optional[str] = Field(
@@ -211,8 +211,8 @@ class BaseScatterModel(BaseVisualizationModel):
211
211
  )
212
212
 
213
213
 
214
- # 使用继承关系重构 EnhancedScatterModel
215
- class EnhancedScatterModel(BaseScatterModel):
214
+ # 使用继承关系重构 EnhancedScatterParams
215
+ class EnhancedScatterParams(BaseScatterParams):
216
216
  """Input schema for the enhanced scatter plotting tool."""
217
217
 
218
218
  sort_order: bool = Field(
@@ -261,7 +261,7 @@ class EnhancedScatterModel(BaseScatterModel):
261
261
 
262
262
 
263
263
  # 创建基础统计可视化模型
264
- class BaseStatPlotModel(BaseVisualizationModel):
264
+ class BaseStatPlotParams(BaseVisualizationParams):
265
265
  """基础统计可视化模型,包含统计图表共享的字段"""
266
266
 
267
267
  groupby: Optional[str] = Field(
@@ -328,8 +328,8 @@ class BaseStatPlotModel(BaseVisualizationModel):
328
328
  raise ValueError("size must be a positive integer")
329
329
  return v
330
330
 
331
- # 添加缺失的 BaseMatrixModel
332
- class BaseMatrixModel(BaseVisualizationModel):
331
+ # 添加缺失的 BaseMatrixParams
332
+ class BaseMatrixParams(BaseVisualizationParams):
333
333
  """基础矩阵可视化模型,包含所有矩阵可视化工具共享的字段"""
334
334
 
335
335
  var_names: Union[List[str], Mapping[str, List[str]]] = Field(
@@ -374,8 +374,8 @@ class BaseMatrixModel(BaseVisualizationModel):
374
374
  )
375
375
 
376
376
 
377
- # 重构 HeatmapModel
378
- class HeatmapModel(BaseMatrixModel):
377
+ # 重构 HeatmapParams
378
+ class HeatmapParams(BaseMatrixParams):
379
379
  """Input schema for the heatmap plotting tool."""
380
380
 
381
381
  num_categories: int = Field(
@@ -412,14 +412,14 @@ class HeatmapModel(BaseMatrixModel):
412
412
  return v
413
413
 
414
414
 
415
- # 重构 TracksplotModel
416
- class TracksplotModel(BaseMatrixModel):
415
+ # 重构 TracksplotParams
416
+ class TracksplotParams(BaseMatrixParams):
417
417
  """Input schema for the tracksplot plotting tool."""
418
- # 所有需要的字段已经在 BaseMatrixModel 中定义
418
+ # 所有需要的字段已经在 BaseMatrixParams 中定义
419
419
 
420
420
 
421
- # 重构 ViolinModel
422
- class ViolinModel(BaseStatPlotModel):
421
+ # 重构 ViolinParams
422
+ class ViolinParams(BaseStatPlotParams):
423
423
  """Input schema for the violin plotting tool."""
424
424
 
425
425
  keys: Union[str, List[str]] = Field(
@@ -477,8 +477,8 @@ class ViolinModel(BaseStatPlotModel):
477
477
  return v
478
478
 
479
479
 
480
- # 重构 MatrixplotModel
481
- class MatrixplotModel(BaseMatrixModel):
480
+ # 重构 MatrixplotParams
481
+ class MatrixplotParams(BaseMatrixParams):
482
482
  """Input schema for the matrixplot plotting tool."""
483
483
 
484
484
  num_categories: int = Field(
@@ -523,8 +523,8 @@ class MatrixplotModel(BaseMatrixModel):
523
523
  return v
524
524
 
525
525
 
526
- # 重构 DotplotModel
527
- class DotplotModel(BaseMatrixModel):
526
+ # 重构 DotplotParams
527
+ class DotplotParams(BaseMatrixParams):
528
528
  """Input schema for the dotplot plotting tool."""
529
529
 
530
530
  expression_cutoff: float = Field(
@@ -577,8 +577,8 @@ class DotplotModel(BaseMatrixModel):
577
577
  )
578
578
 
579
579
 
580
- # 重构 RankGenesGroupsModel
581
- class RankGenesGroupsModel(BaseVisualizationModel):
580
+ # 重构 RankGenesGroupsParams
581
+ class RankGenesGroupsParams(BaseVisualizationParams):
582
582
  """Input schema for the rank_genes_groups plotting tool."""
583
583
 
584
584
  n_genes: int = Field(
@@ -631,7 +631,7 @@ class RankGenesGroupsModel(BaseVisualizationModel):
631
631
 
632
632
 
633
633
  # 重构 ClusterMapModel
634
- class ClusterMapModel(BaseModel):
634
+ class ClusterMapParams(BaseModel):
635
635
  """Input schema for the clustermap plotting tool."""
636
636
 
637
637
  obs_keys: Optional[str] = Field(
@@ -645,8 +645,8 @@ class ClusterMapModel(BaseModel):
645
645
 
646
646
 
647
647
 
648
- # 重构 StackedViolinModel
649
- class StackedViolinModel(BaseStatPlotModel):
648
+ # 重构 StackedViolinParams
649
+ class StackedViolinParams(BaseStatPlotParams):
650
650
  """Input schema for the stacked_violin plotting tool."""
651
651
 
652
652
  stripplot: bool = Field(
@@ -688,8 +688,8 @@ class StackedViolinModel(BaseStatPlotModel):
688
688
  return v
689
689
 
690
690
 
691
- # 重构 TrackingModel
692
- class TrackingModel(BaseVisualizationModel):
691
+ # 重构 TrackingParams
692
+ class TrackingParams(BaseVisualizationParams):
693
693
  """Input schema for the tracking plotting tool."""
694
694
 
695
695
  groupby: str = Field(
@@ -717,8 +717,8 @@ class TrackingModel(BaseVisualizationModel):
717
717
  return v
718
718
 
719
719
 
720
- # 重构 EmbeddingDensityModel
721
- class EmbeddingDensityModel(BaseEmbeddingModel):
720
+ # 重构 EmbeddingDensityParams
721
+ class EmbeddingDensityParams(BaseEmbeddingParams):
722
722
  """Input schema for the embedding_density plotting tool."""
723
723
 
724
724
  basis: str = Field(
@@ -751,7 +751,7 @@ class EmbeddingDensityModel(BaseEmbeddingModel):
751
751
  return v
752
752
 
753
753
 
754
- class PCAModel(BaseEmbeddingModel):
754
+ class PCAParams(BaseEmbeddingParams):
755
755
  """Input schema for the PCA plotting tool."""
756
756
 
757
757
  annotate_var_explained: bool = Field(
@@ -761,22 +761,22 @@ class PCAModel(BaseEmbeddingModel):
761
761
 
762
762
 
763
763
  # 重构 UMAP 模型
764
- class UMAPModel(BaseEmbeddingModel):
764
+ class UMAPParams(BaseEmbeddingParams):
765
765
  """Input schema for the UMAP plotting tool."""
766
- # 所有需要的字段已经在 BaseEmbeddingModel 中定义
766
+ # 所有需要的字段已经在 BaseEmbeddingParams 中定义
767
767
 
768
768
 
769
769
  # 重构 TSNE 模型
770
- class TSNEModel(BaseEmbeddingModel):
770
+ class TSNEParams(BaseEmbeddingParams):
771
771
  """Input schema for the TSNE plotting tool."""
772
- # 所有需要的字段已经在 BaseEmbeddingModel 中定义
772
+ # 所有需要的字段已经在 BaseEmbeddingParams 中定义
773
773
 
774
- # 重构 DiffusionMapModel
775
- class DiffusionMapModel(BaseEmbeddingModel):
774
+ # 重构 DiffusionMapParams
775
+ class DiffusionMapParams(BaseEmbeddingParams):
776
776
  """Input schema for the diffusion map plotting tool."""
777
- # 所有需要的字段已经在 BaseEmbeddingModel 中定义
777
+ # 所有需要的字段已经在 BaseEmbeddingParams 中定义
778
778
 
779
- class HighestExprGenesModel(BaseVisualizationModel):
779
+ class HighestExprGenesParams(BaseVisualizationParams):
780
780
  """Input schema for the highest_expr_genes plotting tool."""
781
781
 
782
782
  n_top: int = Field(
@@ -803,7 +803,7 @@ class HighestExprGenesModel(BaseVisualizationModel):
803
803
  return v
804
804
 
805
805
 
806
- class HighlyVariableGenesModel(BaseVisualizationModel):
806
+ class HighlyVariableGenesParams(BaseVisualizationParams):
807
807
  """Input schema for the highly_variable_genes plotting tool."""
808
808
 
809
809
  log: bool = Field(
@@ -816,7 +816,7 @@ class HighlyVariableGenesModel(BaseVisualizationModel):
816
816
  description="Whether to plot highly variable genes or all genes."
817
817
  )
818
818
 
819
- class PCAVarianceRatioModel(BaseVisualizationModel):
819
+ class PCAVarianceRatioParams(BaseVisualizationParams):
820
820
  """Input schema for the pca_variance_ratio plotting tool."""
821
821
 
822
822
  n_pcs: int = Field(
@@ -839,7 +839,7 @@ class PCAVarianceRatioModel(BaseVisualizationModel):
839
839
 
840
840
  # ... existing code ...
841
841
 
842
- class RankGenesGroupsDotplotModel(BaseMatrixModel):
842
+ class RankGenesGroupsDotplotParams(BaseMatrixParams):
843
843
  """Input schema for the rank_genes_groups_dotplot plotting tool."""
844
844
 
845
845
  groups: Optional[Union[str, List[str]]] = Field(
@@ -874,7 +874,7 @@ class RankGenesGroupsDotplotModel(BaseMatrixModel):
874
874
 
875
875
 
876
876
 
877
- class EmbeddingModel(BaseEmbeddingModel):
877
+ class EmbeddingParams(BaseEmbeddingParams):
878
878
  """Input schema for the embedding plotting tool."""
879
879
 
880
880
  basis: str = Field(
scmcp_shared/schema/pp.py CHANGED
@@ -73,7 +73,7 @@ class FilterGenes(BaseModel):
73
73
  return v
74
74
 
75
75
 
76
- class SubsetCellModel(BaseModel):
76
+ class SubsetCellParams(BaseModel):
77
77
  """Input schema for subsetting AnnData objects based on various criteria."""
78
78
  obs_key: Optional[str] = Field(
79
79
  default=None,
@@ -109,7 +109,7 @@ class SubsetCellModel(BaseModel):
109
109
  )
110
110
 
111
111
 
112
- class SubsetGeneModel(BaseModel):
112
+ class SubsetGeneParams(BaseModel):
113
113
  """Input schema for subsetting AnnData objects based on various criteria."""
114
114
  min_counts: Optional[int] = Field(
115
115
  default=None,
@@ -196,7 +196,7 @@ class CalculateQCMetrics(BaseModel):
196
196
 
197
197
 
198
198
 
199
- class Log1PModel(BaseModel):
199
+ class Log1PParams(BaseModel):
200
200
  """Input schema for the log1p preprocessing tool."""
201
201
 
202
202
  base: Optional[Union[int, float]] = Field(
@@ -233,7 +233,7 @@ class Log1PModel(BaseModel):
233
233
 
234
234
 
235
235
 
236
- class HighlyVariableGenesModel(BaseModel):
236
+ class HighlyVariableGenesParams(BaseModel):
237
237
  """Input schema for the highly_variable_genes preprocessing tool."""
238
238
 
239
239
  layer: Optional[str] = Field(
@@ -307,7 +307,7 @@ class HighlyVariableGenesModel(BaseModel):
307
307
  return v
308
308
 
309
309
 
310
- class RegressOutModel(BaseModel):
310
+ class RegressOutParams(BaseModel):
311
311
  """Input schema for the regress_out preprocessing tool."""
312
312
 
313
313
  keys: Union[str, List[str]] = Field(
@@ -340,7 +340,7 @@ class RegressOutModel(BaseModel):
340
340
  raise ValueError("keys must be a string or list of strings")
341
341
 
342
342
 
343
- class ScaleModel(BaseModel):
343
+ class ScaleParams(BaseModel):
344
344
  """Input schema for the scale preprocessing tool."""
345
345
 
346
346
  zero_center: bool = Field(
@@ -376,7 +376,7 @@ class ScaleModel(BaseModel):
376
376
  return v
377
377
 
378
378
 
379
- class CombatModel(BaseModel):
379
+ class CombatParams(BaseModel):
380
380
  """Input schema for the combat batch effect correction tool."""
381
381
 
382
382
  key: str = Field(
@@ -405,7 +405,7 @@ class CombatModel(BaseModel):
405
405
  return v
406
406
 
407
407
 
408
- class ScrubletModel(BaseModel):
408
+ class ScrubletParams(BaseModel):
409
409
  """Input schema for the scrublet doublet prediction tool."""
410
410
 
411
411
  adata_sim: Optional[str] = Field(
@@ -511,7 +511,7 @@ class ScrubletModel(BaseModel):
511
511
  return v.lower()
512
512
 
513
513
 
514
- class NeighborsModel(BaseModel):
514
+ class NeighborsParams(BaseModel):
515
515
  """Input schema for the neighbors graph construction tool."""
516
516
 
517
517
  n_neighbors: int = Field(
@@ -589,7 +589,7 @@ class NeighborsModel(BaseModel):
589
589
  return v
590
590
 
591
591
 
592
- class NormalizeTotalModel(BaseModel):
592
+ class NormalizeTotalParams(BaseModel):
593
593
  """Input schema for the normalize_total preprocessing tool."""
594
594
 
595
595
  target_sum: Optional[float] = Field(
scmcp_shared/schema/tl.py CHANGED
@@ -2,7 +2,7 @@ from pydantic import Field, field_validator, ValidationInfo, BaseModel
2
2
  from typing import Optional, Union, List, Dict, Any, Tuple, Literal, Mapping
3
3
 
4
4
 
5
- class TSNEModel(BaseModel):
5
+ class TSNEParams(BaseModel):
6
6
  """Input schema for the t-SNE dimensionality reduction tool."""
7
7
  n_pcs: Optional[int] = Field(
8
8
  default=None,
@@ -59,7 +59,7 @@ class TSNEModel(BaseModel):
59
59
  return v.lower()
60
60
 
61
61
 
62
- class UMAPModel(BaseModel):
62
+ class UMAPParams(BaseModel):
63
63
  """Input schema for the UMAP dimensionality reduction tool."""
64
64
 
65
65
  min_dist: Optional[float] = Field(
@@ -145,7 +145,7 @@ class UMAPModel(BaseModel):
145
145
  return v.lower()
146
146
 
147
147
 
148
- class DrawGraphModel(BaseModel):
148
+ class DrawGraphParams(BaseModel):
149
149
  """Input schema for the force-directed graph drawing tool."""
150
150
 
151
151
  layout: str = Field(
@@ -199,7 +199,7 @@ class DrawGraphModel(BaseModel):
199
199
  return v
200
200
 
201
201
 
202
- class DiffMapModel(BaseModel):
202
+ class DiffMapParams(BaseModel):
203
203
  """Input schema for the Diffusion Maps dimensionality reduction tool."""
204
204
 
205
205
  n_comps: int = Field(
@@ -229,7 +229,7 @@ class DiffMapModel(BaseModel):
229
229
  return v
230
230
 
231
231
 
232
- class EmbeddingDensityModel(BaseModel):
232
+ class EmbeddingDensityParams(BaseModel):
233
233
  """Input schema for the embedding density calculation tool."""
234
234
 
235
235
  basis: str = Field(
@@ -257,7 +257,7 @@ class EmbeddingDensityModel(BaseModel):
257
257
  return v
258
258
 
259
259
 
260
- class LeidenModel(BaseModel):
260
+ class LeidenParams(BaseModel):
261
261
  """Input schema for the Leiden clustering algorithm."""
262
262
 
263
263
  resolution: Optional[float] = Field(
@@ -324,7 +324,7 @@ class LeidenModel(BaseModel):
324
324
  return v
325
325
 
326
326
 
327
- class LouvainModel(BaseModel):
327
+ class LouvainParams(BaseModel):
328
328
  """Input schema for the Louvain clustering algorithm."""
329
329
 
330
330
  resolution: Optional[float] = Field(
@@ -396,7 +396,7 @@ class LouvainModel(BaseModel):
396
396
  return v
397
397
 
398
398
 
399
- class DendrogramModel(BaseModel):
399
+ class DendrogramParams(BaseModel):
400
400
  """Input schema for the hierarchical clustering dendrogram tool."""
401
401
 
402
402
  groupby: str = Field(
@@ -461,7 +461,7 @@ class DendrogramModel(BaseModel):
461
461
  return v
462
462
 
463
463
 
464
- class DPTModel(BaseModel):
464
+ class DPTParams(BaseModel):
465
465
  """Input schema for the Diffusion Pseudotime (DPT) tool."""
466
466
 
467
467
  n_dcs: int = Field(
@@ -510,7 +510,7 @@ class DPTModel(BaseModel):
510
510
  raise ValueError("min_group_size must be between 0 and 1")
511
511
  return v
512
512
 
513
- class PAGAModel(BaseModel):
513
+ class PAGAParams(BaseModel):
514
514
  """Input schema for the Partition-based Graph Abstraction (PAGA) tool."""
515
515
 
516
516
  groups: Optional[str] = Field(
@@ -531,14 +531,14 @@ class PAGAModel(BaseModel):
531
531
  )
532
532
 
533
533
  @field_validator('model')
534
- def validate_model(cls, v: str) -> str:
534
+ def validate_Params(cls, v: str) -> str:
535
535
  """Validate model version is supported"""
536
536
  if v not in ['v1.2', 'v1.0']:
537
537
  raise ValueError("model must be either 'v1.2' or 'v1.0'")
538
538
  return v
539
539
 
540
540
 
541
- class IngestModel(BaseModel):
541
+ class IngestParams(BaseModel):
542
542
  """Input schema for the ingest tool that maps labels and embeddings from reference data to new data."""
543
543
 
544
544
  obs: Optional[Union[str, List[str]]] = Field(
@@ -587,7 +587,7 @@ class IngestModel(BaseModel):
587
587
  return v.lower()
588
588
 
589
589
 
590
- class RankGenesGroupsModel(BaseModel):
590
+ class RankGenesGroupsParams(BaseModel):
591
591
  """Input schema for the rank_genes_groups tool."""
592
592
 
593
593
  groupby: str = Field(
@@ -669,7 +669,7 @@ class RankGenesGroupsModel(BaseModel):
669
669
  return v
670
670
 
671
671
 
672
- class FilterRankGenesGroupsModel(BaseModel):
672
+ class FilterRankGenesGroupsParams(BaseModel):
673
673
  """Input schema for filtering ranked genes groups."""
674
674
 
675
675
  key: Optional[str] = Field(
@@ -732,7 +732,7 @@ class FilterRankGenesGroupsModel(BaseModel):
732
732
  return v
733
733
 
734
734
 
735
- class MarkerGeneOverlapModel(BaseModel):
735
+ class MarkerGeneOverlapParams(BaseModel):
736
736
  """Input schema for the marker gene overlap tool."""
737
737
 
738
738
  key: str = Field(
@@ -803,7 +803,7 @@ class MarkerGeneOverlapModel(BaseModel):
803
803
  return v
804
804
 
805
805
 
806
- class ScoreGenesModel(BaseModel):
806
+ class ScoreGenesParams(BaseModel):
807
807
  """Input schema for the score_genes tool that calculates gene scores based on average expression."""
808
808
 
809
809
  ctrl_size: int = Field(
@@ -846,7 +846,7 @@ class ScoreGenesModel(BaseModel):
846
846
  return v
847
847
 
848
848
 
849
- class ScoreGenesCellCycleModel(BaseModel):
849
+ class ScoreGenesCellCycleParams(BaseModel):
850
850
  """Input schema for the score_genes_cell_cycle tool that scores cell cycle genes."""
851
851
 
852
852
  s_genes: List[str] = Field(
@@ -896,7 +896,7 @@ class ScoreGenesCellCycleModel(BaseModel):
896
896
 
897
897
 
898
898
 
899
- class PCAModel(BaseModel):
899
+ class PCAParams(BaseModel):
900
900
  """Input schema for the PCA preprocessing tool."""
901
901
 
902
902
  n_comps: Optional[int] = Field(
@@ -10,7 +10,7 @@ from typing import Optional, Union, List, Dict, Any, Callable, Collection, Liter
10
10
 
11
11
 
12
12
 
13
- class MarkVarModel(BaseModel):
13
+ class MarkVarParams(BaseModel):
14
14
  """Determine or mark if each gene meets specific conditions and store results in adata.var as boolean values"""
15
15
 
16
16
  var_name: str = Field(
@@ -32,15 +32,15 @@ class MarkVarModel(BaseModel):
32
32
  )
33
33
 
34
34
 
35
- class ListVarModel(BaseModel):
35
+ class ListVarParams(BaseModel):
36
36
  """ListVarModel"""
37
37
  pass
38
38
 
39
- class ListObsModel(BaseModel):
39
+ class ListObsParams(BaseModel):
40
40
  """ListObsModel"""
41
41
  pass
42
42
 
43
- class VarNamesModel(BaseModel):
43
+ class VarNamesParams(BaseModel):
44
44
  """ListObsModel"""
45
45
  var_names: List[str] = Field(
46
46
  default=None,
@@ -48,7 +48,7 @@ class VarNamesModel(BaseModel):
48
48
  )
49
49
 
50
50
 
51
- class ConcatBaseModel(BaseModel):
51
+ class ConcatBaseParams(BaseModel):
52
52
  """Model for concatenating AnnData objects"""
53
53
 
54
54
  axis: Literal['obs', 0, 'var', 1] = Field(
@@ -89,7 +89,7 @@ class ConcatBaseModel(BaseModel):
89
89
  )
90
90
 
91
91
 
92
- class DPTIROOTModel(BaseModel):
92
+ class DPTIROOTParams(BaseModel):
93
93
  """Input schema for setting the root cell for diffusion pseudotime."""
94
94
  diffmap_key: str = Field(
95
95
  default="X_diffmap",
@@ -103,7 +103,7 @@ class DPTIROOTModel(BaseModel):
103
103
  )
104
104
 
105
105
 
106
- class CelltypeMapCellTypeModel(BaseModel):
106
+ class CelltypeMapCellTypeParams(BaseModel):
107
107
  """Input schema for mapping cluster IDs to cell type names."""
108
108
  cluster_key: str = Field(
109
109
  description="Key in adata.obs containing cluster IDs."
@@ -122,14 +122,14 @@ class CelltypeMapCellTypeModel(BaseModel):
122
122
 
123
123
 
124
124
 
125
- class AddLayerModel(BaseModel):
125
+ class AddLayerParams(BaseModel):
126
126
  """Input schema for adding a layer to AnnData object."""
127
127
  layer_name: str = Field(
128
128
  description="Name of the layer to add to adata.layers."
129
129
  )
130
130
 
131
131
 
132
- class QueryOpLogModel(BaseModel):
132
+ class QueryOpLogParams(BaseModel):
133
133
  """QueryOpLogModel"""
134
134
  n: int = Field(
135
135
  default=10,
@@ -33,25 +33,20 @@ class BaseMCP:
33
33
  methods = inspect.getmembers(self, predicate=inspect.ismethod)
34
34
 
35
35
  # Filter methods that start with _tool_
36
- tool_methods = {
37
- name[6:]: method # Remove '_tool_' prefix
38
- for name, method in methods
39
- if name.startswith('_tool_')
40
- }
36
+ tool_methods = [tl_method() for name, tl_method in methods if name.startswith('_tool_')]
41
37
 
42
38
  # Filter tools based on include/exclude lists
43
39
  if self.include_tools is not None:
44
- tool_methods = {k: v for k, v in tool_methods.items() if k in self.include_tools}
40
+ tool_methods = [tl for tl in tool_methods if tl.name in self.include_tools]
45
41
 
46
42
  if self.exclude_tools is not None:
47
- tool_methods = {k: v for k, v in tool_methods.items() if k not in self.exclude_tools}
43
+ tool_methods = [tl for tl in tool_methods if tl.name not in self.exclude_tools]
48
44
 
49
45
  # Register filtered tools
50
- for tool_name, tool_method in tool_methods.items():
46
+ for tool in tool_methods:
51
47
  # Get the function returned by the tool method
52
- tool_func = tool_method()
53
- if tool_func is not None:
54
- self.mcp.add_tool(tool_func, name=tool_name)
48
+ if tool is not None:
49
+ self.mcp.add_tool(tool)
55
50
 
56
51
 
57
52
  class AdataState:
scmcp_shared/server/io.py CHANGED
@@ -3,6 +3,7 @@ import inspect
3
3
  from pathlib import Path
4
4
  import scanpy as sc
5
5
  from fastmcp import FastMCP, Context
6
+ from fastmcp.tools.tool import Tool
6
7
  from fastmcp.exceptions import ToolError
7
8
  from ..schema import AdataInfo
8
9
  from ..schema.io import *
@@ -16,7 +17,7 @@ class ScanpyIOMCP(BaseMCP):
16
17
  super().__init__("SCMCP-IO-Server", include_tools, exclude_tools, AdataInfo)
17
18
 
18
19
  def _tool_read(self):
19
- def _read(request: ReadModel, adinfo: self.AdataInfo=self.AdataInfo()):
20
+ def _read(request: ReadParams, adinfo: self.AdataInfo=self.AdataInfo()):
20
21
  """
21
22
  Read data from 10X directory or various file formats (h5ad, 10x, text files, etc.).
22
23
  """
@@ -33,7 +34,7 @@ class ScanpyIOMCP(BaseMCP):
33
34
  elif file.is_file():
34
35
  func_kwargs = filter_args(request, sc.read)
35
36
  adata = sc.read(**func_kwargs)
36
- if not kwargs.get("first_column_obs", True):
37
+ if not kwargs.get("transpose", True):
37
38
  adata = adata.T
38
39
  else:
39
40
  raise FileNotFoundError(f"{kwargs['filename']} does not exist")
@@ -47,8 +48,18 @@ class ScanpyIOMCP(BaseMCP):
47
48
  adata.layers["counts"] = adata.X
48
49
  adata.var_names_make_unique()
49
50
  adata.obs_names_make_unique()
51
+ adata.obs["scmcp_sampleid"] = adinfo.sampleid or ads.active_id
50
52
  ads.set_adata(adata, adinfo=adinfo)
51
- return generate_msg(adinfo, adata, ads)
53
+ return [
54
+ {
55
+ "sampleid": adinfo.sampleid or ads.active_id,
56
+ "adtype": adinfo.adtype,
57
+ "adata": adata,
58
+ "adata.obs_names[:10]": adata.obs_names[:10],
59
+ "adata.var_names[:10]": adata.var_names[:10],
60
+ "notice": "check obs_names and var_names. transpose the data if needed"
61
+ }
62
+ ]
52
63
  except ToolError as e:
53
64
  raise ToolError(e)
54
65
  except Exception as e:
@@ -56,10 +67,10 @@ class ScanpyIOMCP(BaseMCP):
56
67
  raise ToolError(e.__context__)
57
68
  else:
58
69
  raise ToolError(e)
59
- return _read
70
+ return Tool.from_function(_read, name="read")
60
71
 
61
72
  def _tool_write(self):
62
- def _write(request: WriteModel, adinfo: self.AdataInfo=self.AdataInfo()):
73
+ def _write(request: WriteParams, adinfo: self.AdataInfo=self.AdataInfo()):
63
74
  """save adata into a file."""
64
75
  try:
65
76
  res = forward_request("io_write", request, adinfo)
@@ -77,7 +88,7 @@ class ScanpyIOMCP(BaseMCP):
77
88
  raise ToolError(e.__context__)
78
89
  else:
79
90
  raise ToolError(e)
80
- return _write
91
+ return Tool.from_function(_write, name="write")
81
92
 
82
93
 
83
94
  # Create an instance of the class