e2b-charts 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,25 @@
1
+ Metadata-Version: 2.1
2
+ Name: e2b-charts
3
+ Version: 0.0.1
4
+ Summary: Package for extracting data for E2B Code Interpreter
5
+ Home-page: https://e2b.dev/
6
+ License: Apache-2.0
7
+ Author: e2b
8
+ Author-email: hello@e2b.dev
9
+ Requires-Python: >=3.10,<4.0
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Dist: matplotlib (>=3.9.2,<4.0.0)
16
+ Requires-Dist: numpy (>=1.26.4,<2.0.0)
17
+ Requires-Dist: pydantic (>=2.8.2,<3.0.0)
18
+ Project-URL: Bug Tracker, https://github.com/e2b-dev/code-interpreter/issues
19
+ Project-URL: Repository, https://github.com/e2b-dev/e2b-code-interpreter/tree/python
20
+ Description-Content-Type: text/markdown
21
+
22
+ # Extracting Data for Code Interpreter SDK
23
+
24
+ This package is a utility used to extract data in the Code Interpreter SDK from, e.g., DataFrames and matplotlib plots.
25
+
@@ -0,0 +1,3 @@
1
+ # Extracting Data for Code Interpreter SDK
2
+
3
+ This package is a utility used to extract data in the Code Interpreter SDK from, e.g., DataFrames and matplotlib plots.
@@ -0,0 +1 @@
1
+ from .main import graph_figure_to_graph
@@ -0,0 +1,4 @@
1
+ from .base import GraphType, Graph
2
+ from .bars import BarGraph, BoxAndWhiskerGraph
3
+ from .pie import PieGraph
4
+ from .planar import ScatterGraph, LineGraph
@@ -0,0 +1,141 @@
1
+ from typing import Literal, List
2
+
3
+ from matplotlib.axes import Axes
4
+ from pydantic import BaseModel, Field
5
+
6
+ from .base import Graph2D, GraphType
7
+ from ..utils.rounding import dynamic_round
8
+
9
+
10
+ class BarData(BaseModel):
11
+ label: str
12
+ group: str
13
+ value: float
14
+
15
+
16
+ class BarGraph(Graph2D):
17
+ type: Literal[GraphType.BAR] = GraphType.BAR
18
+
19
+ elements: List[BarData] = Field(default_factory=list)
20
+
21
+ def _extract_info(self, ax: Axes) -> None:
22
+ super()._extract_info(ax)
23
+ for container in ax.containers:
24
+ group_label = container.get_label()
25
+ if group_label.startswith("_container"):
26
+ number = int(group_label[10:])
27
+ group_label = f"Group {number}"
28
+
29
+ heights = [rect.get_height() for rect in container]
30
+ if all(height == heights[0] for height in heights):
31
+ # vertical bars
32
+ self._change_orientation()
33
+ labels = [label.get_text() for label in ax.get_yticklabels()]
34
+ values = [rect.get_width() for rect in container]
35
+ else:
36
+ # horizontal bars
37
+ labels = [label.get_text() for label in ax.get_xticklabels()]
38
+ values = heights
39
+ for label, value in zip(labels, values):
40
+
41
+ bar = BarData(label=label, value=value, group=group_label)
42
+ self.elements.append(bar)
43
+
44
+
45
+ class BoxAndWhiskerData(BaseModel):
46
+ label: str
47
+ min: float
48
+ first_quartile: float
49
+ median: float
50
+ third_quartile: float
51
+ max: float
52
+ outliers: List[float]
53
+
54
+
55
+ class BoxAndWhiskerGraph(Graph2D):
56
+ type: Literal[GraphType.BOX_AND_WHISKER] = GraphType.BOX_AND_WHISKER
57
+
58
+ elements: List[BoxAndWhiskerData] = Field(default_factory=list)
59
+
60
+ def _extract_info(self, ax: Axes) -> None:
61
+ super()._extract_info(ax)
62
+ labels = [item.get_text() for item in ax.get_xticklabels()]
63
+
64
+ boxes = []
65
+ for label, box in zip(labels, ax.patches):
66
+ vertices = box.get_path().vertices
67
+ x_vertices = [dynamic_round(x) for x in vertices[:, 0]]
68
+ y_vertices = [dynamic_round(y) for y in vertices[:, 1]]
69
+ x = min(x_vertices)
70
+ y = min(y_vertices)
71
+ boxes.append(
72
+ {
73
+ "x": x,
74
+ "y": y,
75
+ "label": label,
76
+ "width": max(x_vertices) - x,
77
+ "height": max(y_vertices) - y,
78
+ "outliers": [],
79
+ }
80
+ )
81
+
82
+ orientation = "horizontal"
83
+ if all(box["height"] == boxes[0]["height"] for box in boxes):
84
+ orientation = "vertical"
85
+
86
+ if orientation == "vertical":
87
+ self._change_orientation()
88
+ for box in boxes:
89
+ box["x"], box["y"] = box["y"], box["x"]
90
+ box["width"], box["height"] = box["height"], box["width"]
91
+
92
+ for i, line in enumerate(ax.lines):
93
+ xdata = [dynamic_round(x) for x in line.get_xdata()]
94
+ ydata = [dynamic_round(y) for y in line.get_ydata()]
95
+
96
+ if orientation == "vertical":
97
+ xdata, ydata = ydata, xdata
98
+
99
+ if len(xdata) == 1:
100
+ for box in boxes:
101
+ if box["x"] <= xdata[0] <= box["x"] + box["width"]:
102
+ break
103
+ else:
104
+ continue
105
+
106
+ box["outliers"].append(ydata[0])
107
+ if len(ydata) != 2:
108
+ continue
109
+ for box in boxes:
110
+ if box["x"] <= xdata[0] <= xdata[1] <= box["x"] + box["width"]:
111
+ break
112
+ else:
113
+ continue
114
+
115
+ if (
116
+ # Check if the line is inside the box, prevent floating point errors
117
+ ydata[0] == ydata[1]
118
+ and box["y"] <= ydata[0] <= box["y"] + box["height"]
119
+ ):
120
+ box["median"] = ydata[0]
121
+ continue
122
+
123
+ lower_value = min(ydata)
124
+ upper_value = max(ydata)
125
+ if upper_value == box["y"]:
126
+ box["whisker_lower"] = lower_value
127
+ elif lower_value == box["y"] + box["height"]:
128
+ box["whisker_upper"] = upper_value
129
+
130
+ self.elements = [
131
+ BoxAndWhiskerData(
132
+ label=box["label"],
133
+ min=box["whisker_lower"],
134
+ first_quartile=box["y"],
135
+ median=box["median"],
136
+ third_quartile=box["y"] + box["height"],
137
+ max=box["whisker_upper"],
138
+ outliers=box["outliers"],
139
+ )
140
+ for box in boxes
141
+ ]
@@ -0,0 +1,75 @@
1
+ import enum
2
+ import re
3
+ from typing import Optional, List, Any
4
+
5
+ from matplotlib.axes import Axes
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class GraphType(str, enum.Enum):
10
+ LINE = "line"
11
+ SCATTER = "scatter"
12
+ BAR = "bar"
13
+ PIE = "pie"
14
+ BOX_AND_WHISKER = "box_and_whisker"
15
+ SUPERGRAPH = "supergraph"
16
+ UNKNOWN = "unknown"
17
+
18
+
19
+ class Graph(BaseModel):
20
+ type: GraphType
21
+ title: Optional[str] = None
22
+
23
+ elements: List[Any] = Field(default_factory=list)
24
+
25
+ def __init__(self, ax: Optional[Axes] = None, **kwargs):
26
+ super().__init__(**kwargs)
27
+ if ax:
28
+ self._extract_info(ax)
29
+
30
+ def _extract_info(self, ax: Axes) -> None:
31
+ """
32
+ Function to extract information for Graph
33
+ """
34
+ title = ax.get_title()
35
+ if title == "":
36
+ title = None
37
+
38
+ self.title = title
39
+
40
+
41
+ class Graph2D(Graph):
42
+ x_label: Optional[str] = None
43
+ y_label: Optional[str] = None
44
+ x_unit: Optional[str] = None
45
+ y_unit: Optional[str] = None
46
+
47
+ def _extract_info(self, ax: Axes) -> None:
48
+ """
49
+ Function to extract information for Graph2D
50
+ """
51
+ super()._extract_info(ax)
52
+ x_label = ax.get_xlabel()
53
+ if x_label == "":
54
+ x_label = None
55
+ self.x_label = x_label
56
+
57
+ y_label = ax.get_ylabel()
58
+ if y_label == "":
59
+ y_label = None
60
+ self.y_label = y_label
61
+
62
+ regex = r"\s\((.*?)\)|\[(.*?)\]"
63
+ if self.x_label:
64
+ match = re.search(regex, self.x_label)
65
+ if match:
66
+ self.x_unit = match.group(1) or match.group(2)
67
+
68
+ if self.y_label:
69
+ match = re.search(regex, self.y_label)
70
+ if match:
71
+ self.y_unit = match.group(1) or match.group(2)
72
+
73
+ def _change_orientation(self):
74
+ self.x_label, self.y_label = self.y_label, self.x_label
75
+ self.x_unit, self.y_unit = self.y_unit, self.x_unit
@@ -0,0 +1,32 @@
1
+ from decimal import Decimal
2
+ from typing import Literal, List
3
+
4
+ from matplotlib.axes import Axes
5
+ from pydantic import BaseModel, Field
6
+
7
+ from .base import Graph, GraphType
8
+ from ..utils.rounding import dynamic_round
9
+
10
+
11
+ class PieData(BaseModel):
12
+ label: str
13
+ angle: float
14
+ radius: float
15
+
16
+
17
+ class PieGraph(Graph):
18
+ type: Literal[GraphType.PIE] = GraphType.PIE
19
+
20
+ elements: List[PieData] = Field(default_factory=list)
21
+
22
+ def _extract_info(self, ax: Axes) -> None:
23
+ super()._extract_info(ax)
24
+
25
+ for wedge in ax.patches:
26
+ pie_data = PieData(
27
+ label=wedge.get_label(),
28
+ angle=abs(dynamic_round(Decimal(wedge.theta2) - Decimal(wedge.theta1))),
29
+ radius=wedge.r,
30
+ )
31
+
32
+ self.elements.append(pie_data)
@@ -0,0 +1,132 @@
1
+ from datetime import date
2
+ from typing import List, Tuple, Union, Sequence, Any, Literal
3
+
4
+ import matplotlib
5
+ import numpy
6
+ from matplotlib.axes import Axes
7
+ from matplotlib.dates import _SwitchableDateConverter
8
+ from pydantic import BaseModel, field_validator, Field
9
+
10
+ from .base import Graph2D, GraphType
11
+ from ..utils import is_grid_line
12
+
13
+
14
+ class PointData(BaseModel):
15
+ label: str
16
+ points: List[Tuple[Union[str, float], Union[str, float]]]
17
+
18
+ @field_validator("points", mode="before")
19
+ @classmethod
20
+ def transform_points(
21
+ cls, value
22
+ ) -> List[Tuple[Union[str, float], Union[str, float]]]:
23
+ parsed_value = []
24
+ for x, y in value:
25
+ x = cls._parse_point(x)
26
+ y = cls._parse_point(y)
27
+ parsed_value.append((x, y))
28
+ return parsed_value
29
+
30
+ @staticmethod
31
+ def _parse_point(point):
32
+ if isinstance(point, date):
33
+ return point.isoformat()
34
+ if isinstance(point, numpy.datetime64):
35
+ return point.astype("datetime64[s]").astype(str)
36
+ return point
37
+
38
+
39
+ class PointGraph(Graph2D):
40
+ x_ticks: List[Union[str, float]] = Field(default_factory=list)
41
+ x_tick_labels: List[str] = Field(default_factory=list)
42
+ x_scale: str = Field(default="linear")
43
+
44
+ y_ticks: List[Union[str, float]] = Field(default_factory=list)
45
+ y_tick_labels: List[str] = Field(default_factory=list)
46
+ y_scale: str = Field(default="linear")
47
+
48
+ elements: List[PointData] = Field(default_factory=list)
49
+
50
+ def _extract_info(self, ax: Axes) -> None:
51
+ """
52
+ Function to extract information for PointGraph
53
+ """
54
+ super()._extract_info(ax)
55
+
56
+ self.x_tick_labels = [label.get_text() for label in ax.get_xticklabels()]
57
+
58
+ x_ticks = ax.get_xticks()
59
+ self.x_ticks = self._extract_ticks_info(ax.xaxis.converter, x_ticks)
60
+ self.x_scale = self._detect_scale(
61
+ ax.xaxis.converter, ax.get_xscale(), self.x_ticks, self.x_tick_labels
62
+ )
63
+
64
+ self.y_tick_labels = [label.get_text() for label in ax.get_yticklabels()]
65
+ self.y_ticks = self._extract_ticks_info(ax.yaxis.converter, ax.get_yticks())
66
+ self.y_scale = self._detect_scale(
67
+ ax.yaxis.converter, ax.get_yscale(), self.y_ticks, self.y_tick_labels
68
+ )
69
+
70
+ @staticmethod
71
+ def _detect_scale(converter, scale: str, ticks: Sequence, labels: Sequence) -> str:
72
+ # If the converter is a date converter, it's a datetime scale
73
+ if isinstance(converter, _SwitchableDateConverter):
74
+ return "datetime"
75
+
76
+ # If the scale is not linear, it can't be categorical
77
+ if scale != "linear":
78
+ return scale
79
+
80
+ # If all the ticks are integers and are in order from 0 to n-1
81
+ # and the labels aren't corresponding to the ticks, it's categorical
82
+ for i, tick_and_label in enumerate(zip(ticks, labels)):
83
+ tick, label = tick_and_label
84
+ if isinstance(tick, (int, float)) and tick == i and str(i) != label:
85
+ continue
86
+ # Found a tick, which wouldn't be in a categorical scale
87
+ return "linear"
88
+
89
+ return "categorical"
90
+
91
+ @staticmethod
92
+ def _extract_ticks_info(converter: Any, ticks: Sequence) -> list:
93
+ if isinstance(converter, _SwitchableDateConverter):
94
+ return [matplotlib.dates.num2date(tick).isoformat() for tick in ticks]
95
+ else:
96
+ example_tick = ticks[0]
97
+
98
+ if isinstance(example_tick, (int, float)):
99
+ return [float(tick) for tick in ticks]
100
+ else:
101
+ return list(ticks)
102
+
103
+
104
+ class LineGraph(PointGraph):
105
+ type: Literal[GraphType.LINE] = GraphType.LINE
106
+
107
+ def _extract_info(self, ax: Axes) -> None:
108
+ super()._extract_info(ax)
109
+
110
+ for line in ax.get_lines():
111
+ if is_grid_line(line):
112
+ continue
113
+ label = line.get_label()
114
+ if label.startswith("_child"):
115
+ number = int(label[6:])
116
+ label = f"Line {number}"
117
+
118
+ points = [(x, y) for x, y in zip(line.get_xdata(), line.get_ydata())]
119
+ line_data = PointData(label=label, points=points)
120
+ self.elements.append(line_data)
121
+
122
+
123
+ class ScatterGraph(PointGraph):
124
+ type: Literal[GraphType.SCATTER] = GraphType.SCATTER
125
+
126
+ def _extract_info(self, ax: Axes) -> None:
127
+ super()._extract_info(ax)
128
+
129
+ for collection in ax.collections:
130
+ points = [(x, y) for x, y in collection.get_offsets()]
131
+ scatter_data = PointData(label=collection.get_label(), points=points)
132
+ self.elements.append(scatter_data)
@@ -0,0 +1,111 @@
1
+ from typing import Optional, List, Literal
2
+
3
+ from matplotlib.axes import Axes
4
+ from matplotlib.collections import PathCollection
5
+ from matplotlib.lines import Line2D
6
+ from matplotlib.patches import Rectangle, Wedge, PathPatch
7
+ from matplotlib.pyplot import Figure
8
+
9
+ from matplotlib.text import Text
10
+ from pydantic import Field
11
+
12
+ from .graphs import (
13
+ GraphType,
14
+ Graph,
15
+ LineGraph,
16
+ BarGraph,
17
+ BoxAndWhiskerGraph,
18
+ PieGraph,
19
+ ScatterGraph,
20
+ )
21
+ from .utils.filtering import is_grid_line
22
+
23
+
24
+ class SuperGraph(Graph):
25
+ type: Literal[GraphType.SUPERGRAPH] = GraphType.SUPERGRAPH
26
+ elements: List[
27
+ LineGraph | ScatterGraph | BarGraph | PieGraph | BoxAndWhiskerGraph
28
+ ] = Field(default_factory=list)
29
+
30
+ def __init__(self, figure: Figure):
31
+ title = figure.get_suptitle()
32
+ super().__init__(title=title)
33
+
34
+ self.elements = [get_graph_from_ax(ax) for ax in figure.axes]
35
+
36
+
37
+ def _get_type_of_graph(ax: Axes) -> GraphType:
38
+ objects = list(filter(lambda obj: not isinstance(obj, Text), ax._children))
39
+
40
+ # Check for Line plots
41
+ if all(isinstance(line, Line2D) for line in objects):
42
+ return GraphType.LINE
43
+
44
+ if all(isinstance(box_or_path, (PathPatch, Line2D)) for box_or_path in objects):
45
+ return GraphType.BOX_AND_WHISKER
46
+
47
+ filtered = []
48
+ for obj in objects:
49
+ if isinstance(obj, Line2D) and is_grid_line(obj):
50
+ continue
51
+ filtered.append(obj)
52
+
53
+ objects = filtered
54
+
55
+ # Check for Scatter plots
56
+ if all(isinstance(path, PathCollection) for path in objects):
57
+ return GraphType.SCATTER
58
+
59
+ # Check for Pie plots
60
+ if all(isinstance(artist, Wedge) for artist in objects):
61
+ return GraphType.PIE
62
+
63
+ # Check for Bar plots
64
+ if all(isinstance(rect, Rectangle) for rect in objects):
65
+ return GraphType.BAR
66
+
67
+ return GraphType.UNKNOWN
68
+
69
+
70
+ def get_graph_from_ax(
71
+ ax: Axes,
72
+ ) -> LineGraph | ScatterGraph | BarGraph | PieGraph | BoxAndWhiskerGraph | Graph:
73
+ graph_type = _get_type_of_graph(ax)
74
+
75
+ if graph_type == GraphType.LINE:
76
+ graph = LineGraph(ax=ax)
77
+ elif graph_type == GraphType.SCATTER:
78
+ graph = ScatterGraph(ax=ax)
79
+ elif graph_type == GraphType.BAR:
80
+ graph = BarGraph(ax=ax)
81
+ elif graph_type == GraphType.PIE:
82
+ graph = PieGraph(ax=ax)
83
+ elif graph_type == GraphType.BOX_AND_WHISKER:
84
+ graph = BoxAndWhiskerGraph(ax=ax)
85
+ else:
86
+ graph = Graph(ax=ax, type=graph_type)
87
+
88
+ return graph
89
+
90
+
91
+ def graph_figure_to_graph(figure: Figure) -> Optional[Graph]:
92
+ """
93
+ This method is used to extract data from the figure object to a dictionary
94
+ """
95
+ # Get all Axes objects from the Figure
96
+ axes = figure.get_axes()
97
+
98
+ if not axes:
99
+ return
100
+ elif len(axes) > 1:
101
+ return SuperGraph(figure=figure)
102
+ else:
103
+ ax = axes[0]
104
+ return get_graph_from_ax(ax)
105
+
106
+
107
+ def graph_figure_to_dict(figure: Figure) -> dict:
108
+ graph = graph_figure_to_graph(figure)
109
+ if graph:
110
+ return graph.model_dump()
111
+ return {}
@@ -0,0 +1 @@
1
+ from .filtering import is_grid_line
@@ -0,0 +1,16 @@
1
+ from matplotlib.lines import Line2D
2
+
3
+
4
+ def is_grid_line(line: Line2D) -> bool:
5
+ x_data = line.get_xdata()
6
+ if len(x_data) != 2:
7
+ return False
8
+
9
+ y_data = line.get_ydata()
10
+ if len(y_data) != 2:
11
+ return False
12
+
13
+ if x_data[0] == x_data[1] or y_data[0] == y_data[1]:
14
+ return True
15
+
16
+ return False
@@ -0,0 +1,13 @@
1
+ from decimal import Decimal, localcontext
2
+
3
+
4
+ def dynamic_round(number):
5
+ # Convert to Decimal for precise control
6
+ decimal_number = Decimal(str(number))
7
+
8
+ # Dynamically determine precision based on magnitude
9
+ precision = max(1, 8 - decimal_number.adjusted()) # 8 digits of precision
10
+
11
+ with localcontext() as ctx:
12
+ ctx.prec = precision # Set the dynamic precision
13
+ return +decimal_number # The + operator applies rounding
@@ -0,0 +1,29 @@
1
+ [tool.poetry]
2
+ name = "e2b-charts"
3
+ version = "0.0.1"
4
+ description = "Package for extracting data for E2B Code Interpreter"
5
+ authors = ["e2b <hello@e2b.dev>"]
6
+ license = "Apache-2.0"
7
+ readme = "README.md"
8
+ homepage = "https://e2b.dev/"
9
+ repository = "https://github.com/e2b-dev/e2b-code-interpreter/tree/python"
10
+ packages = [{ include = "e2b_charts" }]
11
+
12
+ [tool.poetry.dependencies]
13
+ python = "^3.10"
14
+
15
+ numpy = "^1.26.4"
16
+ matplotlib = "^3.9.2"
17
+ pydantic = "^2.8.2"
18
+
19
+ [tool.poetry.group.dev.dependencies]
20
+ pytest = "^7.4.0"
21
+ python-dotenv = "^1.0.0"
22
+ pytest-dotenv = "^0.5.2"
23
+
24
+ [build-system]
25
+ requires = ["poetry-core"]
26
+ build-backend = "poetry.core.masonry.api"
27
+
28
+ [tool.poetry.urls]
29
+ "Bug Tracker" = "https://github.com/e2b-dev/code-interpreter/issues"