lf-pollywog 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pollywog/display.py ADDED
@@ -0,0 +1,149 @@
1
+ _DISPLAY_THEME = "light"
2
+
3
+ def set_theme(theme):
4
+ """
5
+ Set the global theme for display_calcset ("light" or "dark").
6
+ """
7
+ global _DISPLAY_THEME
8
+ if theme not in ("light", "dark"):
9
+ raise ValueError("Theme must be 'light' or 'dark'.")
10
+ _DISPLAY_THEME = theme
11
+
12
+ def display_calcset(calcset, theme=None, colors=None, display_output=True):
13
+ """
14
+ Display a CalcSet in a Jupyter notebook with a visual style similar to Leapfrog, rendering equations and logic blocks visually.
15
+ Supports 'theme' ("light" or "dark") and custom color palettes via 'colors' dict. If theme is None, uses global setting from set_theme().
16
+ """
17
+ from IPython.display import display, HTML
18
+ import html
19
+
20
+ # Default color palettes
21
+ default_colors = {
22
+ "light": {
23
+ "background": "#eee",
24
+ "text": "#222",
25
+ "variable": "#0057b7",
26
+ "label": "#222",
27
+ "if": "#0057b7",
28
+ "arrow": "#222",
29
+ "comment": "#999",
30
+ "var_ref": "#b77",
31
+ },
32
+ "dark": {
33
+ "background": "#222",
34
+ "text": "#eee",
35
+ "variable": "#7abaff",
36
+ "label": "#eee",
37
+ "if": "#7abaff",
38
+ "arrow": "#eee",
39
+ "comment": "#bbb",
40
+ "var_ref": "#ffb77a",
41
+ },
42
+ }
43
+ use_theme = theme if theme is not None else _DISPLAY_THEME
44
+ palette = default_colors.get(use_theme, default_colors["light"]).copy()
45
+ if colors:
46
+ palette.update(colors)
47
+
48
+ def render_expression(expr, indent=0):
49
+ pad = " " * (indent * 4)
50
+ if isinstance(expr, str):
51
+ expr = html.escape(expr)
52
+ expr = expr.replace('[', f'<span style="color:{palette["var_ref"]};">[').replace(']', ']</span>')
53
+ return pad + f'<span style="color:{palette["text"]};">{expr}</span>'
54
+ elif isinstance(expr, list):
55
+ return '<br>'.join(render_expression(e, indent) for e in expr)
56
+ if isinstance(expr, dict):
57
+ typ = expr.get("type")
58
+ if typ == "if":
59
+ rows = expr.get("rows", [])
60
+ otherwise = expr.get("otherwise", {}).get("children", [])
61
+ html_rows = []
62
+ for row in rows:
63
+ cond = row.get("test", {}).get("children", [])
64
+ res = row.get("result", {}).get("children", [])
65
+ html_rows.append(
66
+ f'<div style="margin-left:{indent*24}px;border-left:2px solid {palette["if"]};padding-left:8px;">'
67
+ f'<span style="color:{palette["if"]};">if</span> '
68
+ f'{render_expression(cond, 0)} '
69
+ f'<span style="color:{palette["arrow"]};">&rarr;</span> '
70
+ f'{render_expression(res, indent+1)}'
71
+ f'</div>'
72
+ )
73
+ if otherwise:
74
+ html_rows.append(
75
+ f'<div style="margin-left:{indent*24}px;border-left:2px solid {palette["if"]};padding-left:8px;">'
76
+ f'<span style="color:{palette["if"]};">otherwise</span> '
77
+ f'<span style="color:{palette["arrow"]};">&rarr;</span> '
78
+ f'{render_expression(otherwise, indent+1)}'
79
+ f'</div>'
80
+ )
81
+ return ''.join(html_rows)
82
+ elif typ == "if_row":
83
+ cond = expr.get("test", {}).get("children", [])
84
+ res = expr.get("result", {}).get("children", [])
85
+ return (
86
+ f'<div style="margin-left:{indent*24}px;border-left:2px solid {palette["if"]};padding-left:8px;">'
87
+ f'<span style="color:{palette["if"]};">if</span> '
88
+ f'{render_expression(cond, 0)} '
89
+ f'<span style="color:{palette["arrow"]};">&rarr;</span> '
90
+ f'{render_expression(res, indent+1)}'
91
+ f'</div>'
92
+ )
93
+ elif typ == "list":
94
+ children = expr.get("children", [])
95
+ return render_expression(children, indent)
96
+ else:
97
+ return pad + f'<span style="color:{palette["text"]};">{html.escape(str(expr))}</span>'
98
+ else:
99
+ return pad + f'<span style="color:{palette["text"]};">{html.escape(str(expr))}</span>'
100
+
101
+ def render_equation(eq):
102
+ if isinstance(eq, dict) and eq.get("type") == "equation":
103
+ statement = eq["statement"]
104
+ comment = eq.get("comment", "")
105
+ expr_html = render_expression(statement)
106
+ comment_html = f'<span style="color:#999;">{html.escape(comment)}</span>' if comment else ''
107
+ return f'<div style="margin-left:1em;color:#555;">{expr_html} {comment_html}</div>'
108
+ return html.escape(str(eq))
109
+
110
+ def render_item(item):
111
+ d = item.to_dict() if hasattr(item, "to_dict") else item
112
+ name = d.get("name", "")
113
+ typ = d.get("type", "")
114
+ eq = d.get("equation", None)
115
+ comment = d.get("comment", "")
116
+ html_block = f'<div style="margin-bottom:0.5em;">'
117
+ html_block += f'<b style="color:#0057b7;">{html.escape(name)}</b> '
118
+ # Show calculation_type for calculation items
119
+ calc_type = d.get("calculation_type")
120
+ label = typ
121
+ if typ == "calculation" and calc_type:
122
+ label = calc_type
123
+ html_block += f'<span style="background:#eee;border-radius:4px;padding:2px 6px;color:#222;">{html.escape(label)}</span>'
124
+ if eq:
125
+ html_block += render_equation(eq)
126
+ if comment:
127
+ html_block += f'<div style="color:#999;margin-left:1em;">{html.escape(comment)}</div>'
128
+ html_block += '</div>'
129
+ return html_block
130
+
131
+ def section(title, items):
132
+ if not items:
133
+ return ''
134
+ html_items = ''.join(render_item(item) for item in items)
135
+ return f'<details open><summary style="font-size:1.1em;font-weight:bold;color:{palette["variable"]};">{title}</summary>{html_items}</details>'
136
+
137
+ variables = [i for i in calcset.items if getattr(i, 'item_type', None) == 'variable']
138
+ calculations = [i for i in calcset.items if getattr(i, 'item_type', None) == 'calculation']
139
+ filters = [i for i in calcset.items if getattr(i, 'item_type', None) == 'filter']
140
+
141
+ html_out = f'<div style="font-family:sans-serif;max-width:900px;color:{palette["text"]};">'
142
+ html_out += section('Variables', variables)
143
+ html_out += section('Calculations', calculations)
144
+ html_out += section('Filters', filters)
145
+ html_out += '</div>'
146
+ if display_output:
147
+ display(HTML(html_out))
148
+ else:
149
+ return html_out
pollywog/helpers.py ADDED
@@ -0,0 +1,246 @@
1
+ from .core import If, IfRow, Number, Category
2
+ from .utils import ensure_variables
3
+
4
+
5
+ def Sum(*variables, name=None, comment=None):
6
+ """
7
+ Create a Number representing the sum of the given variables.
8
+
9
+ Args:
10
+ *variables: Variable names (as strings) to sum, e.g. "Au", "Ag", or a single list of variable names as strings.
11
+ name (str, optional): Name for the output variable. If None, defaults to "sum_<var1>_<var2>_..."
12
+ comment (str, optional): Optional comment for the calculation.
13
+
14
+ Returns:
15
+ Number: A pollywog Number representing the sum calculation.
16
+
17
+ Example:
18
+ >>> Sum("Au", "Ag", name="sum_Au_Ag")
19
+ """
20
+ if not variables:
21
+ raise ValueError("At least one variable must be provided.")
22
+ if len(variables) == 1 and isinstance(variables[0], (list, tuple)):
23
+ variables = variables[0]
24
+ if name is None:
25
+ name = "sum_" + "_".join(variables)
26
+ expr = f"({' + '.join(f'[{v}]' for v in variables)})"
27
+ return Number(
28
+ name, [expr], comment_equation=comment or f"Sum of {', '.join(variables)}"
29
+ )
30
+
31
+
32
+ def Product(*variables, name=None, comment=None):
33
+ """
34
+ Create a Number representing the product of the given variables.
35
+
36
+ Args:
37
+ *variables: Variable names (as strings) to multiply, e.g. "Au", "Ag", or a single list of variable names as strings.
38
+ name (str, optional): Name for the output variable. If None, defaults to "prod_<var1>_<var2>_..."
39
+ comment (str, optional): Optional comment for the calculation.
40
+
41
+ Returns:
42
+ Number: A pollywog Number representing the product calculation.
43
+
44
+ Example:
45
+ >>> Product("Au", "Ag", name="prod_Au_Ag")
46
+ """
47
+ if not variables:
48
+ raise ValueError("At least one variable must be provided.")
49
+ if len(variables) == 1 and isinstance(variables[0], (list, tuple)):
50
+ variables = variables[0]
51
+ if name is None:
52
+ name = "prod_" + "_".join(variables)
53
+ expr = f"({' * '.join(f'[{v}]' for v in variables)})"
54
+ return Number(
55
+ name, [expr], comment_equation=comment or f"Product of {', '.join(variables)}"
56
+ )
57
+
58
+
59
+ def Normalize(variable, min_value, max_value, name=None, comment=None):
60
+ """
61
+ Create a Number that normalizes a variable to [0, 1] given min and max values.
62
+
63
+ Args:
64
+ variable (str): Variable name to normalize.
65
+ min_value (float): Minimum value for normalization.
66
+ max_value (float): Maximum value for normalization.
67
+ name (str, optional): Name for the output variable. If None, defaults to "norm_<variable>".
68
+ comment (str, optional): Optional comment for the calculation.
69
+
70
+ Returns:
71
+ Number: A pollywog Number representing the normalization calculation.
72
+
73
+ Example:
74
+ >>> Normalize("Au", 0, 10, name="norm_Au")
75
+ """
76
+ if name is None:
77
+ name = f"norm_{variable}"
78
+ expr = f"([{variable}] - {min_value}) / ({max_value} - {min_value})"
79
+ return Number(
80
+ name,
81
+ [expr],
82
+ comment_equation=comment
83
+ or f"Normalize {variable} to [0, 1] using min={min_value}, max={max_value}",
84
+ )
85
+
86
+
87
+ def Average(*variables, name=None, comment=None):
88
+ """
89
+ Create a Number representing the average of the given variables.
90
+
91
+ Args:
92
+ *variables: Variable names (as strings) to average, e.g. "Au", "Ag", or a single list of variable names as strings.
93
+ name (str, optional): Name for the output variable. If None, defaults to "avg_<var1>_<var2>_..."
94
+ comment (str, optional): Optional comment for the calculation.
95
+
96
+ Returns:
97
+ Number: A pollywog Number representing the average calculation.
98
+
99
+ Example:
100
+ >>> Average("Au", "Ag", name="avg_Au_Ag")
101
+ """
102
+ if not variables:
103
+ raise ValueError("At least one variable must be provided.")
104
+ if len(variables) == 1 and isinstance(variables[0], list):
105
+ variables = variables[0]
106
+ if name is None:
107
+ name = "avg_" + "_".join(variables)
108
+ expr = f"({' + '.join(f'[{v}]' for v in variables)}) / {len(variables)}"
109
+ return Number(
110
+ name, [expr], comment_equation=comment or f"Average of {', '.join(variables)}"
111
+ )
112
+
113
+
114
+ def WeightedAverage(variables, weights, name=None, comment=None):
115
+ """
116
+ Create a Number representing the weighted average of variables.
117
+
118
+ Args:
119
+ variables (list of str): Variable names to average, e.g. ["Au", "Ag", "Cu"]
120
+ weights (list of float or string): Corresponding weights for each variable, either constant values ( e.g. [0.5, 0.3, 0.2]) or variable names (e.g. ["w1", "w2", "w3"]).
121
+ name (str, optional): Name for the output variable. If None, defaults to "wavg_<var1>_<var2>_..."
122
+ comment (str, optional): Optional comment for the calculation.
123
+
124
+ Returns:
125
+ Number: A pollywog Number representing the weighted average calculation.
126
+
127
+ Example:
128
+ >>> WeightedAverage(["Au", "Ag"], [0.7, 0.3], name="wavg_Au_Ag")
129
+ """
130
+ if not variables or not weights or len(variables) != len(weights):
131
+ raise ValueError("variables and weights must be non-empty and of equal length.")
132
+ if name is None:
133
+ name = "wavg_" + "_".join(variables)
134
+ weights = ensure_variables(weights)
135
+ sum_weights = " + ".join(weights)
136
+ weighted_terms = [f"[{v}] * {w}" for v, w in zip(variables, weights)]
137
+ expr = f"({' + '.join(weighted_terms)}) / ({sum_weights})"
138
+ return Number(
139
+ name,
140
+ [expr],
141
+ comment_equation=comment
142
+ or f"Weighted average of {', '.join(variables)} with weights {weights}",
143
+ )
144
+
145
+
146
+ def Scale(variable, factor, name=None, comment=None):
147
+ """
148
+ Create a Number that multiplies a variable by a factor.
149
+
150
+ Args:
151
+ variable (str): Variable name to scale.
152
+ factor (float or str): Scaling factor (can be a constant or another variable).
153
+ name (str, optional): Name for the output variable. If None, defaults to "scale_<variable>".
154
+ comment (str, optional): Optional comment for the calculation.
155
+
156
+ Returns:
157
+ Number: A pollywog Number representing the scaled variable.
158
+
159
+ Example:
160
+ >>> Scale("Au", 2, name="Au_scaled")
161
+ """
162
+ if name is None:
163
+ name = f"scale_{variable}"
164
+ factor_expr = f"[{factor}]" if isinstance(factor, str) else str(factor)
165
+ expr = f"[{variable}] * {factor_expr}"
166
+ return Number(
167
+ name, [expr], comment_equation=comment or f"Scale {variable} by {factor}"
168
+ )
169
+
170
+
171
+ def IfElse(condition, then, else_, name=None, comment=None, output_type=None):
172
+ """
173
+ Create a Number or Category with conditional logic (if/else).
174
+
175
+ Args:
176
+ condition (str): Condition expression (e.g., "[Au] > 1").
177
+ then: Value or expression if condition is true.
178
+ else_: Value or expression if condition is false.
179
+ name (str, optional): Name for the output variable.
180
+ comment (str, optional): Optional comment for the calculation.
181
+ output_type (type, optional): Number or Category (default: Number).
182
+
183
+ Returns:
184
+ Number or Category: A pollywog item with conditional logic.
185
+
186
+ Example:
187
+ >>> IfElse("[Au] > 1", "High", "Low", name="Au_class")
188
+ """
189
+ if output_type is None:
190
+ output_type = Number
191
+ if name is None:
192
+ name = "ifelse"
193
+ if isinstance(then, (int, float)):
194
+ then_expr = str(then)
195
+ else:
196
+ then_expr = then
197
+ if isinstance(else_, (int, float)):
198
+ else_expr = str(else_)
199
+ else:
200
+ else_expr = else_
201
+ if_block = If([IfRow([condition], [then_expr])], otherwise=[else_expr])
202
+ return output_type(
203
+ name,
204
+ [if_block],
205
+ comment_equation=comment or f"If {condition} then {then_expr} else {else_expr}",
206
+ )
207
+
208
+
209
+ def CategoryFromThresholds(variable, thresholds, categories, name=None, comment=None):
210
+ """
211
+ Create a Category assigning labels based on value thresholds.
212
+
213
+ Args:
214
+ variable (str): Variable to threshold.
215
+ thresholds (list of float): Threshold values (must be sorted ascending).
216
+ categories (list of str): Category labels (len(categories) == len(thresholds) + 1).
217
+ name (str, optional): Name for the output category.
218
+ comment (str, optional): Optional comment for the calculation.
219
+
220
+ Returns:
221
+ Category: A pollywog Category assigning labels based on thresholds.
222
+
223
+ Example:
224
+ >>> CategoryFromThresholds("Au", [0.5, 1.0], ["Low", "Medium", "High"], name="Au_class")
225
+ """
226
+ if len(categories) != len(thresholds) + 1:
227
+ raise ValueError("categories must have one more element than thresholds")
228
+ rows = []
229
+ prev = None
230
+ for i, threshold in enumerate(thresholds):
231
+ if prev is None:
232
+ cond = f"[{variable}] <= {threshold}"
233
+ else:
234
+ cond = f"([{variable}] > {prev}) and ([{variable}] <= {threshold})"
235
+ rows.append(([cond], [categories[i]]))
236
+ prev = threshold
237
+ # Otherwise case
238
+ otherwise = [categories[-1]]
239
+ if_block = If([IfRow(cond, val) for cond, val in rows], otherwise=otherwise)
240
+ if name is None:
241
+ name = f"class_{variable}"
242
+ return Category(
243
+ name,
244
+ [if_block],
245
+ comment_equation=comment or f"Classify {variable} by thresholds {thresholds}",
246
+ )
pollywog/run.py ADDED
@@ -0,0 +1,116 @@
1
+ import copy
2
+ from pollywog.core import CalcSet, If, IfRow
3
+ import re
4
+
5
+
6
+ def run_calcset(
7
+ calcset, inputs=None, dataframe=None, assign_results=True, output_variables=False
8
+ ):
9
+ """
10
+ Evaluate a CalcSet with external inputs or a pandas DataFrame.
11
+ Returns a dict of results (if inputs provided) or a DataFrame (if dataframe provided).
12
+ Pandas is only required if using DataFrame input/output.
13
+ By default, only calculations, categories, and filters are output (Leapfrog-like).
14
+ Set output_variables=True to include variables in output (for debugging).
15
+ """
16
+
17
+ # Helper to evaluate an expression or If object
18
+ def eval_expr(expr, context):
19
+ if isinstance(expr, str):
20
+ if not expr.strip():
21
+ return None
22
+
23
+ # Replace [var] with context["var"] using regex
24
+ def repl(m):
25
+ var = m.group(1)
26
+ return f"context[{repr(var)}]"
27
+
28
+ expr_eval = re.sub(r"\[([^\]]+)\]", repl, expr)
29
+ try:
30
+ return eval(expr_eval, {"context": context}, context)
31
+ except Exception:
32
+ return None
33
+ elif isinstance(expr, If):
34
+ for row in expr.rows:
35
+ cond = eval_expr(row.condition[0], context) if row.condition else True
36
+ if cond:
37
+ return eval_expr(row.value[0], context)
38
+ if expr.otherwise:
39
+ return eval_expr(expr.otherwise[0], context)
40
+ return None
41
+ elif isinstance(expr, IfRow):
42
+ # Should not be evaluated directly, only as part of If
43
+ return None
44
+ else:
45
+ return expr
46
+
47
+ # Dependency resolution
48
+ sorted_items = calcset.topological_sort().items
49
+
50
+ def run_single(context):
51
+ results = {}
52
+ for item in sorted_items:
53
+ # If item is a Variable, assign its value from context or inputs directly
54
+ if getattr(item, "item_type", None) == "variable":
55
+ results[item.name] = context.get(item.name, None)
56
+ continue
57
+ child_results = []
58
+ for child in item.children:
59
+ child_results.append(eval_expr(child, {**context, **results}))
60
+ results[item.name] = child_results[0] if child_results else None
61
+ # Filter output according to output_variables flag
62
+ item_type_map = {
63
+ item.name: getattr(item, "item_type", None) for item in sorted_items
64
+ }
65
+ if not output_variables:
66
+ return {
67
+ k: v for k, v in results.items() if item_type_map.get(k) != "variable"
68
+ }
69
+ return results
70
+
71
+ if dataframe is not None:
72
+ try:
73
+ import pandas as pd
74
+ except ImportError:
75
+ raise ImportError(
76
+ "pandas is required for DataFrame input/output. Please install pandas or use dict inputs."
77
+ )
78
+ df = dataframe.copy()
79
+ for idx, row in df.iterrows():
80
+ context = dict(row)
81
+ results = run_single(context)
82
+ for k, v in results.items():
83
+ df.at[idx, k] = v
84
+ # Remove variable columns if output_variables is False
85
+ if not output_variables:
86
+ variable_names = [
87
+ item.name
88
+ for item in sorted_items
89
+ if getattr(item, "item_type", None) == "variable"
90
+ ]
91
+ df = df.drop(columns=variable_names, errors="ignore")
92
+ return df
93
+ else:
94
+ context = inputs if inputs is not None else {}
95
+ return run_single(context)
96
+
97
+
98
+ # Pandas DataFrame extension accessor
99
+ try:
100
+ import pandas as pd
101
+
102
+ @pd.api.extensions.register_dataframe_accessor("pw")
103
+ class PollywogAccessor:
104
+ def __init__(self, pandas_obj):
105
+ self._obj = pandas_obj
106
+
107
+ def run(self, calcset, assign_results=True):
108
+ """
109
+ Run a CalcSet on this DataFrame, returning a copy with results assigned.
110
+ """
111
+ return run_calcset(
112
+ calcset, dataframe=self._obj, assign_results=assign_results
113
+ )
114
+
115
+ except ImportError:
116
+ pass
pollywog/utils.py ADDED
@@ -0,0 +1,95 @@
1
+ def ensure_list(x):
2
+ """
3
+ Ensure the input is a list. If not, wrap it in a list.
4
+
5
+ Args:
6
+ x (Any): Input value or list.
7
+ Returns:
8
+ list: The input as a list.
9
+ """
10
+ if not isinstance(x, list):
11
+ x = [x]
12
+ return x
13
+
14
+
15
+ def ensure_str_list(x):
16
+ """
17
+ Ensure the input is a list of strings, padding with empty strings if needed.
18
+
19
+ Args:
20
+ x (Any): Input value or list.
21
+ Returns:
22
+ list: List of strings, padded with empty strings if necessary.
23
+ """
24
+ if not isinstance(x, list):
25
+ x = [x]
26
+ if isinstance(x, list):
27
+ if not isinstance(x[0], str):
28
+ x = [""] + x
29
+ if not isinstance(x[-1], str):
30
+ x = x + [""]
31
+ return x
32
+
33
+
34
+ def to_dict(items, guard_strings=False):
35
+ """
36
+ Convert a list of items to their dictionary representations if possible.
37
+
38
+ Args:
39
+ items (list): List of items or single item.
40
+ guard_strings (bool): If True, pad with empty strings if first/last are not strings.
41
+ Returns:
42
+ list: List of dicts or items, possibly padded with empty strings.
43
+ """
44
+ out = [
45
+ item.to_dict() if hasattr(item, "to_dict") else item
46
+ for item in ensure_list(items)
47
+ ]
48
+ if guard_strings:
49
+ if not isinstance(out[0], str):
50
+ out = [""] + out
51
+ if not isinstance(out[-1], str):
52
+ out = out + [""]
53
+ return out
54
+
55
+ def is_number(v):
56
+ """
57
+ Check if the input can be converted to a float (i.e., is a number).
58
+
59
+ Args:
60
+ v (Any): Input value.
61
+ Returns:
62
+ bool: True if input is a number, False otherwise.
63
+ """
64
+ try:
65
+ float(v)
66
+ return True
67
+ except (ValueError, TypeError):
68
+ return False
69
+
70
+ def ensure_brackets(var):
71
+ """
72
+ Ensure the variable name is wrapped in brackets [var].
73
+
74
+ Args:
75
+ var (str): Variable name.
76
+ Returns:
77
+ str: Variable name wrapped in brackets.
78
+ """
79
+ var = var.strip()
80
+ if not (var.startswith("[") and var.endswith("]")):
81
+ var = f"[{var}]"
82
+ return var
83
+
84
+ def ensure_variables(variables):
85
+ """
86
+ Ensures that each item in the input is formatted as a variable.
87
+ For each item in `variables`, if the item is a number, it is converted to a string.
88
+ Otherwise, it is passed to `ensure_brackets` to ensure proper bracket formatting.
89
+ Args:
90
+ variables (Any): A single variable or a list of variables to be processed.
91
+ Returns:
92
+ list: A list of formatted variable strings.
93
+ """
94
+
95
+ return [f"{v}" if is_number(v) else ensure_brackets(v) for v in ensure_list(variables)]
tests/__init__.py ADDED
File without changes
tests/conftest.py ADDED
@@ -0,0 +1,4 @@
1
+ import sys
2
+ import os
3
+
4
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@@ -0,0 +1,72 @@
1
+ import pytest
2
+ from pollywog.conversion.sklearn import convert_tree, convert_linear_model
3
+ from pollywog.core import Number, Category
4
+ from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
5
+ from sklearn.linear_model import LinearRegression
6
+ import numpy as np
7
+
8
+
9
+ def make_regressor():
10
+ X = np.array([[0, 0], [1, 1], [2, 2]])
11
+ y = np.array([0, 1, 2])
12
+ return DecisionTreeRegressor().fit(X, y)
13
+
14
+
15
+ def make_linear():
16
+ X = np.array([[0, 0], [1, 0], [0, 1]])
17
+ y = np.array([1, 3, 1])
18
+ lm = LinearRegression().fit(X, y)
19
+ lm.coef_ = np.array([2.0, 0.0])
20
+ lm.intercept_ = 1.0
21
+ return lm
22
+
23
+
24
+ class DummyTree:
25
+ class tree_:
26
+ feature = [0, -2, -2]
27
+ threshold = [1.5, 0, 0]
28
+ children_left = [1, -1, -1]
29
+ children_right = [2, -1, -1]
30
+ value = [[[0]], [[1]], [[2]]]
31
+
32
+ class _tree:
33
+ class Dummy:
34
+ pass
35
+
36
+
37
+ def test_convert_tree_regressor():
38
+ tree = make_regressor()
39
+ result = convert_tree(tree, ["x1", "x2"], "target")
40
+ assert isinstance(result, Number)
41
+ assert result.name == "target"
42
+ assert "Converted from DecisionTreeRegressor" in result.comment_equation
43
+
44
+
45
+ def test_convert_tree_classifier():
46
+ X = np.array([[0, 0], [1, 1], [2, 2]])
47
+ y = np.array([0, 1, 2])
48
+ tree = DecisionTreeClassifier().fit(X, y)
49
+ result = convert_tree(tree, ["x1", "x2"], "target")
50
+ assert isinstance(result, Category)
51
+ assert result.name == "target"
52
+ assert "Converted from DecisionTreeClassifier" in result.comment_equation
53
+
54
+
55
+ def test_convert_tree_invalid():
56
+ class Dummy:
57
+ pass
58
+
59
+ dummy = Dummy()
60
+ dummy.__class__ = type("UnknownTree", (), {})
61
+ with pytest.raises(Exception):
62
+ convert_tree(dummy, ["x1"], "target")
63
+
64
+
65
+ def test_convert_linear_model():
66
+ lm = make_linear()
67
+ result = convert_linear_model(lm, ["x1", "x2"], "target")
68
+ assert isinstance(result, Number)
69
+ assert result.name == "target"
70
+ assert "Converted from LinearRegression" in result.comment_equation
71
+ assert "1.000000" in result.children[0]
72
+ assert "2.000000 * [x1]" in result.children[0]