marsilea 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
marsilea/__init__.py CHANGED
@@ -1,9 +1,10 @@
1
1
  """Create x-layout visualization"""
2
2
 
3
- __version__ = "0.3.1"
3
+ __version__ = "0.3.3"
4
4
 
5
+ import marsilea.plotter as plotter
5
6
  from ._deform import Deformation
6
- from .base import WhiteBoard, ClusterBoard
7
+ from .base import WhiteBoard, ClusterBoard, ZeroWidth, ZeroHeight
7
8
  from .dataset import load_data
8
9
  from .dendrogram import Dendrogram, GroupDendrogram
9
10
  from .heatmap import Heatmap, SizedHeatmap, CatHeatmap
marsilea/_api.py CHANGED
@@ -6,5 +6,6 @@ Utilities functions to check user input
6
6
  def check_in_list(options, **kwargs):
7
7
  for name, value in kwargs.items():
8
8
  if value not in options:
9
- raise ValueError(f"You input unknown {name}={value}, "
10
- f"options are {options}")
9
+ raise ValueError(
10
+ f"You input unknown {name}={value}, " f"options are {options}"
11
+ )
marsilea/_deform.py CHANGED
@@ -15,6 +15,7 @@ class Deformation:
15
15
  #. Compute the ratio to split axes that match with data
16
16
 
17
17
  """
18
+
18
19
  is_row_split = False
19
20
  is_col_split = False
20
21
  is_row_cluster = False
@@ -65,22 +66,25 @@ class Deformation:
65
66
 
66
67
  def set_data_row_reindex(self, reindex):
67
68
  if len(reindex) != self._nrow:
68
- msg = f"Length of reindex ({len(reindex)}) should match " \
69
- f"data row with {self._nrow} elements"
69
+ msg = (
70
+ f"Length of reindex ({len(reindex)}) should match "
71
+ f"data row with {self._nrow} elements"
72
+ )
70
73
  raise ValueError(msg)
71
74
  self.data_row_reindex = reindex
72
75
  self._row_clustered = False
73
76
 
74
77
  def set_data_col_reindex(self, reindex):
75
78
  if len(reindex) != self._ncol:
76
- msg = f"Length of reindex ({len(reindex)}) should match " \
77
- f"data col with {self._ncol} elements"
79
+ msg = (
80
+ f"Length of reindex ({len(reindex)}) should match "
81
+ f"data col with {self._ncol} elements"
82
+ )
78
83
  raise ValueError(msg)
79
84
  self.data_col_reindex = reindex
80
85
  self._col_clustered = False
81
86
 
82
- def set_cluster(self, col=None, row=None, use_meta=True,
83
- linkage=None, **kwargs):
87
+ def set_cluster(self, col=None, row=None, use_meta=True, linkage=None, **kwargs):
84
88
  if col is not None:
85
89
  self.is_col_cluster = col
86
90
  self.col_cluster_kws = kwargs
@@ -105,8 +109,7 @@ class Deformation:
105
109
  def set_split_row(self, breakpoints=None, order=None):
106
110
  if breakpoints is not None:
107
111
  self.is_row_split = True
108
- self.row_breakpoints = [0, *np.sort(np.asarray(breakpoints)),
109
- self._nrow]
112
+ self.row_breakpoints = [0, *np.sort(np.asarray(breakpoints)), self._nrow]
110
113
  if order is None:
111
114
  order = np.arange(len(breakpoints) + 1)
112
115
  self.row_split_order = order
@@ -114,8 +117,7 @@ class Deformation:
114
117
  def set_split_col(self, breakpoints=None, order=None):
115
118
  if breakpoints is not None:
116
119
  self.is_col_split = True
117
- self.col_breakpoints = [0, *np.sort(np.asarray(breakpoints)),
118
- self._ncol]
120
+ self.col_breakpoints = [0, *np.sort(np.asarray(breakpoints)), self._ncol]
119
121
  if order is None:
120
122
  order = np.arange(len(breakpoints) + 1)
121
123
  self.col_split_order = order
@@ -125,8 +127,7 @@ class Deformation:
125
127
  self._run_cluster()
126
128
  if self.row_breakpoints is None:
127
129
  return None
128
- ratios = np.array([
129
- ix2 - ix1 for ix1, ix2 in pairwise(self.row_breakpoints)])
130
+ ratios = np.array([ix2 - ix1 for ix1, ix2 in pairwise(self.row_breakpoints)])
130
131
 
131
132
  if self.row_chunk_index is not None:
132
133
  return ratios[self.row_chunk_index]
@@ -138,8 +139,7 @@ class Deformation:
138
139
  self._run_cluster()
139
140
  if self.col_breakpoints is None:
140
141
  return None
141
- ratios = np.array([
142
- ix2 - ix1 for ix1, ix2 in pairwise(self.col_breakpoints)])
142
+ ratios = np.array([ix2 - ix1 for ix1, ix2 in pairwise(self.col_breakpoints)])
143
143
 
144
144
  if self.col_chunk_index is not None:
145
145
  return ratios[self.col_chunk_index]
@@ -161,11 +161,9 @@ class Deformation:
161
161
  if not self.is_col_split:
162
162
  return data
163
163
  if data.ndim == 1:
164
- return [data[ix1:ix2] for ix1, ix2 in pairwise(
165
- self.col_breakpoints)]
164
+ return [data[ix1:ix2] for ix1, ix2 in pairwise(self.col_breakpoints)]
166
165
  else:
167
- return [data[:, ix1:ix2] for ix1, ix2 in pairwise(
168
- self.col_breakpoints)]
166
+ return [data[:, ix1:ix2] for ix1, ix2 in pairwise(self.col_breakpoints)]
169
167
 
170
168
  def split_cross(self, data: np.ndarray):
171
169
  if self.is_col_split & self.is_row_split:
@@ -173,9 +171,7 @@ class Deformation:
173
171
  for ix1, ix2 in pairwise(self.row_breakpoints):
174
172
  row = []
175
173
  for iy1, iy2 in pairwise(self.col_breakpoints):
176
- row.append(
177
- data[ix1:ix2, iy1:iy2]
178
- )
174
+ row.append(data[ix1:ix2, iy1:iy2])
179
175
  split_data.append(row)
180
176
  return split_data
181
177
  if self.is_row_split:
@@ -184,14 +180,18 @@ class Deformation:
184
180
  return self.split_by_col(data)
185
181
  return data
186
182
 
187
- _linkage_check_msg = ("If you want to specific linkage when splitting, "
188
- "it must be a dict-like object, "
189
- "with keys as group names and values as linkage")
183
+ _linkage_check_msg = (
184
+ "If you want to specific linkage when splitting, "
185
+ "it must be a dict-like object, "
186
+ "with keys as group names and values as linkage"
187
+ )
190
188
 
191
189
  def cluster_row(self):
192
190
  row_data = self.split_by_row(self.get_data())
193
191
  if self.is_row_split:
194
- if not (isinstance(self.row_linkage, Mapping) or (self.row_linkage is None)):
192
+ if not (
193
+ isinstance(self.row_linkage, Mapping) or (self.row_linkage is None)
194
+ ):
195
195
  raise TypeError(self._linkage_check_msg)
196
196
  dens = []
197
197
  for chunk, k in zip(row_data, self.row_split_order):
@@ -200,7 +200,9 @@ class Deformation:
200
200
  linkage = self.row_linkage.get(k)
201
201
  if linkage is None:
202
202
  raise KeyError(f"Linkage for group {k} is not specified")
203
- dens.append(Dendrogram(chunk, linkage=linkage, key=k, **self.row_cluster_kws))
203
+ dens.append(
204
+ Dendrogram(chunk, linkage=linkage, key=k, **self.row_cluster_kws)
205
+ )
204
206
 
205
207
  dg = GroupDendrogram(dens, **self.row_cluster_kws)
206
208
  if self._use_row_meta:
@@ -216,7 +218,9 @@ class Deformation:
216
218
  def cluster_col(self):
217
219
  col_data = self.split_by_col(self.get_data())
218
220
  if self.is_col_split:
219
- if not (isinstance(self.col_linkage, Mapping) or (self.col_linkage is None)):
221
+ if not (
222
+ isinstance(self.col_linkage, Mapping) or (self.col_linkage is None)
223
+ ):
220
224
  raise TypeError(self._linkage_check_msg)
221
225
  dens = []
222
226
  for chunk, k in zip(col_data, self.col_split_order):
@@ -225,7 +229,9 @@ class Deformation:
225
229
  linkage = self.col_linkage.get(k)
226
230
  if linkage is None:
227
231
  raise KeyError(f"Linkage for group {k} is not specified")
228
- dens.append(Dendrogram(chunk.T, linkage=linkage, key=k, **self.col_cluster_kws))
232
+ dens.append(
233
+ Dendrogram(chunk.T, linkage=linkage, key=k, **self.col_cluster_kws)
234
+ )
229
235
  dg = GroupDendrogram(dens, **self.col_cluster_kws)
230
236
  if self._use_col_meta:
231
237
  self.col_chunk_index = dg.reorder_index
@@ -233,7 +239,9 @@ class Deformation:
233
239
  self.col_chunk_index = np.arange(len(dens))
234
240
  self.col_reorder_index = [d.reorder_index for d in dens]
235
241
  else:
236
- dg = Dendrogram(col_data.T, linkage=self.col_linkage, **self.col_cluster_kws)
242
+ dg = Dendrogram(
243
+ col_data.T, linkage=self.col_linkage, **self.col_cluster_kws
244
+ )
237
245
  self.col_reorder_index = dg.reorder_index
238
246
  self.col_dendrogram = dg
239
247
 
@@ -279,19 +287,15 @@ class Deformation:
279
287
  if self.is_row_split & self.is_col_split:
280
288
  final_data = []
281
289
  for row in data:
282
- for ix, order in zip(range(len(row)),
283
- self.col_reorder_index):
290
+ for ix, order in zip(range(len(row)), self.col_reorder_index):
284
291
  if row[ix].ndim == 2:
285
292
  row[ix] = row[ix][:, order]
286
293
  else:
287
294
  row[ix] = row[ix][order]
288
- final_data.append(
289
- [row[ix] for ix in self.col_chunk_index]
290
- )
295
+ final_data.append([row[ix] for ix in self.col_chunk_index])
291
296
  return final_data
292
297
  elif self.is_col_split:
293
- for ix, order in zip(range(len(data)),
294
- self.col_reorder_index):
298
+ for ix, order in zip(range(len(data)), self.col_reorder_index):
295
299
  data[ix] = data[ix][:, order]
296
300
 
297
301
  return [data[ix] for ix in self.col_chunk_index]
@@ -317,8 +321,10 @@ class Deformation:
317
321
  def transform(self, data: np.ndarray):
318
322
  """data must be 2d array with the same shape as cluster data"""
319
323
  if not data.shape == (self._nrow, self._ncol):
320
- msg = f"The shape of input data {data.shape} does not align with" \
321
- f" the shape of cluster data {(self._nrow, self._ncol)}"
324
+ msg = (
325
+ f"The shape of input data {data.shape} does not align with"
326
+ f" the shape of cluster data {(self._nrow, self._ncol)}"
327
+ )
322
328
  raise ValueError(msg)
323
329
  if self.data_row_reindex is not None:
324
330
  data = data[self.data_row_reindex]