pepflow 0.1.3a1__py3-none-any.whl → 0.1.4a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pepflow/__init__.py +1 -0
- pepflow/constraint_test.py +71 -0
- pepflow/e2e_test.py +69 -0
- pepflow/expression_manager.py +72 -2
- pepflow/expression_manager_test.py +116 -0
- pepflow/function.py +142 -48
- pepflow/function_test.py +249 -108
- pepflow/interactive_constraint.py +165 -75
- pepflow/pep.py +18 -3
- pepflow/pep_context.py +12 -7
- pepflow/pep_context_test.py +23 -21
- pepflow/pep_test.py +8 -0
- pepflow/point.py +43 -8
- pepflow/point_test.py +106 -308
- pepflow/scalar.py +39 -1
- pepflow/scalar_test.py +207 -0
- pepflow/solver_test.py +7 -7
- pepflow/utils.py +14 -1
- {pepflow-0.1.3a1.dist-info → pepflow-0.1.4a1.dist-info}/METADATA +19 -1
- pepflow-0.1.4a1.dist-info/RECORD +26 -0
- pepflow-0.1.3a1.dist-info/RECORD +0 -22
- {pepflow-0.1.3a1.dist-info → pepflow-0.1.4a1.dist-info}/WHEEL +0 -0
- {pepflow-0.1.3a1.dist-info → pepflow-0.1.4a1.dist-info}/licenses/LICENSE +0 -0
- {pepflow-0.1.3a1.dist-info → pepflow-0.1.4a1.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ from __future__ import annotations
|
|
22
22
|
import json
|
23
23
|
from typing import TYPE_CHECKING
|
24
24
|
|
25
|
+
import attrs
|
25
26
|
import dash
|
26
27
|
import dash_bootstrap_components as dbc
|
27
28
|
import numpy as np
|
@@ -29,11 +30,12 @@ import pandas as pd
|
|
29
30
|
import plotly
|
30
31
|
import plotly.express as px
|
31
32
|
import plotly.graph_objects as go
|
32
|
-
from dash import Dash, Input, Output, State, dcc, html
|
33
|
+
from dash import ALL, MATCH, Dash, Input, Output, State, dcc, html
|
33
34
|
|
34
35
|
from pepflow.constants import PSD_CONSTRAINT
|
35
36
|
|
36
37
|
if TYPE_CHECKING:
|
38
|
+
from pepflow.function import Function
|
37
39
|
from pepflow.pep import PEPBuilder, PEPResult
|
38
40
|
from pepflow.pep_context import PEPContext
|
39
41
|
|
@@ -42,26 +44,84 @@ plotly.io.renderers.default = "colab+vscode"
|
|
42
44
|
plotly.io.templates.default = "plotly_white"
|
43
45
|
|
44
46
|
|
47
|
+
@attrs.frozen
|
48
|
+
class PlotData:
|
49
|
+
dataframe: pd.DataFrame
|
50
|
+
figure: go.Figure
|
51
|
+
function: Function
|
52
|
+
|
53
|
+
def dual_matrix_to_tab(self) -> html.Pre:
|
54
|
+
def get_matrix_of_dual_value(df: pd.DataFrame) -> np.ndarray:
|
55
|
+
# Check if we need to update the order.
|
56
|
+
return (
|
57
|
+
pd.pivot_table(
|
58
|
+
df, values="dual_value", index="row", columns="col", dropna=False
|
59
|
+
)
|
60
|
+
.fillna(0.0)
|
61
|
+
.to_numpy()
|
62
|
+
.T
|
63
|
+
)
|
64
|
+
|
65
|
+
with np.printoptions(precision=3, linewidth=100, suppress=True):
|
66
|
+
dual_value_tab = html.Pre(
|
67
|
+
str(get_matrix_of_dual_value(self.dataframe)),
|
68
|
+
id={"type": "dual-value-display", "index": self.function.tag},
|
69
|
+
style={
|
70
|
+
"border": "1px solid lightgrey",
|
71
|
+
"padding": "10px",
|
72
|
+
"height": "60vh",
|
73
|
+
"overflowY": "auto",
|
74
|
+
},
|
75
|
+
)
|
76
|
+
return dual_value_tab
|
77
|
+
|
78
|
+
def plot_data_to_tab(self) -> dbc.Tab:
|
79
|
+
tab = dbc.Tab(
|
80
|
+
html.Div(
|
81
|
+
[
|
82
|
+
html.P("Interactive Heat Map:"),
|
83
|
+
dcc.Graph(
|
84
|
+
id={
|
85
|
+
"type": "interactive-scatter",
|
86
|
+
"index": self.function.tag,
|
87
|
+
},
|
88
|
+
figure=self.figure,
|
89
|
+
),
|
90
|
+
html.P("Dual Value Matrix:"),
|
91
|
+
self.dual_matrix_to_tab(),
|
92
|
+
]
|
93
|
+
),
|
94
|
+
label=f"{self.function.tag}-Interpolation Conditions",
|
95
|
+
tab_id=f"{self.function.tag}-interactive-scatter-tab",
|
96
|
+
)
|
97
|
+
return tab
|
98
|
+
|
99
|
+
|
45
100
|
def solve_prob_and_get_figure(
|
46
101
|
pep_builder: PEPBuilder, context: PEPContext
|
47
|
-
) -> tuple[
|
48
|
-
|
102
|
+
) -> tuple[list[PlotData], PEPResult]:
|
103
|
+
plot_data_list: list[PlotData] = []
|
49
104
|
|
50
105
|
result = pep_builder.solve(context=context)
|
51
106
|
|
52
107
|
df_dict, order_dict = context.triplets_to_df_and_order()
|
53
|
-
f = pep_builder.functions[0]
|
54
108
|
|
55
|
-
|
56
|
-
|
109
|
+
for f in context.triplets.keys():
|
110
|
+
df = df_dict[f]
|
111
|
+
order = order_dict[f]
|
57
112
|
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
113
|
+
df["constraint"] = df.constraint_name.map(
|
114
|
+
lambda x: "inactive" if x in pep_builder.relaxed_constraints else "active"
|
115
|
+
)
|
116
|
+
df["dual_value"] = df.constraint_name.map(
|
117
|
+
lambda x: result.dual_var_manager.dual_value(x)
|
118
|
+
)
|
119
|
+
|
120
|
+
plot_data_list.append(
|
121
|
+
PlotData(dataframe=df, figure=processed_df_to_fig(df, order), function=f)
|
122
|
+
)
|
123
|
+
|
124
|
+
return plot_data_list, result
|
65
125
|
|
66
126
|
|
67
127
|
def processed_df_to_fig(df: pd.DataFrame, order: list[str]):
|
@@ -99,19 +159,7 @@ def get_matrix_of_dual_value(df: pd.DataFrame) -> np.ndarray:
|
|
99
159
|
|
100
160
|
def launch(pep_builder: PEPBuilder, context: PEPContext):
|
101
161
|
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
|
102
|
-
|
103
|
-
|
104
|
-
with np.printoptions(precision=3, linewidth=100, suppress=True):
|
105
|
-
dual_value_tab = html.Pre(
|
106
|
-
str(get_matrix_of_dual_value(df)),
|
107
|
-
id="dual-value-display",
|
108
|
-
style={
|
109
|
-
"border": "1px solid lightgrey",
|
110
|
-
"padding": "10px",
|
111
|
-
"height": "60vh",
|
112
|
-
"overflowY": "auto",
|
113
|
-
},
|
114
|
-
)
|
162
|
+
plot_data_list, result = solve_prob_and_get_figure(pep_builder, context)
|
115
163
|
# Think how can we manipulate the pep_builder here.
|
116
164
|
display_row = dbc.Row(
|
117
165
|
[
|
@@ -129,21 +177,8 @@ def launch(pep_builder: PEPBuilder, context: PEPContext):
|
|
129
177
|
color="success",
|
130
178
|
),
|
131
179
|
dbc.Tabs(
|
132
|
-
[
|
133
|
-
|
134
|
-
html.Div(
|
135
|
-
dcc.Graph(id="interactive-scatter", figure=fig),
|
136
|
-
),
|
137
|
-
label="Interactive Heatmap",
|
138
|
-
tab_id="interactive-scatter-tab",
|
139
|
-
),
|
140
|
-
dbc.Tab(
|
141
|
-
dual_value_tab,
|
142
|
-
label="Dual Value Matrix",
|
143
|
-
tab_id="dual_value_tab",
|
144
|
-
),
|
145
|
-
],
|
146
|
-
active_tab="interactive-scatter-tab",
|
180
|
+
[plot_data.plot_data_to_tab() for plot_data in plot_data_list],
|
181
|
+
active_tab=f"{plot_data_list[0].function.tag}-interactive-scatter-tab",
|
147
182
|
),
|
148
183
|
],
|
149
184
|
width=5,
|
@@ -175,25 +210,37 @@ def launch(pep_builder: PEPBuilder, context: PEPContext):
|
|
175
210
|
[
|
176
211
|
html.H2("PEPFlow"),
|
177
212
|
display_row,
|
178
|
-
#
|
179
|
-
|
213
|
+
# For each function, store the corresponding DataFrame as a dictionary in dcc.Store
|
214
|
+
*[
|
215
|
+
dcc.Store(
|
216
|
+
id={
|
217
|
+
"type": "dataframe-store",
|
218
|
+
"index": plot_data.function.tag,
|
219
|
+
},
|
220
|
+
data=(
|
221
|
+
plot_data.function.tag,
|
222
|
+
plot_data.dataframe.to_dict("records"),
|
223
|
+
),
|
224
|
+
)
|
225
|
+
for plot_data in plot_data_list
|
226
|
+
],
|
180
227
|
]
|
181
228
|
)
|
182
229
|
|
183
230
|
@dash.callback(
|
184
231
|
Output("result-card", "children"),
|
185
|
-
Output("dual-value-display", "children"),
|
186
|
-
Output("interactive-scatter", "figure"),
|
187
|
-
Output("dataframe-store", "data"),
|
232
|
+
Output({"type": "dual-value-display", "index": ALL}, "children"),
|
233
|
+
Output({"type": "interactive-scatter", "index": ALL}, "figure"),
|
234
|
+
Output({"type": "dataframe-store", "index": ALL}, "data"),
|
188
235
|
Input("solve-button", "n_clicks"),
|
189
236
|
)
|
190
237
|
def solve(_):
|
191
|
-
|
238
|
+
plot_data_list, result = solve_prob_and_get_figure(pep_builder, context)
|
192
239
|
with np.printoptions(precision=3, linewidth=100, suppress=True):
|
193
240
|
psd_dual_value = np.array(
|
194
241
|
result.dual_var_manager.dual_value(PSD_CONSTRAINT)
|
195
242
|
)
|
196
|
-
|
243
|
+
result_card = dbc.CardBody(
|
197
244
|
[
|
198
245
|
html.H2(f"Optimal Value {result.primal_opt_value:.4g}"),
|
199
246
|
html.H3(f"Solver Status: {result.solver_status}"),
|
@@ -203,50 +250,89 @@ def launch(pep_builder: PEPBuilder, context: PEPContext):
|
|
203
250
|
html.Pre(json.dumps(pep_builder.relaxed_constraints, indent=2)),
|
204
251
|
]
|
205
252
|
)
|
206
|
-
|
207
|
-
|
253
|
+
dual_value_displays = [
|
254
|
+
str(get_matrix_of_dual_value(plot_data.dataframe))
|
255
|
+
for plot_data in plot_data_list
|
256
|
+
]
|
257
|
+
figs = [plot_data.figure for plot_data in plot_data_list]
|
258
|
+
|
259
|
+
df_data = [
|
260
|
+
(plot_data.function.tag, plot_data.dataframe.to_dict("records"))
|
261
|
+
for plot_data in plot_data_list
|
262
|
+
]
|
263
|
+
|
264
|
+
return result_card, dual_value_displays, figs, df_data
|
208
265
|
|
209
266
|
@dash.callback(
|
210
|
-
Output(
|
211
|
-
|
267
|
+
Output(
|
268
|
+
{"type": "interactive-scatter", "index": ALL},
|
269
|
+
"figure",
|
270
|
+
allow_duplicate=True,
|
271
|
+
),
|
272
|
+
Output({"type": "dataframe-store", "index": ALL}, "data", allow_duplicate=True),
|
212
273
|
Input("restore-all-constraints-button", "n_clicks"),
|
213
|
-
State("dataframe-store", "data"),
|
274
|
+
State({"type": "dataframe-store", "index": ALL}, "data"),
|
214
275
|
prevent_initial_call=True,
|
215
276
|
)
|
216
|
-
def restore_all_constraints(_,
|
277
|
+
def restore_all_constraints(_, list_previous_df_tuples):
|
217
278
|
nonlocal pep_builder
|
218
|
-
df_updated = pd.DataFrame(previous_df)
|
219
279
|
pep_builder.relaxed_constraints = []
|
220
|
-
|
221
|
-
|
222
|
-
|
280
|
+
updated_figs = []
|
281
|
+
df_data = []
|
282
|
+
for previous_df_tuple in list_previous_df_tuples:
|
283
|
+
tag, previous_df = previous_df_tuple
|
284
|
+
df_updated = pd.DataFrame(previous_df)
|
285
|
+
df_updated["constraint"] = "active"
|
286
|
+
order = context.order_of_point(pep_builder.get_func_by_tag(tag))
|
287
|
+
updated_figs.append(processed_df_to_fig(df_updated, order))
|
288
|
+
df_data.append((tag, df_updated.to_dict("records")))
|
289
|
+
return updated_figs, df_data
|
223
290
|
|
224
291
|
@dash.callback(
|
225
|
-
Output(
|
226
|
-
|
292
|
+
Output(
|
293
|
+
{"type": "interactive-scatter", "index": ALL},
|
294
|
+
"figure",
|
295
|
+
allow_duplicate=True,
|
296
|
+
),
|
297
|
+
Output({"type": "dataframe-store", "index": ALL}, "data", allow_duplicate=True),
|
227
298
|
Input("relax-all-constraints-button", "n_clicks"),
|
228
|
-
State("dataframe-store", "data"),
|
299
|
+
State({"type": "dataframe-store", "index": ALL}, "data"),
|
229
300
|
prevent_initial_call=True,
|
230
301
|
)
|
231
|
-
def relax_all_constraints(_,
|
302
|
+
def relax_all_constraints(_, list_previous_df_tuples):
|
232
303
|
nonlocal pep_builder
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
304
|
+
pep_builder.relaxed_constraints = []
|
305
|
+
updated_figs = []
|
306
|
+
df_data = []
|
307
|
+
for previous_df_tuple in list_previous_df_tuples:
|
308
|
+
tag, previous_df = previous_df_tuple
|
309
|
+
df_updated = pd.DataFrame(previous_df)
|
310
|
+
pep_builder.relaxed_constraints.extend(
|
311
|
+
df_updated["constraint_name"].to_list()
|
312
|
+
)
|
313
|
+
df_updated["constraint"] = "inactive"
|
314
|
+
order = context.order_of_point(pep_builder.get_func_by_tag(tag))
|
315
|
+
updated_figs.append(processed_df_to_fig(df_updated, order))
|
316
|
+
df_data.append((tag, df_updated.to_dict("records")))
|
317
|
+
return updated_figs, df_data
|
238
318
|
|
239
319
|
@dash.callback(
|
240
|
-
Output(
|
241
|
-
|
242
|
-
|
243
|
-
|
320
|
+
Output(
|
321
|
+
{"type": "interactive-scatter", "index": MATCH},
|
322
|
+
"figure",
|
323
|
+
allow_duplicate=True,
|
324
|
+
),
|
325
|
+
Output(
|
326
|
+
{"type": "dataframe-store", "index": MATCH}, "data", allow_duplicate=True
|
327
|
+
),
|
328
|
+
Input({"type": "interactive-scatter", "index": MATCH}, "clickData"),
|
329
|
+
State({"type": "dataframe-store", "index": MATCH}, "data"),
|
244
330
|
prevent_initial_call=True,
|
245
331
|
)
|
246
|
-
def update_df_and_redraw(clickData,
|
332
|
+
def update_df_and_redraw(clickData, previous_df_tuple):
|
247
333
|
nonlocal pep_builder
|
248
334
|
if not clickData["points"][0]["customdata"]:
|
249
|
-
return dash.no_update, dash.no_update
|
335
|
+
return dash.no_update, dash.no_update
|
250
336
|
|
251
337
|
clicked_name = clickData["points"][0]["customdata"][0]
|
252
338
|
if clicked_name not in pep_builder.relaxed_constraints:
|
@@ -254,11 +340,15 @@ def launch(pep_builder: PEPBuilder, context: PEPContext):
|
|
254
340
|
else:
|
255
341
|
pep_builder.relaxed_constraints.remove(clicked_name)
|
256
342
|
|
343
|
+
tag, previous_df = previous_df_tuple
|
257
344
|
df_updated = pd.DataFrame(previous_df)
|
258
345
|
df_updated["constraint"] = df_updated.constraint_name.map(
|
259
346
|
lambda x: "inactive" if x in pep_builder.relaxed_constraints else "active"
|
260
347
|
)
|
261
|
-
order = context.order_of_point(pep_builder.
|
262
|
-
return processed_df_to_fig(df_updated, order),
|
348
|
+
order = context.order_of_point(pep_builder.get_func_by_tag(tag))
|
349
|
+
return processed_df_to_fig(df_updated, order), (
|
350
|
+
tag,
|
351
|
+
df_updated.to_dict("records"),
|
352
|
+
)
|
263
353
|
|
264
354
|
app.run(debug=True)
|
pepflow/pep.py
CHANGED
@@ -87,6 +87,13 @@ class PEPBuilder:
|
|
87
87
|
# We should think about a better choice like manager.
|
88
88
|
self.relaxed_constraints = []
|
89
89
|
|
90
|
+
def clear_setup(self):
|
91
|
+
self.init_conditions.clear()
|
92
|
+
self.functions.clear()
|
93
|
+
self.interpolation_constraints.clear()
|
94
|
+
self.performance_metric = None
|
95
|
+
self.relaxed_constraints.clear()
|
96
|
+
|
90
97
|
@contextlib.contextmanager
|
91
98
|
def make_context(
|
92
99
|
self, name: str, override: bool = False
|
@@ -94,7 +101,8 @@ class PEPBuilder:
|
|
94
101
|
if not override and name in self.pep_context_dict:
|
95
102
|
raise KeyError(f"There is already a context {name} in the builder")
|
96
103
|
try:
|
97
|
-
|
104
|
+
self.clear_setup()
|
105
|
+
ctx = pc.PEPContext(name)
|
98
106
|
self.pep_context_dict[name] = ctx
|
99
107
|
pc.set_current_context(ctx)
|
100
108
|
yield ctx
|
@@ -130,11 +138,19 @@ class PEPBuilder:
|
|
130
138
|
def set_relaxed_constraints(self, relaxed_constraints: list[str]):
|
131
139
|
self.relaxed_constraints.extend(relaxed_constraints)
|
132
140
|
|
133
|
-
def declare_func(self, function_class, **kwargs):
|
141
|
+
def declare_func(self, function_class: type[Function], tag: str, **kwargs):
|
134
142
|
func = function_class(is_basis=True, composition=None, **kwargs)
|
143
|
+
func.add_tag(tag)
|
135
144
|
self.functions.append(func)
|
136
145
|
return func
|
137
146
|
|
147
|
+
def get_func_by_tag(self, tag: str):
|
148
|
+
# TODO: Add support to return composite functions as well. Right now we can only return base functions
|
149
|
+
for f in self.functions:
|
150
|
+
if tag in f.tags:
|
151
|
+
return f
|
152
|
+
raise ValueError("Cannot find the function of given tag.")
|
153
|
+
|
138
154
|
def solve(self, context: pc.PEPContext | None = None, **kwargs):
|
139
155
|
if context is None:
|
140
156
|
context = pc.get_current_context()
|
@@ -144,7 +160,6 @@ class PEPBuilder:
|
|
144
160
|
all_constraints: list[Constraint] = [*self.init_conditions]
|
145
161
|
for f in self.functions:
|
146
162
|
all_constraints.extend(f.get_interpolation_constraints())
|
147
|
-
all_constraints.extend(context.opt_conditions[f])
|
148
163
|
|
149
164
|
# for now, we heavily rely on the CVX. We can make a wrapper class to avoid
|
150
165
|
# direct dependency in the future.
|
pepflow/pep_context.py
CHANGED
@@ -26,13 +26,14 @@ import natsort
|
|
26
26
|
import pandas as pd
|
27
27
|
|
28
28
|
if TYPE_CHECKING:
|
29
|
-
from pepflow.constraint import Constraint
|
30
29
|
from pepflow.function import Function, Triplet
|
31
30
|
from pepflow.point import Point
|
32
31
|
from pepflow.scalar import Scalar
|
33
32
|
|
34
33
|
# A global variable for storing the current context that is used for points or scalars.
|
35
34
|
CURRENT_CONTEXT: PEPContext | None = None
|
35
|
+
# Keep the track of all previous created context
|
36
|
+
GLOBAL_CONTEXT_DICT: dict[str, PEPContext] = {}
|
36
37
|
|
37
38
|
|
38
39
|
def get_current_context() -> PEPContext | None:
|
@@ -46,11 +47,15 @@ def set_current_context(ctx: PEPContext | None):
|
|
46
47
|
|
47
48
|
|
48
49
|
class PEPContext:
|
49
|
-
def __init__(self):
|
50
|
+
def __init__(self, name: str):
|
51
|
+
self.name = name
|
50
52
|
self.points: list[Point] = []
|
51
53
|
self.scalars: list[Scalar] = []
|
52
54
|
self.triplets: dict[Function, list[Triplet]] = defaultdict(list)
|
53
|
-
self.
|
55
|
+
# self.triplets will contain all stationary_triplets. They are not mutually exclusive.
|
56
|
+
self.stationary_triplets: dict[Function, list[Triplet]] = defaultdict(list)
|
57
|
+
|
58
|
+
GLOBAL_CONTEXT_DICT[name] = self
|
54
59
|
|
55
60
|
def set_as_current(self) -> PEPContext:
|
56
61
|
set_current_context(self)
|
@@ -65,8 +70,8 @@ class PEPContext:
|
|
65
70
|
def add_triplet(self, function: Function, triplet: Triplet):
|
66
71
|
self.triplets[function].append(triplet)
|
67
72
|
|
68
|
-
def
|
69
|
-
self.
|
73
|
+
def add_stationary_triplet(self, function: Function, stationary_triplet: Triplet):
|
74
|
+
self.stationary_triplets[function].append(stationary_triplet)
|
70
75
|
|
71
76
|
def get_by_tag(self, tag: str) -> Point | Scalar:
|
72
77
|
for p in self.points:
|
@@ -75,13 +80,13 @@ class PEPContext:
|
|
75
80
|
for s in self.scalars:
|
76
81
|
if tag in s.tags:
|
77
82
|
return s
|
78
|
-
raise ValueError("Cannot find the point or
|
83
|
+
raise ValueError("Cannot find the point, scalar, or function of given tag.")
|
79
84
|
|
80
85
|
def clear(self):
|
81
86
|
self.points.clear()
|
82
87
|
self.scalars.clear()
|
83
88
|
self.triplets.clear()
|
84
|
-
self.
|
89
|
+
self.stationary_triplets.clear()
|
85
90
|
|
86
91
|
def tracked_point(self, func: Function) -> list[Point]:
|
87
92
|
return natsort.natsorted(
|
pepflow/pep_context_test.py
CHANGED
@@ -17,17 +17,25 @@
|
|
17
17
|
# specific language governing permissions and limitations
|
18
18
|
# under the License.
|
19
19
|
|
20
|
+
from typing import Iterator
|
21
|
+
|
20
22
|
import pandas as pd
|
23
|
+
import pytest
|
21
24
|
|
22
25
|
from pepflow import pep_context as pc
|
23
26
|
from pepflow.function import SmoothConvexFunction
|
24
27
|
from pepflow.point import Point
|
25
28
|
|
26
29
|
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
+
@pytest.fixture
|
31
|
+
def pep_context() -> Iterator[pc.PEPContext]:
|
32
|
+
"""Prepare the pep context and reset the context to None at the end."""
|
33
|
+
ctx = pc.PEPContext("test").set_as_current()
|
34
|
+
yield ctx
|
35
|
+
pc.set_current_context(None)
|
36
|
+
|
30
37
|
|
38
|
+
def test_tracked_points(pep_context: pc.PEPContext):
|
31
39
|
f = SmoothConvexFunction(L=1, is_basis=True)
|
32
40
|
f.add_tag("f")
|
33
41
|
|
@@ -41,16 +49,11 @@ def test_tracked_points():
|
|
41
49
|
_ = f.generate_triplet(p3)
|
42
50
|
_ = f.generate_triplet(p_star)
|
43
51
|
|
44
|
-
assert
|
45
|
-
assert
|
46
|
-
|
47
|
-
pc.set_current_context(None)
|
48
|
-
|
52
|
+
assert pep_context.order_of_point(f) == ["x_1", "x_2", "x_3", "x_*"]
|
53
|
+
assert pep_context.tracked_point(f) == [p1, p3, p2, p_star]
|
49
54
|
|
50
|
-
def test_triplets_to_dataframe():
|
51
|
-
ctx = pc.PEPContext()
|
52
|
-
pc.set_current_context(ctx)
|
53
55
|
|
56
|
+
def test_triplets_to_dataframe(pep_context: pc.PEPContext):
|
54
57
|
f = SmoothConvexFunction(L=1, is_basis=True)
|
55
58
|
f.add_tag("f")
|
56
59
|
|
@@ -62,7 +65,7 @@ def test_triplets_to_dataframe():
|
|
62
65
|
_ = f.generate_triplet(p2)
|
63
66
|
_ = f.generate_triplet(p3)
|
64
67
|
|
65
|
-
func_to_df, func_to_order =
|
68
|
+
func_to_df, func_to_order = pep_context.triplets_to_df_and_order()
|
66
69
|
expected_df = pd.DataFrame(
|
67
70
|
{
|
68
71
|
"constraint_name": [
|
@@ -83,20 +86,19 @@ def test_triplets_to_dataframe():
|
|
83
86
|
pd.testing.assert_frame_equal(func_to_df[f], expected_df)
|
84
87
|
assert func_to_order[f] == ["x1", "x2", "x3"]
|
85
88
|
|
86
|
-
pc.set_current_context(None)
|
87
|
-
|
88
|
-
|
89
|
-
def test_get_by_tag():
|
90
|
-
ctx = pc.PEPContext()
|
91
|
-
pc.set_current_context(ctx)
|
92
89
|
|
90
|
+
def test_get_by_tag(pep_context: pc.PEPContext):
|
93
91
|
f = SmoothConvexFunction(L=1, is_basis=True)
|
94
92
|
f.add_tag("f")
|
95
93
|
p1 = Point(is_basis=True, tags=["x1"])
|
94
|
+
p2 = Point(is_basis=True, tags=["x2"])
|
95
|
+
p3 = p1 + p2
|
96
96
|
|
97
97
|
triplet = f.generate_triplet(p1)
|
98
|
+
_ = f.generate_triplet(p2)
|
98
99
|
|
99
|
-
assert
|
100
|
-
assert
|
101
|
-
assert
|
100
|
+
assert pep_context.get_by_tag("x1") == p1
|
101
|
+
assert pep_context.get_by_tag("f(x1)") == triplet.function_value
|
102
|
+
assert pep_context.get_by_tag("gradient_f(x1)") == triplet.gradient
|
103
|
+
assert pep_context.get_by_tag("x1+x2") == p3
|
102
104
|
pc.set_current_context(None)
|
pepflow/pep_test.py
CHANGED
@@ -19,6 +19,7 @@
|
|
19
19
|
|
20
20
|
import pytest
|
21
21
|
|
22
|
+
from pepflow import function as fc
|
22
23
|
from pepflow import pep
|
23
24
|
from pepflow import pep_context as pc
|
24
25
|
|
@@ -75,3 +76,10 @@ class TestPEPBuilder:
|
|
75
76
|
|
76
77
|
with builder.make_context("test", override=True):
|
77
78
|
pass
|
79
|
+
|
80
|
+
def test_get_func_by_tag(self) -> None:
|
81
|
+
builder = pep.PEPBuilder()
|
82
|
+
with builder.make_context("test"):
|
83
|
+
f = builder.declare_func(fc.SmoothConvexFunction, "f", L=1)
|
84
|
+
|
85
|
+
assert builder.get_func_by_tag("f") == f
|
pepflow/point.py
CHANGED
@@ -134,75 +134,110 @@ class Point:
|
|
134
134
|
return self.tag
|
135
135
|
return super().__repr__()
|
136
136
|
|
137
|
+
def _repr_latex_(self):
|
138
|
+
s = repr(self)
|
139
|
+
s = s.replace("star", r"\star")
|
140
|
+
s = s.replace("gradient_", r"\nabla ")
|
141
|
+
s = s.replace("|", r"\|")
|
142
|
+
return rf"$\\displaystyle {s}$"
|
143
|
+
|
137
144
|
# TODO: add a validator that `is_basis` and `eval_expression` are properly setup.
|
138
145
|
def __add__(self, other):
|
139
|
-
assert
|
146
|
+
assert isinstance(other, Point)
|
140
147
|
return Point(
|
141
148
|
is_basis=False,
|
142
149
|
eval_expression=EvalExpressionPoint(utils.Op.ADD, self, other),
|
150
|
+
tags=[f"{self.tag}+{other.tag}"],
|
143
151
|
)
|
144
152
|
|
145
153
|
def __radd__(self, other):
|
146
|
-
|
154
|
+
# TODO: come up with better way to handle this
|
155
|
+
if other == 0:
|
156
|
+
return self
|
157
|
+
assert isinstance(other, Point)
|
147
158
|
return Point(
|
148
159
|
is_basis=False,
|
149
160
|
eval_expression=EvalExpressionPoint(utils.Op.ADD, other, self),
|
161
|
+
tags=[f"{other.tag}+{self.tag}"],
|
150
162
|
)
|
151
163
|
|
152
164
|
def __sub__(self, other):
|
153
|
-
assert
|
165
|
+
assert isinstance(other, Point)
|
166
|
+
tag_other = utils.parenthesize_tag(other)
|
154
167
|
return Point(
|
155
168
|
is_basis=False,
|
156
169
|
eval_expression=EvalExpressionPoint(utils.Op.SUB, self, other),
|
170
|
+
tags=[f"{self.tag}-{tag_other}"],
|
157
171
|
)
|
158
172
|
|
159
173
|
def __rsub__(self, other):
|
160
|
-
assert
|
174
|
+
assert isinstance(other, Point)
|
175
|
+
tag_self = utils.parenthesize_tag(self)
|
161
176
|
return Point(
|
162
177
|
is_basis=False,
|
163
178
|
eval_expression=EvalExpressionPoint(utils.Op.SUB, other, self),
|
179
|
+
tags=[f"{other.tag}-{tag_self}"],
|
164
180
|
)
|
165
181
|
|
166
182
|
def __mul__(self, other):
|
167
183
|
# TODO allow the other to be point so that we return a scalar.
|
168
184
|
assert is_numerical_or_point(other)
|
185
|
+
tag_self = utils.parenthesize_tag(self)
|
169
186
|
if utils.is_numerical(other):
|
170
187
|
return Point(
|
171
188
|
is_basis=False,
|
172
189
|
eval_expression=EvalExpressionPoint(utils.Op.MUL, self, other),
|
190
|
+
tags=[f"{tag_self}*{other:.4g}"],
|
173
191
|
)
|
174
192
|
else:
|
193
|
+
tag_other = utils.parenthesize_tag(other)
|
175
194
|
return Scalar(
|
176
195
|
is_basis=False,
|
177
|
-
eval_expression=EvalExpressionScalar(utils.Op.MUL, self, other),
|
196
|
+
eval_expression=EvalExpressionScalar(utils.Op.MUL, self, other),
|
197
|
+
tags=[f"{tag_self}*{tag_other}"],
|
178
198
|
)
|
179
199
|
|
180
200
|
def __rmul__(self, other):
|
181
201
|
# TODO allow the other to be point so that we return a scalar.
|
182
202
|
assert is_numerical_or_point(other)
|
203
|
+
tag_self = utils.parenthesize_tag(self)
|
183
204
|
if utils.is_numerical(other):
|
184
205
|
return Point(
|
185
206
|
is_basis=False,
|
186
207
|
eval_expression=EvalExpressionPoint(utils.Op.MUL, other, self),
|
208
|
+
tags=[f"{other:.4g}*{tag_self}"],
|
187
209
|
)
|
188
210
|
else:
|
211
|
+
tag_other = utils.parenthesize_tag(other)
|
189
212
|
return Scalar(
|
190
213
|
is_basis=False,
|
191
|
-
eval_expression=EvalExpressionScalar(utils.Op.MUL, other, self),
|
214
|
+
eval_expression=EvalExpressionScalar(utils.Op.MUL, other, self),
|
215
|
+
tags=[f"{tag_other}*{tag_self}"],
|
192
216
|
)
|
193
217
|
|
194
218
|
def __pow__(self, power):
|
195
219
|
assert power == 2
|
196
|
-
return
|
220
|
+
return Scalar(
|
221
|
+
is_basis=False,
|
222
|
+
eval_expression=EvalExpressionScalar(utils.Op.MUL, self, self),
|
223
|
+
tags=[rf"|{self.tag}|^{power}"],
|
224
|
+
)
|
197
225
|
|
198
226
|
def __neg__(self):
|
199
|
-
|
227
|
+
tag_self = utils.parenthesize_tag(self)
|
228
|
+
return Point(
|
229
|
+
is_basis=False,
|
230
|
+
eval_expression=EvalExpressionPoint(utils.Op.MUL, -1, self),
|
231
|
+
tags=[f"-{tag_self}"],
|
232
|
+
)
|
200
233
|
|
201
234
|
def __truediv__(self, other):
|
202
235
|
assert utils.is_numerical(other)
|
236
|
+
tag_self = utils.parenthesize_tag(self)
|
203
237
|
return Point(
|
204
238
|
is_basis=False,
|
205
239
|
eval_expression=EvalExpressionPoint(utils.Op.DIV, self, other),
|
240
|
+
tags=[f"1/{other:.4g}*{tag_self}"],
|
206
241
|
)
|
207
242
|
|
208
243
|
def __hash__(self):
|