marsilea 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
marsilea/plotter/base.py CHANGED
@@ -34,8 +34,7 @@ class DataLoader:
34
34
  elif isinstance(data, Iterable):
35
35
  return self.from_iterable(data, target=target)
36
36
  else:
37
- raise TypeError(
38
- f"Your input data with type {type(data)} is not supported.")
37
+ raise TypeError(f"Your input data with type {type(data)} is not supported.")
39
38
 
40
39
  def from_dataframe(self, data: pd.DataFrame, target="2d"):
41
40
  ndata = data.to_numpy()
@@ -76,6 +75,26 @@ class DataLoader:
76
75
 
77
76
  @dataclass(repr=False)
78
77
  class RenderSpec:
78
+ """The container class that holds the \
79
+ rendering data and parameters for each axes
80
+
81
+ Attributes
82
+ ----------
83
+ data : Any
84
+ The data to be rendered
85
+ params : List[Dict]
86
+ The parameters for each data
87
+ group_data : Any
88
+ The group data
89
+ group_params : Any
90
+ The group parameters
91
+ current_ix : int
92
+ The current index of the axes
93
+ total : int
94
+ The total number of axes
95
+
96
+
97
+ """
79
98
  ax: Axes
80
99
 
81
100
  data: Any
@@ -103,10 +122,12 @@ class RenderPlanLabel:
103
122
  label_props = {
104
123
  "left": dict(loc="center right", bbox_to_anchor=(0, 0.5)),
105
124
  "right": dict(loc="center left", bbox_to_anchor=(1, 0.5)),
106
- "top": dict(loc="lower center", bbox_to_anchor=(0.5, 1),
107
- prop=dict(rotation=90)),
108
- "bottom": dict(loc="upper center", bbox_to_anchor=(0.5, 0),
109
- prop=dict(rotation=90)),
125
+ "top": dict(
126
+ loc="lower center", bbox_to_anchor=(0.5, 1), prop=dict(rotation=90)
127
+ ),
128
+ "bottom": dict(
129
+ loc="upper center", bbox_to_anchor=(0.5, 0), prop=dict(rotation=90)
130
+ ),
110
131
  }
111
132
 
112
133
  label_loc = {
@@ -123,7 +144,6 @@ class RenderPlanLabel:
123
144
 
124
145
  def add(self, axes, side):
125
146
  if side != "main":
126
-
127
147
  if isinstance(axes, Sequence):
128
148
  if self.loc in ["top", "left"]:
129
149
  label_ax = axes[0]
@@ -138,16 +158,21 @@ class RenderPlanLabel:
138
158
  loc = self.loc
139
159
  label_props = self.label_props[loc]
140
160
  bbox_loc = label_props["loc"]
141
- bbox_to_anchor = label_props['bbox_to_anchor']
142
- prop = label_props.get('prop', {})
161
+ bbox_to_anchor = label_props["bbox_to_anchor"]
162
+ prop = label_props.get("prop", {})
143
163
  if self.props is not None:
144
164
  prop.update(self.props)
145
165
 
146
- title = AnchoredText(self.label, loc=bbox_loc,
147
- bbox_to_anchor=bbox_to_anchor,
148
- prop=prop, pad=0.3, borderpad=0,
149
- bbox_transform=label_ax.transAxes,
150
- frameon=False)
166
+ title = AnchoredText(
167
+ self.label,
168
+ loc=bbox_loc,
169
+ bbox_to_anchor=bbox_to_anchor,
170
+ prop=prop,
171
+ pad=0.3,
172
+ borderpad=0,
173
+ bbox_transform=label_ax.transAxes,
174
+ frameon=False,
175
+ )
151
176
  label_ax.add_artist(title)
152
177
 
153
178
 
@@ -155,17 +180,18 @@ class MetaRenderPlan(type):
155
180
  """Metaclass for RenderPlan"""
156
181
 
157
182
  def __init__(cls, name, bases, attrs):
158
- allow_labeling = attrs.get('allow_labeling', False)
183
+ allow_labeling = attrs.get("allow_labeling", False)
159
184
  if allow_labeling:
185
+
160
186
  def new_render(self, axes):
161
- attrs['render'](self, axes)
187
+ attrs["render"](self, axes)
162
188
  self._plan_label.add(axes, self.side)
163
189
 
164
- setattr(cls, 'render', new_render)
190
+ setattr(cls, "render", new_render)
165
191
 
166
192
 
167
193
  class RenderPlan:
168
- """The base class for every plot in Heatgraphy
194
+ """The base class for every plotter in Marsilea
169
195
 
170
196
  Attributes
171
197
  ----------
@@ -183,6 +209,7 @@ class RenderPlan:
183
209
  This only works if the RenderPlan is rendered on main canvas
184
210
 
185
211
  """
212
+
186
213
  name: str = None
187
214
  size: float = None
188
215
  side: str = "top"
@@ -212,8 +239,7 @@ class RenderPlan:
212
239
  chunks = [side_str, zorder_str]
213
240
  else:
214
241
  chunks = [f"name='{self.name}'", side_str, zorder_str]
215
- return f"{self.__class__.__name__}" \
216
- f"({', '.join(chunks)})"
242
+ return f"{self.__class__.__name__}" f"({', '.join(chunks)})"
217
243
 
218
244
  def set(self, **kwargs):
219
245
  for k, v in kwargs.items():
@@ -245,11 +271,13 @@ class RenderPlan:
245
271
  if group_data is not None:
246
272
  if self.has_deform:
247
273
  group_data = np.asarray(group_data)
248
- if self.deform.is_col_split & self.is_body & \
249
- self.deform.is_col_cluster:
274
+ if self.deform.is_col_split & self.is_body & self.deform.is_col_cluster:
250
275
  return group_data[self.deform.col_chunk_index]
251
- elif self.deform.is_row_split & self.is_flank & \
252
- self.deform.is_row_cluster:
276
+ elif (
277
+ self.deform.is_row_split
278
+ & self.is_flank
279
+ & self.deform.is_row_cluster
280
+ ):
253
281
  return group_data[self.deform.row_chunk_index]
254
282
  return group_data
255
283
 
@@ -332,8 +360,9 @@ class RenderPlan:
332
360
  total = len(axes)
333
361
  dispatch = zip(axes, spec_data, params, group_params)
334
362
  for i, (ax, d, p, gp) in enumerate(dispatch):
335
- spec = RenderSpec(ax=ax, data=d, params=p, group_params=gp,
336
- current_ix=i, total=total)
363
+ spec = RenderSpec(
364
+ ax=ax, data=d, params=p, group_params=gp, current_ix=i, total=total
365
+ )
337
366
  spec_list.append(spec)
338
367
  return spec_list
339
368
 
@@ -350,17 +379,21 @@ class RenderPlan:
350
379
  else:
351
380
  spec_data = [deform_func(d) for d in datasets]
352
381
  if params is not None:
353
- params = deform_func(
354
- np.asarray(params, dtype=object))
382
+ params = deform_func(np.asarray(params, dtype=object))
355
383
  else:
356
384
  spec_data = datasets[0] if len(datasets) == 1 else datasets
357
385
  return RenderSpec(ax=ax, data=spec_data, params=params)
358
386
 
359
387
  def get_render_spec(self, axes):
360
- if self.is_split:
361
- return self._get_split_render_spec(axes)
362
- else:
363
- return self._get_intact_render_spec(axes)
388
+ try:
389
+ if self.is_split:
390
+ return self._get_split_render_spec(axes)
391
+ else:
392
+ return self._get_intact_render_spec(axes)
393
+ except Exception as _:
394
+ raise DataError(
395
+ f"Please check your data input with {self.__class__.__name__}"
396
+ )
364
397
 
365
398
  # def get_render_data(self):
366
399
  # """Define how render data is organized
@@ -511,9 +544,8 @@ class AxisOption:
511
544
 
512
545
 
513
546
  class StatsBase(RenderPlan):
514
- """A base class for rendering statistics plot
547
+ """A base class for rendering statistics plot"""
515
548
 
516
- """
517
549
  render_main = True
518
550
  orient = None
519
551
  axis_options = None
@@ -527,17 +559,16 @@ class StatsBase(RenderPlan):
527
559
  if self.has_deform:
528
560
  orient = self.get_orient()
529
561
  if self.side == "main":
562
+ orient_mapper = {"h": "horizontally", "v": "vertically"}
530
563
 
531
- orient_mapper = {
532
- "h": "horizontally",
533
- "v": "vertically"
534
- }
535
-
536
- if (((orient == "v") & self.deform.is_row_split) or
537
- ((orient == "h") & self.deform.is_col_split)):
564
+ if ((orient == "v") & self.deform.is_row_split) or (
565
+ (orient == "h") & self.deform.is_col_split
566
+ ):
538
567
  plot_dir = orient_mapper[self.get_orient()]
539
- msg = f"{self.__class__.__name__} is oriented " \
540
- f"{plot_dir} should only be split {plot_dir}"
568
+ msg = (
569
+ f"{self.__class__.__name__} is oriented "
570
+ f"{plot_dir} should only be split {plot_dir}"
571
+ )
541
572
  raise SplitConflict(msg)
542
573
 
543
574
  if self.get_orient() == "v":
@@ -548,12 +579,10 @@ class StatsBase(RenderPlan):
548
579
  def _setup_axis(self, ax):
549
580
  if self.get_orient() == "h":
550
581
  despine(ax=ax, left=True)
551
- ax.tick_params(left=False, labelleft=False,
552
- bottom=True, labelbottom=True)
582
+ ax.tick_params(left=False, labelleft=False, bottom=True, labelbottom=True)
553
583
  else:
554
584
  despine(ax=ax, bottom=True)
555
- ax.tick_params(left=True, labelleft=True,
556
- bottom=False, labelbottom=False)
585
+ ax.tick_params(left=True, labelleft=True, bottom=False, labelbottom=False)
557
586
 
558
587
  def align_lim(self, axes):
559
588
  is_inverted = False
@@ -581,9 +610,7 @@ class StatsBase(RenderPlan):
581
610
  ax.set_xlim(*lims) if is_h else ax.set_ylim(*lims)
582
611
 
583
612
  def render(self, axes):
584
-
585
613
  if self.is_split:
586
-
587
614
  for spec in self.get_render_spec(axes):
588
615
  self.render_ax(spec)
589
616
  self.align_lim(axes)
marsilea/plotter/bio.py CHANGED
@@ -8,18 +8,19 @@ from .base import StatsBase
8
8
  from ..utils import pairwise, ECHARTS16
9
9
 
10
10
 
11
- def path_char(pos,
12
- extend,
13
- t,
14
- ax,
15
- width=1.,
16
- flip=False,
17
- mirror=False,
18
- direction="h",
19
- prop=None,
20
- usetex=False,
21
- **kwargs
22
- ):
11
+ def path_char(
12
+ pos,
13
+ extend,
14
+ t,
15
+ ax,
16
+ width=1.0,
17
+ flip=False,
18
+ mirror=False,
19
+ direction="h",
20
+ prop=None,
21
+ usetex=False,
22
+ **kwargs,
23
+ ):
23
24
  w = width
24
25
  h = extend[1] - extend[0]
25
26
 
@@ -63,10 +64,12 @@ def path_char(pos,
63
64
  tx = bbox.xmin
64
65
  ty = bbox.ymin + char_shift
65
66
 
66
- transformation = (Affine2D()
67
- .translate(tx=-tmp_bbox.xmin, ty=-tmp_bbox.ymin)
68
- .scale(sx=hs, sy=vs)
69
- .translate(tx=tx, ty=ty))
67
+ transformation = (
68
+ Affine2D()
69
+ .translate(tx=-tmp_bbox.xmin, ty=-tmp_bbox.ymin)
70
+ .scale(sx=hs, sy=vs)
71
+ .translate(tx=tx, ty=ty)
72
+ )
70
73
 
71
74
  char_path = transformation.transform_path(tmp_path)
72
75
 
@@ -109,12 +112,14 @@ class SeqLogo(StatsBase):
109
112
 
110
113
  render_main = False
111
114
 
112
- def __init__(self,
113
- matrix: pd.DataFrame,
114
- width=.9,
115
- color_encode=None,
116
- stack="descending", # "descending", "ascending", "normal"
117
- **kwargs):
115
+ def __init__(
116
+ self,
117
+ matrix: pd.DataFrame,
118
+ width=0.9,
119
+ color_encode=None,
120
+ stack="descending", # "descending", "ascending", "normal"
121
+ **kwargs,
122
+ ):
118
123
  self.matrix = matrix
119
124
  self.letters = matrix.index.to_numpy()
120
125
  self.set_data(matrix.to_numpy())
@@ -143,15 +148,21 @@ class SeqLogo(StatsBase):
143
148
  col = col[ix]
144
149
 
145
150
  extends = [0] + list(np.cumsum(col))
146
- pos = i + .5
151
+ pos = i + 0.5
147
152
  for t, extend in zip(letters, pairwise(extends)):
148
153
  facecolor = self.color_encode[t]
149
- options = {"facecolor": facecolor,
150
- "edgecolor": "none",
151
- **self.options}
152
- path_char(pos, extend, t, ax, flip=flip, mirror=mirror,
153
- width=self.width, direction=direction,
154
- **options)
154
+ options = {"facecolor": facecolor, "edgecolor": "none", **self.options}
155
+ path_char(
156
+ pos,
157
+ extend,
158
+ t,
159
+ ax,
160
+ flip=flip,
161
+ mirror=mirror,
162
+ width=self.width,
163
+ direction=direction,
164
+ **options,
165
+ )
155
166
  if self.is_body:
156
167
  ax.set_xlim(0, data.shape[1])
157
168
  ax.set_ylim(0, lim)
@@ -165,4 +176,3 @@ class SeqLogo(StatsBase):
165
176
  if self.side == "bottom":
166
177
  ax.invert_yaxis()
167
178
  ax.set_axis_off()
168
-