pepflow 0.1.0__py3-none-any.whl → 0.1.3a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pepflow/__init__.py +50 -0
- pepflow/constants.py +20 -0
- pepflow/constraint.py +19 -0
- pepflow/expression_manager.py +19 -0
- pepflow/function.py +273 -90
- pepflow/function_test.py +134 -0
- pepflow/interactive_constraint.py +264 -0
- pepflow/pep.py +63 -7
- pepflow/pep_context.py +107 -6
- pepflow/pep_context_test.py +102 -0
- pepflow/pep_test.py +19 -0
- pepflow/point.py +42 -1
- pepflow/point_test.py +67 -30
- pepflow/scalar.py +43 -2
- pepflow/solver.py +28 -3
- pepflow/solver_test.py +19 -0
- pepflow/utils.py +19 -18
- {pepflow-0.1.0.dist-info → pepflow-0.1.3a1.dist-info}/METADATA +7 -2
- pepflow-0.1.3a1.dist-info/RECORD +22 -0
- pepflow-0.1.0.dist-info/RECORD +0 -18
- {pepflow-0.1.0.dist-info → pepflow-0.1.3a1.dist-info}/WHEEL +0 -0
- {pepflow-0.1.0.dist-info → pepflow-0.1.3a1.dist-info}/licenses/LICENSE +0 -0
- {pepflow-0.1.0.dist-info → pepflow-0.1.3a1.dist-info}/top_level.txt +0 -0
pepflow/function_test.py
ADDED
@@ -0,0 +1,134 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
|
22
|
+
from pepflow import expression_manager as exm
|
23
|
+
from pepflow import function as fc
|
24
|
+
from pepflow import pep as pep
|
25
|
+
|
26
|
+
|
27
|
+
def test_function_repr():
|
28
|
+
pep_builder = pep.PEPBuilder()
|
29
|
+
with pep_builder.make_context("test"):
|
30
|
+
f = fc.Function(is_basis=True, reuse_gradient=False)
|
31
|
+
print(f) # it should be fine without tag
|
32
|
+
f.add_tag("f")
|
33
|
+
assert str(f) == "f"
|
34
|
+
|
35
|
+
|
36
|
+
def test_stationary_point():
|
37
|
+
pep_builder = pep.PEPBuilder()
|
38
|
+
with pep_builder.make_context("test") as ctx:
|
39
|
+
f = fc.Function(is_basis=True, reuse_gradient=False, tags=["f"])
|
40
|
+
f.add_stationary_point("x_star")
|
41
|
+
|
42
|
+
assert len(ctx.triplets) == 1
|
43
|
+
assert len(ctx.triplets[f]) == 1
|
44
|
+
|
45
|
+
f_triplet = ctx.triplets[f][0]
|
46
|
+
assert f_triplet.name == "x_star_f(x_star)_gradient_f(x_star)"
|
47
|
+
assert f_triplet.gradient.tag == "gradient_f(x_star)"
|
48
|
+
assert f_triplet.function_value.tag == "f(x_star)"
|
49
|
+
|
50
|
+
em = exm.ExpressionManager(ctx)
|
51
|
+
np.testing.assert_allclose(
|
52
|
+
em.eval_point(f_triplet.gradient).vector, np.array([0])
|
53
|
+
)
|
54
|
+
np.testing.assert_allclose(em.eval_point(f_triplet.point).vector, np.array([1]))
|
55
|
+
|
56
|
+
|
57
|
+
def test_stationary_point_scaled():
|
58
|
+
pep_builder = pep.PEPBuilder()
|
59
|
+
with pep_builder.make_context("test") as ctx:
|
60
|
+
f = fc.Function(is_basis=True, reuse_gradient=False, tags=["f"])
|
61
|
+
g = 5 * f
|
62
|
+
g.add_stationary_point("x_star")
|
63
|
+
|
64
|
+
assert len(ctx.triplets) == 1
|
65
|
+
assert len(ctx.triplets[f]) == 1
|
66
|
+
|
67
|
+
f_triplet = ctx.triplets[f][0]
|
68
|
+
assert f_triplet.name == "x_star_f(x_star)_gradient_f(x_star)"
|
69
|
+
assert f_triplet.gradient.tag == "gradient_f(x_star)"
|
70
|
+
assert f_triplet.function_value.tag == "f(x_star)"
|
71
|
+
|
72
|
+
em = exm.ExpressionManager(ctx)
|
73
|
+
np.testing.assert_allclose(
|
74
|
+
em.eval_point(f_triplet.gradient).vector, np.array([0])
|
75
|
+
)
|
76
|
+
np.testing.assert_allclose(em.eval_point(f_triplet.point).vector, np.array([1]))
|
77
|
+
|
78
|
+
|
79
|
+
def test_stationary_point_additive():
|
80
|
+
pep_builder = pep.PEPBuilder()
|
81
|
+
with pep_builder.make_context("test") as ctx:
|
82
|
+
f = fc.Function(is_basis=True, reuse_gradient=False)
|
83
|
+
f.add_tag("f")
|
84
|
+
g = fc.Function(is_basis=True, reuse_gradient=False)
|
85
|
+
g.add_tag("g")
|
86
|
+
h = f + g
|
87
|
+
h.add_tag("h")
|
88
|
+
|
89
|
+
h.add_stationary_point("x_star")
|
90
|
+
assert len(ctx.triplets) == 2
|
91
|
+
assert len(ctx.triplets[f]) == 1
|
92
|
+
assert len(ctx.triplets[g]) == 1
|
93
|
+
|
94
|
+
f_triplet = ctx.triplets[f][0]
|
95
|
+
g_triplet = ctx.triplets[g][0]
|
96
|
+
assert f_triplet.name == "x_star_f(x_star)_gradient_f(x_star)"
|
97
|
+
assert g_triplet.name == "x_star_g(x_star)_gradient_g(x_star)"
|
98
|
+
|
99
|
+
em = exm.ExpressionManager(ctx)
|
100
|
+
np.testing.assert_allclose(
|
101
|
+
em.eval_point(f_triplet.gradient).vector, np.array([0, 1])
|
102
|
+
)
|
103
|
+
np.testing.assert_allclose(
|
104
|
+
em.eval_point(g_triplet.gradient).vector, np.array([0, -1])
|
105
|
+
)
|
106
|
+
|
107
|
+
|
108
|
+
def test_stationary_point_linear_combination():
|
109
|
+
pep_builder = pep.PEPBuilder()
|
110
|
+
with pep_builder.make_context("test") as ctx:
|
111
|
+
f = fc.Function(is_basis=True, reuse_gradient=False)
|
112
|
+
f.add_tag("f")
|
113
|
+
g = fc.Function(is_basis=True, reuse_gradient=False)
|
114
|
+
g.add_tag("g")
|
115
|
+
h = 3 * f + 2 * g
|
116
|
+
h.add_tag("h")
|
117
|
+
|
118
|
+
h.add_stationary_point("x_star")
|
119
|
+
assert len(ctx.triplets) == 2
|
120
|
+
assert len(ctx.triplets[f]) == 1
|
121
|
+
assert len(ctx.triplets[g]) == 1
|
122
|
+
|
123
|
+
f_triplet = ctx.triplets[f][0]
|
124
|
+
g_triplet = ctx.triplets[g][0]
|
125
|
+
assert f_triplet.name == "x_star_f(x_star)_gradient_f(x_star)"
|
126
|
+
assert g_triplet.name == "x_star_g(x_star)_gradient_g(x_star)"
|
127
|
+
|
128
|
+
em = exm.ExpressionManager(ctx)
|
129
|
+
np.testing.assert_allclose(
|
130
|
+
em.eval_point(f_triplet.gradient).vector, np.array([0, 1])
|
131
|
+
)
|
132
|
+
np.testing.assert_allclose(
|
133
|
+
em.eval_point(g_triplet.gradient).vector, np.array([0, -1.5])
|
134
|
+
)
|
@@ -0,0 +1,264 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
20
|
+
from __future__ import annotations
|
21
|
+
|
22
|
+
import json
|
23
|
+
from typing import TYPE_CHECKING
|
24
|
+
|
25
|
+
import dash
|
26
|
+
import dash_bootstrap_components as dbc
|
27
|
+
import numpy as np
|
28
|
+
import pandas as pd
|
29
|
+
import plotly
|
30
|
+
import plotly.express as px
|
31
|
+
import plotly.graph_objects as go
|
32
|
+
from dash import Dash, Input, Output, State, dcc, html
|
33
|
+
|
34
|
+
from pepflow.constants import PSD_CONSTRAINT
|
35
|
+
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
from pepflow.pep import PEPBuilder, PEPResult
|
38
|
+
from pepflow.pep_context import PEPContext
|
39
|
+
|
40
|
+
|
41
|
+
plotly.io.renderers.default = "colab+vscode"
|
42
|
+
plotly.io.templates.default = "plotly_white"
|
43
|
+
|
44
|
+
|
45
|
+
def solve_prob_and_get_figure(
|
46
|
+
pep_builder: PEPBuilder, context: PEPContext
|
47
|
+
) -> tuple[go.Figure, pd.DataFrame, PEPResult]:
|
48
|
+
assert len(context.triplets) == 1, "Support single function only for now"
|
49
|
+
|
50
|
+
result = pep_builder.solve(context=context)
|
51
|
+
|
52
|
+
df_dict, order_dict = context.triplets_to_df_and_order()
|
53
|
+
f = pep_builder.functions[0]
|
54
|
+
|
55
|
+
df = df_dict[f]
|
56
|
+
order = order_dict[f]
|
57
|
+
|
58
|
+
df["constraint"] = df.constraint_name.map(
|
59
|
+
lambda x: "inactive" if x in pep_builder.relaxed_constraints else "active"
|
60
|
+
)
|
61
|
+
df["dual_value"] = df.constraint_name.map(
|
62
|
+
lambda x: result.dual_var_manager.dual_value(x)
|
63
|
+
)
|
64
|
+
return processed_df_to_fig(df, order), df, result
|
65
|
+
|
66
|
+
|
67
|
+
def processed_df_to_fig(df: pd.DataFrame, order: list[str]):
|
68
|
+
fig = px.scatter(
|
69
|
+
df,
|
70
|
+
x="row",
|
71
|
+
y="col",
|
72
|
+
color="dual_value",
|
73
|
+
symbol="constraint",
|
74
|
+
symbol_map={"inactive": "x-open", "active": "circle"},
|
75
|
+
custom_data="constraint_name",
|
76
|
+
color_continuous_scale="Viridis",
|
77
|
+
)
|
78
|
+
fig.update_layout(yaxis=dict(autorange="reversed"))
|
79
|
+
fig.update_traces(marker=dict(size=15))
|
80
|
+
fig.update_layout(
|
81
|
+
coloraxis_colorbar=dict(yanchor="top", y=1, x=1.3, ticks="outside")
|
82
|
+
)
|
83
|
+
fig.update_xaxes(tickmode="array", tickvals=list(range(len(order))), ticktext=order)
|
84
|
+
fig.update_yaxes(tickmode="array", tickvals=list(range(len(order))), ticktext=order)
|
85
|
+
return fig
|
86
|
+
|
87
|
+
|
88
|
+
def get_matrix_of_dual_value(df: pd.DataFrame) -> np.ndarray:
|
89
|
+
# Check if we need to update the order.
|
90
|
+
return (
|
91
|
+
pd.pivot_table(
|
92
|
+
df, values="dual_value", index="row", columns="col", dropna=False
|
93
|
+
)
|
94
|
+
.fillna(0.0)
|
95
|
+
.to_numpy()
|
96
|
+
.T
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
def launch(pep_builder: PEPBuilder, context: PEPContext):
|
101
|
+
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
|
102
|
+
fig, df, result = solve_prob_and_get_figure(pep_builder, context)
|
103
|
+
|
104
|
+
with np.printoptions(precision=3, linewidth=100, suppress=True):
|
105
|
+
dual_value_tab = html.Pre(
|
106
|
+
str(get_matrix_of_dual_value(df)),
|
107
|
+
id="dual-value-display",
|
108
|
+
style={
|
109
|
+
"border": "1px solid lightgrey",
|
110
|
+
"padding": "10px",
|
111
|
+
"height": "60vh",
|
112
|
+
"overflowY": "auto",
|
113
|
+
},
|
114
|
+
)
|
115
|
+
# Think how can we manipulate the pep_builder here.
|
116
|
+
display_row = dbc.Row(
|
117
|
+
[
|
118
|
+
dbc.Col(
|
119
|
+
[
|
120
|
+
dbc.Button(
|
121
|
+
"Relax All Constraints",
|
122
|
+
id="relax-all-constraints-button",
|
123
|
+
style={"margin-bottom": "5px", "margin-right": "5px"},
|
124
|
+
),
|
125
|
+
dbc.Button(
|
126
|
+
"Restore All Constraints",
|
127
|
+
id="restore-all-constraints-button",
|
128
|
+
style={"margin-bottom": "5px"},
|
129
|
+
color="success",
|
130
|
+
),
|
131
|
+
dbc.Tabs(
|
132
|
+
[
|
133
|
+
dbc.Tab(
|
134
|
+
html.Div(
|
135
|
+
dcc.Graph(id="interactive-scatter", figure=fig),
|
136
|
+
),
|
137
|
+
label="Interactive Heatmap",
|
138
|
+
tab_id="interactive-scatter-tab",
|
139
|
+
),
|
140
|
+
dbc.Tab(
|
141
|
+
dual_value_tab,
|
142
|
+
label="Dual Value Matrix",
|
143
|
+
tab_id="dual_value_tab",
|
144
|
+
),
|
145
|
+
],
|
146
|
+
active_tab="interactive-scatter-tab",
|
147
|
+
),
|
148
|
+
],
|
149
|
+
width=5,
|
150
|
+
),
|
151
|
+
# Column 2: The Selected Data Display (takes up 4 of 12 columns)
|
152
|
+
dbc.Col(
|
153
|
+
[
|
154
|
+
dbc.Button(
|
155
|
+
"Solve PEP Problem",
|
156
|
+
id="solve-button",
|
157
|
+
color="primary",
|
158
|
+
className="me-1",
|
159
|
+
style={"margin-bottom": "5px"},
|
160
|
+
),
|
161
|
+
dcc.Loading(
|
162
|
+
dbc.Card(
|
163
|
+
id="result-card",
|
164
|
+
style={"height": "60vh", "overflow-y": "auto"},
|
165
|
+
)
|
166
|
+
),
|
167
|
+
],
|
168
|
+
width=7,
|
169
|
+
),
|
170
|
+
],
|
171
|
+
)
|
172
|
+
|
173
|
+
# 3. Define the app layout
|
174
|
+
app.layout = html.Div(
|
175
|
+
[
|
176
|
+
html.H2("PEPFlow"),
|
177
|
+
display_row,
|
178
|
+
# Store the entire DataFrame as a dictionary in dcc.Store
|
179
|
+
dcc.Store(id="dataframe-store", data=df.to_dict("records")),
|
180
|
+
]
|
181
|
+
)
|
182
|
+
|
183
|
+
@dash.callback(
|
184
|
+
Output("result-card", "children"),
|
185
|
+
Output("dual-value-display", "children"),
|
186
|
+
Output("interactive-scatter", "figure"),
|
187
|
+
Output("dataframe-store", "data"),
|
188
|
+
Input("solve-button", "n_clicks"),
|
189
|
+
)
|
190
|
+
def solve(_):
|
191
|
+
fig, df, result = solve_prob_and_get_figure(pep_builder, context)
|
192
|
+
with np.printoptions(precision=3, linewidth=100, suppress=True):
|
193
|
+
psd_dual_value = np.array(
|
194
|
+
result.dual_var_manager.dual_value(PSD_CONSTRAINT)
|
195
|
+
)
|
196
|
+
reslt_card = dbc.CardBody(
|
197
|
+
[
|
198
|
+
html.H2(f"Optimal Value {result.primal_opt_value:.4g}"),
|
199
|
+
html.H3(f"Solver Status: {result.solver_status}"),
|
200
|
+
html.P("PSD Dual Variable:"),
|
201
|
+
html.Pre(str(psd_dual_value)),
|
202
|
+
html.P("Relaxed Constraints:"),
|
203
|
+
html.Pre(json.dumps(pep_builder.relaxed_constraints, indent=2)),
|
204
|
+
]
|
205
|
+
)
|
206
|
+
dual_value_display = str(get_matrix_of_dual_value(df))
|
207
|
+
return reslt_card, dual_value_display, fig, df.to_dict("records")
|
208
|
+
|
209
|
+
@dash.callback(
|
210
|
+
Output("interactive-scatter", "figure", allow_duplicate=True),
|
211
|
+
Output("dataframe-store", "data", allow_duplicate=True),
|
212
|
+
Input("restore-all-constraints-button", "n_clicks"),
|
213
|
+
State("dataframe-store", "data"),
|
214
|
+
prevent_initial_call=True,
|
215
|
+
)
|
216
|
+
def restore_all_constraints(_, previous_df):
|
217
|
+
nonlocal pep_builder
|
218
|
+
df_updated = pd.DataFrame(previous_df)
|
219
|
+
pep_builder.relaxed_constraints = []
|
220
|
+
df_updated["constraint"] = "active"
|
221
|
+
order = context.order_of_point(pep_builder.functions[0])
|
222
|
+
return processed_df_to_fig(df_updated, order), df_updated.to_dict("records")
|
223
|
+
|
224
|
+
@dash.callback(
|
225
|
+
Output("interactive-scatter", "figure", allow_duplicate=True),
|
226
|
+
Output("dataframe-store", "data", allow_duplicate=True),
|
227
|
+
Input("relax-all-constraints-button", "n_clicks"),
|
228
|
+
State("dataframe-store", "data"),
|
229
|
+
prevent_initial_call=True,
|
230
|
+
)
|
231
|
+
def relax_all_constraints(_, previous_df):
|
232
|
+
nonlocal pep_builder
|
233
|
+
df_updated = pd.DataFrame(previous_df)
|
234
|
+
pep_builder.relaxed_constraints = df_updated["constraint_name"].to_list()
|
235
|
+
df_updated["constraint"] = "inactive"
|
236
|
+
order = context.order_of_point(pep_builder.functions[0])
|
237
|
+
return processed_df_to_fig(df_updated, order), df_updated.to_dict("records")
|
238
|
+
|
239
|
+
@dash.callback(
|
240
|
+
Output("interactive-scatter", "figure", allow_duplicate=True),
|
241
|
+
Output("dataframe-store", "data", allow_duplicate=True),
|
242
|
+
Input("interactive-scatter", "clickData"),
|
243
|
+
State("dataframe-store", "data"),
|
244
|
+
prevent_initial_call=True,
|
245
|
+
)
|
246
|
+
def update_df_and_redraw(clickData, previous_df):
|
247
|
+
nonlocal pep_builder
|
248
|
+
if not clickData["points"][0]["customdata"]:
|
249
|
+
return dash.no_update, dash.no_update, dash.no_update
|
250
|
+
|
251
|
+
clicked_name = clickData["points"][0]["customdata"][0]
|
252
|
+
if clicked_name not in pep_builder.relaxed_constraints:
|
253
|
+
pep_builder.relaxed_constraints.append(clicked_name)
|
254
|
+
else:
|
255
|
+
pep_builder.relaxed_constraints.remove(clicked_name)
|
256
|
+
|
257
|
+
df_updated = pd.DataFrame(previous_df)
|
258
|
+
df_updated["constraint"] = df_updated.constraint_name.map(
|
259
|
+
lambda x: "inactive" if x in pep_builder.relaxed_constraints else "active"
|
260
|
+
)
|
261
|
+
order = context.order_of_point(pep_builder.functions[0])
|
262
|
+
return processed_df_to_fig(df_updated, order), df_updated.to_dict("records")
|
263
|
+
|
264
|
+
app.run(debug=True)
|
pepflow/pep.py
CHANGED
@@ -1,16 +1,40 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
1
20
|
from __future__ import annotations
|
2
21
|
|
3
22
|
import contextlib
|
4
|
-
from typing import TYPE_CHECKING, Any
|
23
|
+
from typing import TYPE_CHECKING, Any, Iterator
|
5
24
|
|
6
25
|
import attrs
|
26
|
+
import numpy as np
|
27
|
+
import pandas as pd
|
7
28
|
|
8
29
|
from pepflow import pep_context as pc
|
9
30
|
from pepflow import point as pt
|
10
31
|
from pepflow import scalar as sc
|
11
32
|
from pepflow import solver as ps
|
33
|
+
from pepflow.constants import PSD_CONSTRAINT
|
12
34
|
|
13
35
|
if TYPE_CHECKING:
|
36
|
+
from pepflow.constraint import Constraint
|
37
|
+
from pepflow.function import Function
|
14
38
|
from pepflow.solver import DualVariableManager
|
15
39
|
|
16
40
|
|
@@ -19,6 +43,33 @@ class PEPResult:
|
|
19
43
|
primal_opt_value: float
|
20
44
|
dual_var_manager: DualVariableManager
|
21
45
|
solver_status: Any
|
46
|
+
context: pc.PEPContext
|
47
|
+
|
48
|
+
def get_function_dual_variables(self) -> dict[Function, np.ndarray]:
|
49
|
+
def get_matrix_of_dual_value(df: pd.DataFrame) -> np.ndarray:
|
50
|
+
# Check if we need to update the order.
|
51
|
+
return (
|
52
|
+
pd.pivot_table(
|
53
|
+
df, values="dual_value", index="row", columns="col", dropna=False
|
54
|
+
)
|
55
|
+
.fillna(0.0)
|
56
|
+
.to_numpy()
|
57
|
+
.T
|
58
|
+
)
|
59
|
+
|
60
|
+
df_dict, _ = self.context.triplets_to_df_and_order()
|
61
|
+
df_dict_matrix = {}
|
62
|
+
for f in df_dict.keys():
|
63
|
+
df = df_dict[f]
|
64
|
+
df["dual_value"] = df.constraint_name.map(
|
65
|
+
lambda x: self.dual_var_manager.dual_value(x)
|
66
|
+
)
|
67
|
+
df_dict_matrix[f] = get_matrix_of_dual_value(df)
|
68
|
+
|
69
|
+
return df_dict_matrix
|
70
|
+
|
71
|
+
def get_psd_dual_matrix(self):
|
72
|
+
return np.array(self.dual_var_manager.dual_value(PSD_CONSTRAINT))
|
22
73
|
|
23
74
|
|
24
75
|
class PEPBuilder:
|
@@ -37,7 +88,9 @@ class PEPBuilder:
|
|
37
88
|
self.relaxed_constraints = []
|
38
89
|
|
39
90
|
@contextlib.contextmanager
|
40
|
-
def make_context(
|
91
|
+
def make_context(
|
92
|
+
self, name: str, override: bool = False
|
93
|
+
) -> Iterator[pc.PEPContext]:
|
41
94
|
if not override and name in self.pep_context_dict:
|
42
95
|
raise KeyError(f"There is already a context {name} in the builder")
|
43
96
|
try:
|
@@ -63,7 +116,7 @@ class PEPBuilder:
|
|
63
116
|
def clear_all_context(self) -> None:
|
64
117
|
self.pep_context_dict.clear()
|
65
118
|
|
66
|
-
def set_init_point(self, tag: str
|
119
|
+
def set_init_point(self, tag: str) -> pt.Point:
|
67
120
|
point = pt.Point(is_basis=True)
|
68
121
|
point.add_tag(tag)
|
69
122
|
return point
|
@@ -74,6 +127,9 @@ class PEPBuilder:
|
|
74
127
|
def set_performance_metric(self, metric: sc.Scalar):
|
75
128
|
self.performance_metric = metric
|
76
129
|
|
130
|
+
def set_relaxed_constraints(self, relaxed_constraints: list[str]):
|
131
|
+
self.relaxed_constraints.extend(relaxed_constraints)
|
132
|
+
|
77
133
|
def declare_func(self, function_class, **kwargs):
|
78
134
|
func = function_class(is_basis=True, composition=None, **kwargs)
|
79
135
|
self.functions.append(func)
|
@@ -85,11 +141,10 @@ class PEPBuilder:
|
|
85
141
|
if context is None:
|
86
142
|
raise RuntimeError("Did you forget to create a context?")
|
87
143
|
|
88
|
-
all_constraints = [*self.init_conditions]
|
144
|
+
all_constraints: list[Constraint] = [*self.init_conditions]
|
89
145
|
for f in self.functions:
|
90
|
-
|
91
|
-
|
92
|
-
all_constraints.extend(f.constraints)
|
146
|
+
all_constraints.extend(f.get_interpolation_constraints())
|
147
|
+
all_constraints.extend(context.opt_conditions[f])
|
93
148
|
|
94
149
|
# for now, we heavily rely on the CVX. We can make a wrapper class to avoid
|
95
150
|
# direct dependency in the future.
|
@@ -106,4 +161,5 @@ class PEPBuilder:
|
|
106
161
|
primal_opt_value=result,
|
107
162
|
dual_var_manager=solver.dual_var_manager,
|
108
163
|
solver_status=problem.status,
|
164
|
+
context=context,
|
109
165
|
)
|
pepflow/pep_context.py
CHANGED
@@ -1,5 +1,36 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
1
20
|
from __future__ import annotations
|
2
21
|
|
22
|
+
from collections import defaultdict
|
23
|
+
from typing import TYPE_CHECKING
|
24
|
+
|
25
|
+
import natsort
|
26
|
+
import pandas as pd
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from pepflow.constraint import Constraint
|
30
|
+
from pepflow.function import Function, Triplet
|
31
|
+
from pepflow.point import Point
|
32
|
+
from pepflow.scalar import Scalar
|
33
|
+
|
3
34
|
# A global variable for storing the current context that is used for points or scalars.
|
4
35
|
CURRENT_CONTEXT: PEPContext | None = None
|
5
36
|
|
@@ -16,15 +47,85 @@ def set_current_context(ctx: PEPContext | None):
|
|
16
47
|
|
17
48
|
class PEPContext:
|
18
49
|
def __init__(self):
|
19
|
-
self.points = []
|
20
|
-
self.scalars = []
|
50
|
+
self.points: list[Point] = []
|
51
|
+
self.scalars: list[Scalar] = []
|
52
|
+
self.triplets: dict[Function, list[Triplet]] = defaultdict(list)
|
53
|
+
self.opt_conditions: dict[Function, list[Constraint]] = defaultdict(list)
|
21
54
|
|
22
|
-
def
|
55
|
+
def set_as_current(self) -> PEPContext:
|
56
|
+
set_current_context(self)
|
57
|
+
return self
|
58
|
+
|
59
|
+
def add_point(self, point: Point):
|
23
60
|
self.points.append(point)
|
24
61
|
|
25
|
-
def add_scalar(self, scalar):
|
62
|
+
def add_scalar(self, scalar: Scalar):
|
26
63
|
self.scalars.append(scalar)
|
27
64
|
|
65
|
+
def add_triplet(self, function: Function, triplet: Triplet):
|
66
|
+
self.triplets[function].append(triplet)
|
67
|
+
|
68
|
+
def add_opt_condition(self, function: Function, opt_condition: Constraint):
|
69
|
+
self.opt_conditions[function].append(opt_condition)
|
70
|
+
|
71
|
+
def get_by_tag(self, tag: str) -> Point | Scalar:
|
72
|
+
for p in self.points:
|
73
|
+
if tag in p.tags:
|
74
|
+
return p
|
75
|
+
for s in self.scalars:
|
76
|
+
if tag in s.tags:
|
77
|
+
return s
|
78
|
+
raise ValueError("Cannot find the point or scalar of given tag")
|
79
|
+
|
28
80
|
def clear(self):
|
29
|
-
self.points
|
30
|
-
self.scalars
|
81
|
+
self.points.clear()
|
82
|
+
self.scalars.clear()
|
83
|
+
self.triplets.clear()
|
84
|
+
self.opt_conditions.clear()
|
85
|
+
|
86
|
+
def tracked_point(self, func: Function) -> list[Point]:
|
87
|
+
return natsort.natsorted(
|
88
|
+
[t.point for t in self.triplets[func]], key=lambda x: x.tag
|
89
|
+
)
|
90
|
+
|
91
|
+
def tracked_grad(self, func: Function) -> list[Point]:
|
92
|
+
return natsort.natsorted(
|
93
|
+
[t.gradient for t in self.triplets[func]], key=lambda x: x.tag
|
94
|
+
)
|
95
|
+
|
96
|
+
def tracked_func_value(self, func: Function) -> list[Scalar]:
|
97
|
+
return natsort.natsorted(
|
98
|
+
[t.function_value for t in self.triplets[func]], key=lambda x: x.tag
|
99
|
+
)
|
100
|
+
|
101
|
+
def order_of_point(self, func: Function) -> list[str]:
|
102
|
+
return natsort.natsorted([t.point.tag for t in self.triplets[func]])
|
103
|
+
|
104
|
+
def triplets_to_df_and_order(
|
105
|
+
self,
|
106
|
+
) -> tuple[dict[Function, pd.DataFrame], dict[Function, list[str]]]:
|
107
|
+
func_to_df: dict[Function, pd.DataFrame] = {}
|
108
|
+
func_to_order: dict[Function, list[str]] = {}
|
109
|
+
|
110
|
+
def name_to_point_tuple(c_name: str) -> list[str]:
|
111
|
+
_, points = c_name.split(":")
|
112
|
+
return points.split(",")
|
113
|
+
|
114
|
+
for func, triplets in self.triplets.items():
|
115
|
+
order = self.order_of_point(func)
|
116
|
+
df = pd.DataFrame(
|
117
|
+
[
|
118
|
+
(
|
119
|
+
constraint.name,
|
120
|
+
*name_to_point_tuple(constraint.name),
|
121
|
+
)
|
122
|
+
for constraint in func.get_interpolation_constraints(self)
|
123
|
+
],
|
124
|
+
columns=["constraint_name", "col_point", "row_point"],
|
125
|
+
)
|
126
|
+
df["row"] = df["row_point"].map(lambda x: order.index(x))
|
127
|
+
df["col"] = df["col_point"].map(lambda x: order.index(x))
|
128
|
+
func_to_df[func] = df
|
129
|
+
func_to_order[func] = order
|
130
|
+
|
131
|
+
return func_to_df, func_to_order
|