tikzplot42 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tikzplot/figure.py ADDED
@@ -0,0 +1,285 @@
1
+ import numpy as np
2
+
3
+ from .axes import Axes
4
+ from .axes3d import Axes3
5
+ from .config import TikzConfig
6
+
7
+ class Figure:
8
+
9
+ def __init__(self):
10
+ self._axes = []
11
+ self._width = None
12
+ self._height = None
13
+ if TikzConfig.DEFAULT_WIDTH:
14
+ self._width = TikzConfig.DEFAULT_WIDTH
15
+ if TikzConfig.DEFAULT_HEIGHT:
16
+ self._height = TikzConfig.DEFAULT_HEIGHT
17
+
18
+ self._sharex = None
19
+ self._sharey = None
20
+
21
+ self._nrows = 0
22
+ self._ncols = 0
23
+
24
+ self._spacings = None
25
+
26
+ self._last_path_num = 0
27
+
28
+ self._col_dict = {}
29
+
30
+ self._globals = set()
31
+
32
+ def add_subplot(self, nrows=1, ncols=1, index=1, sharex=None, sharey=None, projection=None, polar=False):
33
+ if projection=="3d":
34
+ ax = Axes3(nrows, ncols, index, self)
35
+ else:
36
+ pol = projection=="polar" or polar
37
+ ax = Axes(nrows, ncols, index, self, pol)
38
+ self._nrows = nrows
39
+ self._ncols = ncols
40
+ self._axes.append(ax)
41
+ if sharex:
42
+ self._sharex = sharex
43
+ if sharey:
44
+ self._sharey = sharey
45
+ return ax
46
+
47
+ def _add_subplots(self, nrows, ncols, sharex=None, sharey=None, subplot_kw=None):
48
+ grid = []
49
+ if sharex:
50
+ self._sharex = sharex
51
+ if sharey:
52
+ self._sharey = sharey
53
+ for i in range(1, nrows * ncols + 1):
54
+ if subplot_kw:
55
+ if "projection" in subplot_kw:
56
+ ax = self.add_subplot(nrows, ncols, i, sharex, sharey, projection=subplot_kw["projection"])
57
+ else:
58
+ ax = self.add_subplot(nrows, ncols, i, sharex, sharey)
59
+ grid.append(ax)
60
+ return grid
61
+
62
+ def set_size_inches(self, *args):
63
+ if isinstance(args[0], tuple):
64
+ args = args[0]
65
+ try:
66
+ w,h = args
67
+ self._width = w * 2.5
68
+ self._height = h * 2.5
69
+ for ax in self._axes:
70
+ ax._update_size()
71
+ except:
72
+ pass
73
+
74
+ def _compute_group_spacing(self):
75
+ grid = np.zeros((self._nrows, self._ncols, 4))
76
+
77
+ for ax in self._axes:
78
+ grid[ax._get_row(), ax._get_col()] = np.array(ax._margins())
79
+ l = grid[:, :, 0]
80
+ r = grid[:, :, 1]
81
+ t = grid[:, :, 2]
82
+ b = grid[:, :, 3]
83
+ if self._nrows > 1:
84
+ row_spacing = np.max(b[:-1, :], axis=1) + np.max(t[1:, :], axis=1)
85
+ else:
86
+ row_spacing = [0]
87
+ if self._ncols > 1:
88
+ col_spacing = np.max(r[:, :-1], axis=0) + np.max(l[:, 1:], axis=0)
89
+ else:
90
+ col_spacing = [0]
91
+ self._spacings = row_spacing, col_spacing
92
+
93
+ def _get_spacing(self, row, col):
94
+ if not self._spacings:
95
+ self._compute_group_spacing()
96
+ if col == 0:
97
+ if row == 0:
98
+ return 0
99
+ return self._spacings[0][row-1]
100
+ return self._spacings[1][col-1]
101
+
102
+
103
+ def _shared_ranges(self):
104
+ shared_x = []
105
+ shared_y = []
106
+ if self._sharex and self._sharex != "none":
107
+ if self._sharex == "all" or self._sharex == True:
108
+ shared_x = [self._axes]
109
+ if self._sharex == "row":
110
+ shared_x = [[] for _ in range(self._nrows)]
111
+ for ax in self._axes:
112
+ shared_x[ax._get_row()].append(ax)
113
+ elif self._sharex == "col":
114
+ shared_x = [[] for _ in range(self._ncols)]
115
+ for ax in self._axes:
116
+ shared_x[ax._get_col()].append(ax)
117
+ if self._sharey and self._sharey != "none":
118
+ if self._sharey == "all" or self._sharey == True:
119
+ shared_y = [self._axes]
120
+ if self._sharey == "row":
121
+ shared_y = [[] for _ in range(self._nrows)]
122
+ for ax in self._axes:
123
+ shared_y[ax._get_row()].append(ax)
124
+ elif self._sharey == "col":
125
+ shared_y = [[] for _ in range(self._ncols)]
126
+ for ax in self._axes:
127
+ shared_y[ax._get_col()].append(ax)
128
+
129
+ def set_ax_ranges(which, group):
130
+ hard_min_vals = []
131
+ hard_max_vals = []
132
+ mode = "lin"
133
+
134
+ for ax in group:
135
+ hmin, m = ax._get_hard_range(which + "min")
136
+ if m == "log":
137
+ mode = "log"
138
+ if hmin is not None:
139
+ hard_min_vals.append(hmin)
140
+
141
+ hmax, m = ax._get_hard_range(which + "max")
142
+ if m == "log":
143
+ mode = "log"
144
+ if hmax is not None:
145
+ hard_max_vals.append(hmax)
146
+
147
+ if hard_min_vals or hard_max_vals:
148
+ min_val = min(hard_min_vals) if hard_min_vals else None
149
+ max_val = max(hard_max_vals) if hard_max_vals else None
150
+
151
+ if min_val is not None:
152
+ for ax in group:
153
+ ax._set_range(which + "min", min_val)
154
+
155
+ if max_val is not None:
156
+ for ax in group:
157
+ ax._set_range(which + "max", max_val)
158
+
159
+ return
160
+
161
+ mins = [ax._get_range(which + "min") for ax in group]
162
+ maxes = [ax._get_range(which + "max") for ax in group]
163
+
164
+ min_val = min(r[0] for r in mins)
165
+ max_val = max(r[0] for r in maxes)
166
+
167
+ for r in mins:
168
+ if r[2] == "log":
169
+ mode = "log"
170
+
171
+ if min_val < max_val:
172
+ if mode == "lin":
173
+ d = max_val - min_val
174
+ min_val -= d * TikzConfig.SHARED_AXIS_REL_MARGIN
175
+ max_val += d * TikzConfig.SHARED_AXIS_REL_MARGIN
176
+
177
+ else:
178
+ d = (max_val / min_val) ** TikzConfig.SHARED_AXIS_REL_MARGIN
179
+ min_val /= d
180
+ max_val *= d
181
+
182
+ for ax in group:
183
+ ax._set_range(which + "min", min_val)
184
+ ax._set_range(which + "max", max_val)
185
+
186
+
187
+ for group in shared_x:
188
+ set_ax_ranges("x", group)
189
+
190
+ for group in shared_y:
191
+ set_ax_ranges("y", group)
192
+
193
+ def _reduce_points(self):
194
+ counts = [0]
195
+ for ax in self._axes:
196
+ counts += ax._num_points()
197
+ counts = [min(c, TikzConfig.MAX_POINTS_PER_ELEMENT) for c in counts]
198
+ limit = max(counts)
199
+ if sum(counts) > TikzConfig.MAX_POINTS_PER_FIGURE:
200
+ lo, hi = 0, max(counts)
201
+ while lo < hi:
202
+ mid = (lo + hi + 1) // 2
203
+ total = sum(min(c, mid) for c in counts)
204
+ if total <= TikzConfig.MAX_POINTS_PER_FIGURE:
205
+ lo = mid
206
+ else:
207
+ hi = mid - 1
208
+ limit = lo
209
+
210
+ for ax in self._axes:
211
+ ax._reduce_points(limit)
212
+
213
+ def _to_tex(self, filename):
214
+ if not self._axes:
215
+ return ""
216
+ self._shared_ranges()
217
+ if TikzConfig.REDUCE_NUM_POINTS:
218
+ self._reduce_points()
219
+ preambule = ""
220
+ if TikzConfig.STANDALONE:
221
+ preambule += "\\documentclass[tikz,border=2pt]{standalone}\n"
222
+ preambule += "\\usepackage{tikz}\n"
223
+ preambule += "\\usepackage{pgfplots}\n"
224
+ if TikzConfig.USE_GROUPPLOTS:
225
+ preambule += "\\usepgfplotslibrary{groupplots}\n"
226
+ preambule += "\\usepgfplotslibrary{fillbetween}\n"
227
+ preambule += f"\\pgfplotsset{{compat={TikzConfig.TIKZ_COMPAT}}}\n"
228
+ if TikzConfig.USE_XCOLOR:
229
+ preambule += "\\usepackage{xcolor}\n"
230
+ preambule += "\\begin{document}\n"
231
+
232
+ lines = [g for g in self._globals]
233
+ lines2 = []
234
+ lines.append("\\begin{tikzpicture}")
235
+ nrows = self._axes[0]._get_nrows()
236
+ ncols = self._axes[0]._get_ncols()
237
+ if TikzConfig.USE_GROUPPLOTS:
238
+ self._compute_group_spacing()
239
+ if len(self._spacings[0]) > 0 and len(self._spacings[1]) > 0:
240
+ lines.append(f"\\begin{{groupplot}}[group style={{group size={ncols} by {nrows}, horizontal sep={max(self._spacings[1])}cm, vertical sep={max(self._spacings[0])}cm}}]")
241
+ else:
242
+ lines.append(f"\\begin{{groupplot}}[group style={{group size={ncols} by {nrows}}}]")
243
+ for ax in self._axes:
244
+ prim, sec = ax._to_tex(filename)
245
+ lines += prim
246
+ if sec:
247
+ lines2 += sec
248
+ if TikzConfig.USE_GROUPPLOTS:
249
+ lines.append("\\end{groupplot}")
250
+ lines += lines2
251
+ lines.append("\\end{tikzpicture}")
252
+ for c in self._col_dict:
253
+ r,g,b=self._col_dict[c]
254
+ lines.insert(1,f"\\definecolor{{{c}}}{{rgb}}{{{r:.3f}, {g:.3f}, {b:.3f}}}")
255
+ fin = ""
256
+ if TikzConfig.STANDALONE:
257
+ fin += "\\end{document}"
258
+ output = preambule + "\n" + "\n".join(lines) + "\n" + fin
259
+ return output
260
+
261
+ def _save(self, filename):
262
+ content = self._to_tex(filename)
263
+ if not TikzConfig.SAVE_DATAPOINTS or (TikzConfig.SAVE_DATAPOINTS and not TikzConfig.UPDATE_DATA_ONLY):
264
+ with open(filename, "w", encoding="utf-8") as f:
265
+ f.write(content)
266
+
267
+ def _get_width(self):
268
+ return self._width
269
+
270
+ def _get_height(self):
271
+ return self._height
272
+
273
+ def clear(self):
274
+ self.__init__()
275
+
276
+ def _get_free_path_name(self):
277
+ self._last_path_num += 1
278
+ return f"path{self._last_path_num}"
279
+
280
+ def _add_col(self, r,g,b):
281
+ code = f"c{r:.3f}{g:.3f}{b:.3f}".replace(".", "")
282
+ self._col_dict[code] = (r,g,b)
283
+
284
+ def _add_global(self, setting):
285
+ self._globals.add(setting)
tikzplot/figure.pyi ADDED
@@ -0,0 +1,27 @@
1
+ from .axes import Axes as Axes
2
+ from .config import TikzConfig as TikzConfig
3
+
4
+ from typing import Optional
5
+
6
+ class Figure:
7
+ def __init__(self) -> None: ...
8
+ def add_subplot(self, nrows:Optional[int], ncols:Optional[int], index:Optional[int], sharex:Optional[str], sharey:Optional[str], projection:Optional[str], polar:Optional[bool]) -> Axes:
9
+ """
10
+ Add subplot axis.
11
+ Parameters
12
+ ----------
13
+ projection: None, "polar", "3d", optional
14
+ polar: bool, optional
15
+ Use polar projection for axis (no additional features implemented yet).
16
+ """
17
+ ...
18
+ def set_size_inches(self, *args) -> None:
19
+ """
20
+ Set figure size (w,h).
21
+ """
22
+ ...
23
+ def clear(self) -> None:
24
+ """
25
+ Clear figure.
26
+ """
27
+ ...
@@ -0,0 +1,62 @@
1
+ import re
2
+
3
+ TEXT_MAP = str.maketrans({
4
+ '%': r'\%',
5
+ '$': r'\$',
6
+ '&': r'\&',
7
+ '#': r'\#',
8
+ '_': r'\_',
9
+ '{': r'\{',
10
+ '}': r'\}',
11
+ '~': r'\textasciitilde{}',
12
+ '^': r'\textasciicircum{}',
13
+ '\\': r'\textbackslash{}',
14
+ })
15
+
16
+ MATH_MAP = str.maketrans({
17
+ "%": r"\%",
18
+ "#": r"\#",
19
+ "&": r"\&",
20
+ })
21
+
22
+ def tex_text(sa: str) -> str:
23
+ math_mode = False
24
+ i = 0
25
+ n = len(sa)
26
+ out = []
27
+
28
+ while i < n:
29
+ s = sa[i]
30
+
31
+ # Handle $$ (display math)
32
+ if s == "$":
33
+ if i + 1 < n and sa[i + 1] == "$":
34
+ math_mode = not math_mode
35
+ out.append("$$")
36
+ i += 2
37
+ continue
38
+ else:
39
+ math_mode = not math_mode
40
+ out.append("$")
41
+ i += 1
42
+ continue
43
+
44
+ # Skip already escaped characters (e.g. \%, \_)
45
+ if s == "\\" and i + 1 < n:
46
+ out.append(sa[i:i+2])
47
+ i += 2
48
+ continue
49
+
50
+ if not math_mode:
51
+ # Special handling for underscore in text mode
52
+ if s == "_":
53
+ out.append(r"\_")
54
+ else:
55
+ out.append(s.translate(TEXT_MAP))
56
+ else:
57
+ # Math mode: lighter escaping
58
+ out.append(s.translate(MATH_MAP))
59
+
60
+ i += 1
61
+
62
+ return "".join(out)
tikzplot/plots.py ADDED
@@ -0,0 +1,175 @@
1
+ import numpy as np
2
+
3
+ from .figure import Figure
4
+ from .state import main_name, next_show_num
5
+ from .config import TikzConfig
6
+
7
+ _current_figure = None
8
+ _current_axes = None
9
+
10
+ def figure(**kwargs):
11
+ global _current_figure, _current_axes
12
+ _current_figure = Figure()
13
+ _current_axes = None
14
+ if "figsize" in kwargs:
15
+ _current_figure.set_size_inches(kwargs["figsize"])
16
+ return _current_figure
17
+
18
+ def _ensure_axes():
19
+ global _current_figure, _current_axes
20
+
21
+ if _current_figure is None:
22
+ from .figure import Figure
23
+ _current_figure = Figure()
24
+
25
+ if _current_axes is None:
26
+ _current_axes = _current_figure.add_subplot(1, 1, 1)
27
+
28
+ def xlabel(label):
29
+ _ensure_axes()
30
+ _current_axes.set_xlabel(label)
31
+
32
+ def ylabel(label):
33
+ _ensure_axes()
34
+ _current_axes.set_ylabel(label)
35
+
36
+ def title(text):
37
+ _ensure_axes()
38
+ _current_axes.set_title(text)
39
+
40
+ def grid(*args, **kwargs):
41
+ _ensure_axes()
42
+ _current_axes.grid(*args, **kwargs)
43
+
44
+ def minorticks_num(num):
45
+ _ensure_axes()
46
+ _current_axes.set_minorticks_num(num)
47
+
48
+ def xlim(*args, **kwargs):
49
+ _ensure_axes()
50
+ _current_axes.set_xlim(*args, **kwargs)
51
+
52
+ def ylim(*args, **kwargs):
53
+ _ensure_axes()
54
+ _current_axes.set_ylim(*args, **kwargs)
55
+
56
+ def legend(*args, **kwargs):
57
+ _ensure_axes()
58
+ _current_axes.legend(*args, **kwargs)
59
+
60
+ def subplot(nrows, ncols, index, sharex=None, sharey=None, projection=None, polar=False):
61
+ global _current_axes
62
+
63
+ if _current_figure is None:
64
+ figure()
65
+
66
+ _current_axes = _current_figure.add_subplot(nrows, ncols, index, sharex, sharey, projection, polar)
67
+ return _current_axes
68
+
69
+ def subplots(nrows=1, ncols=1, sharex=None, sharey=None, subplot_kw=None, **kwargs):
70
+
71
+ global _current_figure, _current_axes
72
+
73
+ _current_figure = Figure()
74
+ axes = _current_figure._add_subplots(nrows, ncols, sharex, sharey, subplot_kw)
75
+
76
+ if nrows * ncols == 1:
77
+ _current_axes = axes[0]
78
+ return _current_figure, axes[0]
79
+
80
+ grid = []
81
+ k = 0
82
+ for r in range(nrows):
83
+ row = []
84
+ for c in range(ncols):
85
+ row.append(axes[k])
86
+ k += 1
87
+ grid.append(row)
88
+
89
+ _current_axes = axes[0]
90
+ grid = np.asarray(grid)
91
+ if grid.shape[0] == 1:
92
+ grid = grid[0]
93
+ elif grid.shape[1] == 1:
94
+ grid = grid[:,0]
95
+
96
+ if "figsize" in kwargs:
97
+ _current_figure.set_size_inches(kwargs["figsize"])
98
+ return _current_figure, grid
99
+
100
+ def plot(*args, **kwargs):
101
+ _ensure_axes()
102
+ _current_axes.plot(*args, **kwargs)
103
+
104
+ def scatter(x, y, *args, **kwargs):
105
+ _ensure_axes()
106
+ _current_axes.scatter(x, y, *args, **kwargs)
107
+
108
+
109
+ def loglog(x, y, *args, **kwargs):
110
+ _ensure_axes()
111
+ _current_axes.loglog(x, y, *args, **kwargs)
112
+
113
+ def semilogx(x, y, *args, **kwargs):
114
+ _ensure_axes()
115
+ _current_axes.semilogx(x, y, *args, **kwargs)
116
+
117
+ def semilogy(x, y, *args, **kwargs):
118
+ _ensure_axes()
119
+ _current_axes.semilogy(x, y, *args, **kwargs)
120
+
121
+ def errorbar(x, y, *args, **kwargs):
122
+ _ensure_axes()
123
+ _current_axes.errorbar(x, y, *args, **kwargs)
124
+
125
+ def stem(*args, **kwargs):
126
+ _ensure_axes()
127
+ _current_axes.stem(*args, **kwargs)
128
+
129
+ def fill_between(*args, **kwargs):
130
+ _ensure_axes()
131
+ _current_axes.fill_between(*args, **kwargs)
132
+
133
+ def hlines(*args, **kwargs):
134
+ _ensure_axes()
135
+ _current_axes.hlines(*args, **kwargs)
136
+
137
+ def vlines(*args, **kwargs):
138
+ _ensure_axes()
139
+ _current_axes.vlines(*args, **kwargs)
140
+
141
+ def imshow(*args, **kwargs):
142
+ _ensure_axes()
143
+ return _current_axes.imshow(*args, **kwargs)
144
+
145
+ def xticks(*args, **kwargs):
146
+ _ensure_axes()
147
+ _current_axes.set_xticks(*args, **kwargs)
148
+
149
+ def yticks(*args, **kwargs):
150
+ _ensure_axes()
151
+ _current_axes.set_yticks(*args, **kwargs)
152
+
153
+ def xscale(*args, **kwargs):
154
+ _ensure_axes()
155
+ _current_axes.set_xscale(*args, **kwargs)
156
+
157
+ def yscale( *args, **kwargs):
158
+ _ensure_axes()
159
+ _current_axes.set_yscale(*args, **kwargs)
160
+
161
+ def savefig(filename):
162
+ if not(filename.endswith(".tex") or filename.endswith(".tikz")):
163
+ filename += ".tex"
164
+ _current_figure._save(filename)
165
+
166
+ def show():
167
+ _current_figure._save(f"{str(main_name()[1]).removesuffix('.py')}_{TikzConfig.SHOW_SAVENAME}{next_show_num()}.tex")
168
+ clf()
169
+
170
+ def clf():
171
+ _current_figure.clear()
172
+
173
+ def gca():
174
+ _ensure_axes()
175
+ return _current_axes