pepflow 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,264 @@
1
+ # Copyright: 2025 The PEPFlow Developers
2
+ #
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ from typing import TYPE_CHECKING
24
+
25
+ import dash
26
+ import dash_bootstrap_components as dbc
27
+ import numpy as np
28
+ import pandas as pd
29
+ import plotly
30
+ import plotly.express as px
31
+ import plotly.graph_objects as go
32
+ from dash import Dash, Input, Output, State, dcc, html
33
+
34
+ from pepflow.constants import PSD_CONSTRAINT
35
+
36
+ if TYPE_CHECKING:
37
+ from pepflow.pep import PEPBuilder, PEPResult
38
+ from pepflow.pep_context import PEPContext
39
+
40
+
41
+ plotly.io.renderers.default = "colab+vscode"
42
+ plotly.io.templates.default = "plotly_white"
43
+
44
+
45
+ def solve_prob_and_get_figure(
46
+ pep_builder: PEPBuilder, context: PEPContext
47
+ ) -> tuple[go.Figure, pd.DataFrame, PEPResult]:
48
+ assert len(context.triplets) == 1, "Support single function only for now"
49
+
50
+ result = pep_builder.solve(context=context)
51
+
52
+ df_dict, order_dict = context.triplets_to_df_and_order()
53
+ f = pep_builder.functions[0]
54
+
55
+ df = df_dict[f]
56
+ order = order_dict[f]
57
+
58
+ df["constraint"] = df.constraint_name.map(
59
+ lambda x: "inactive" if x in pep_builder.relaxed_constraints else "active"
60
+ )
61
+ df["dual_value"] = df.constraint_name.map(
62
+ lambda x: result.dual_var_manager.dual_value(x)
63
+ )
64
+ return processed_df_to_fig(df, order), df, result
65
+
66
+
67
+ def processed_df_to_fig(df: pd.DataFrame, order: list[str]):
68
+ fig = px.scatter(
69
+ df,
70
+ x="row",
71
+ y="col",
72
+ color="dual_value",
73
+ symbol="constraint",
74
+ symbol_map={"inactive": "x-open", "active": "circle"},
75
+ custom_data="constraint_name",
76
+ color_continuous_scale="Viridis",
77
+ )
78
+ fig.update_layout(yaxis=dict(autorange="reversed"))
79
+ fig.update_traces(marker=dict(size=15))
80
+ fig.update_layout(
81
+ coloraxis_colorbar=dict(yanchor="top", y=1, x=1.3, ticks="outside")
82
+ )
83
+ fig.update_xaxes(tickmode="array", tickvals=list(range(len(order))), ticktext=order)
84
+ fig.update_yaxes(tickmode="array", tickvals=list(range(len(order))), ticktext=order)
85
+ return fig
86
+
87
+
88
+ def get_matrix_of_dual_value(df: pd.DataFrame) -> np.ndarray:
89
+ # Check if we need to update the order.
90
+ return (
91
+ pd.pivot_table(
92
+ df, values="dual_value", index="row", columns="col", dropna=False
93
+ )
94
+ .fillna(0.0)
95
+ .to_numpy()
96
+ .T
97
+ )
98
+
99
+
100
+ def launch(pep_builder: PEPBuilder, context: PEPContext):
101
+ app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
102
+ fig, df, result = solve_prob_and_get_figure(pep_builder, context)
103
+
104
+ with np.printoptions(precision=3, linewidth=100, suppress=True):
105
+ dual_value_tab = html.Pre(
106
+ str(get_matrix_of_dual_value(df)),
107
+ id="dual-value-display",
108
+ style={
109
+ "border": "1px solid lightgrey",
110
+ "padding": "10px",
111
+ "height": "60vh",
112
+ "overflowY": "auto",
113
+ },
114
+ )
115
+ # Think how can we manipulate the pep_builder here.
116
+ display_row = dbc.Row(
117
+ [
118
+ dbc.Col(
119
+ [
120
+ dbc.Button(
121
+ "Relax All Constraints",
122
+ id="relax-all-constraints-button",
123
+ style={"margin-bottom": "5px", "margin-right": "5px"},
124
+ ),
125
+ dbc.Button(
126
+ "Restore All Constraints",
127
+ id="restore-all-constraints-button",
128
+ style={"margin-bottom": "5px"},
129
+ color="success",
130
+ ),
131
+ dbc.Tabs(
132
+ [
133
+ dbc.Tab(
134
+ html.Div(
135
+ dcc.Graph(id="interactive-scatter", figure=fig),
136
+ ),
137
+ label="Interactive Heatmap",
138
+ tab_id="interactive-scatter-tab",
139
+ ),
140
+ dbc.Tab(
141
+ dual_value_tab,
142
+ label="Dual Value Matrix",
143
+ tab_id="dual_value_tab",
144
+ ),
145
+ ],
146
+ active_tab="interactive-scatter-tab",
147
+ ),
148
+ ],
149
+ width=5,
150
+ ),
151
+ # Column 2: The Selected Data Display (takes up 4 of 12 columns)
152
+ dbc.Col(
153
+ [
154
+ dbc.Button(
155
+ "Solve PEP Problem",
156
+ id="solve-button",
157
+ color="primary",
158
+ className="me-1",
159
+ style={"margin-bottom": "5px"},
160
+ ),
161
+ dcc.Loading(
162
+ dbc.Card(
163
+ id="result-card",
164
+ style={"height": "60vh", "overflow-y": "auto"},
165
+ )
166
+ ),
167
+ ],
168
+ width=7,
169
+ ),
170
+ ],
171
+ )
172
+
173
+ # 3. Define the app layout
174
+ app.layout = html.Div(
175
+ [
176
+ html.H2("PEPFlow"),
177
+ display_row,
178
+ # Store the entire DataFrame as a dictionary in dcc.Store
179
+ dcc.Store(id="dataframe-store", data=df.to_dict("records")),
180
+ ]
181
+ )
182
+
183
+ @dash.callback(
184
+ Output("result-card", "children"),
185
+ Output("dual-value-display", "children"),
186
+ Output("interactive-scatter", "figure"),
187
+ Output("dataframe-store", "data"),
188
+ Input("solve-button", "n_clicks"),
189
+ )
190
+ def solve(_):
191
+ fig, df, result = solve_prob_and_get_figure(pep_builder, context)
192
+ with np.printoptions(precision=3, linewidth=100, suppress=True):
193
+ psd_dual_value = np.array(
194
+ result.dual_var_manager.dual_value(PSD_CONSTRAINT)
195
+ )
196
+ reslt_card = dbc.CardBody(
197
+ [
198
+ html.H2(f"Optimal Value {result.primal_opt_value:.4g}"),
199
+ html.H3(f"Solver Status: {result.solver_status}"),
200
+ html.P("PSD Dual Variable:"),
201
+ html.Pre(str(psd_dual_value)),
202
+ html.P("Relaxed Constraints:"),
203
+ html.Pre(json.dumps(pep_builder.relaxed_constraints, indent=2)),
204
+ ]
205
+ )
206
+ dual_value_display = str(get_matrix_of_dual_value(df))
207
+ return reslt_card, dual_value_display, fig, df.to_dict("records")
208
+
209
+ @dash.callback(
210
+ Output("interactive-scatter", "figure", allow_duplicate=True),
211
+ Output("dataframe-store", "data", allow_duplicate=True),
212
+ Input("restore-all-constraints-button", "n_clicks"),
213
+ State("dataframe-store", "data"),
214
+ prevent_initial_call=True,
215
+ )
216
+ def restore_all_constraints(_, previous_df):
217
+ nonlocal pep_builder
218
+ df_updated = pd.DataFrame(previous_df)
219
+ pep_builder.relaxed_constraints = []
220
+ df_updated["constraint"] = "active"
221
+ order = context.order_of_point(pep_builder.functions[0])
222
+ return processed_df_to_fig(df_updated, order), df_updated.to_dict("records")
223
+
224
+ @dash.callback(
225
+ Output("interactive-scatter", "figure", allow_duplicate=True),
226
+ Output("dataframe-store", "data", allow_duplicate=True),
227
+ Input("relax-all-constraints-button", "n_clicks"),
228
+ State("dataframe-store", "data"),
229
+ prevent_initial_call=True,
230
+ )
231
+ def relax_all_constraints(_, previous_df):
232
+ nonlocal pep_builder
233
+ df_updated = pd.DataFrame(previous_df)
234
+ pep_builder.relaxed_constraints = df_updated["constraint_name"].to_list()
235
+ df_updated["constraint"] = "inactive"
236
+ order = context.order_of_point(pep_builder.functions[0])
237
+ return processed_df_to_fig(df_updated, order), df_updated.to_dict("records")
238
+
239
+ @dash.callback(
240
+ Output("interactive-scatter", "figure", allow_duplicate=True),
241
+ Output("dataframe-store", "data", allow_duplicate=True),
242
+ Input("interactive-scatter", "clickData"),
243
+ State("dataframe-store", "data"),
244
+ prevent_initial_call=True,
245
+ )
246
+ def update_df_and_redraw(clickData, previous_df):
247
+ nonlocal pep_builder
248
+ if not clickData["points"][0]["customdata"]:
249
+ return dash.no_update, dash.no_update, dash.no_update
250
+
251
+ clicked_name = clickData["points"][0]["customdata"][0]
252
+ if clicked_name not in pep_builder.relaxed_constraints:
253
+ pep_builder.relaxed_constraints.append(clicked_name)
254
+ else:
255
+ pep_builder.relaxed_constraints.remove(clicked_name)
256
+
257
+ df_updated = pd.DataFrame(previous_df)
258
+ df_updated["constraint"] = df_updated.constraint_name.map(
259
+ lambda x: "inactive" if x in pep_builder.relaxed_constraints else "active"
260
+ )
261
+ order = context.order_of_point(pep_builder.functions[0])
262
+ return processed_df_to_fig(df_updated, order), df_updated.to_dict("records")
263
+
264
+ app.run(debug=True)
pepflow/pep.py CHANGED
@@ -1,16 +1,40 @@
1
+ # Copyright: 2025 The PEPFlow Developers
2
+ #
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
1
20
  from __future__ import annotations
2
21
 
3
22
  import contextlib
4
- from typing import TYPE_CHECKING, Any
23
+ from typing import TYPE_CHECKING, Any, Iterator
5
24
 
6
25
  import attrs
26
+ import numpy as np
27
+ import pandas as pd
7
28
 
8
29
  from pepflow import pep_context as pc
9
30
  from pepflow import point as pt
10
31
  from pepflow import scalar as sc
11
32
  from pepflow import solver as ps
33
+ from pepflow.constants import PSD_CONSTRAINT
12
34
 
13
35
  if TYPE_CHECKING:
36
+ from pepflow.constraint import Constraint
37
+ from pepflow.function import Function
14
38
  from pepflow.solver import DualVariableManager
15
39
 
16
40
 
@@ -19,6 +43,33 @@ class PEPResult:
19
43
  primal_opt_value: float
20
44
  dual_var_manager: DualVariableManager
21
45
  solver_status: Any
46
+ context: pc.PEPContext
47
+
48
+ def get_function_dual_variables(self) -> dict[Function, np.ndarray]:
49
+ def get_matrix_of_dual_value(df: pd.DataFrame) -> np.ndarray:
50
+ # Check if we need to update the order.
51
+ return (
52
+ pd.pivot_table(
53
+ df, values="dual_value", index="row", columns="col", dropna=False
54
+ )
55
+ .fillna(0.0)
56
+ .to_numpy()
57
+ .T
58
+ )
59
+
60
+ df_dict, _ = self.context.triplets_to_df_and_order()
61
+ df_dict_matrix = {}
62
+ for f in df_dict.keys():
63
+ df = df_dict[f]
64
+ df["dual_value"] = df.constraint_name.map(
65
+ lambda x: self.dual_var_manager.dual_value(x)
66
+ )
67
+ df_dict_matrix[f] = get_matrix_of_dual_value(df)
68
+
69
+ return df_dict_matrix
70
+
71
+ def get_psd_dual_matrix(self):
72
+ return np.array(self.dual_var_manager.dual_value(PSD_CONSTRAINT))
22
73
 
23
74
 
24
75
  class PEPBuilder:
@@ -37,7 +88,9 @@ class PEPBuilder:
37
88
  self.relaxed_constraints = []
38
89
 
39
90
  @contextlib.contextmanager
40
- def make_context(self, name: str, override: bool = False) -> pc.PEPContext:
91
+ def make_context(
92
+ self, name: str, override: bool = False
93
+ ) -> Iterator[pc.PEPContext]:
41
94
  if not override and name in self.pep_context_dict:
42
95
  raise KeyError(f"There is already a context {name} in the builder")
43
96
  try:
@@ -63,7 +116,7 @@ class PEPBuilder:
63
116
  def clear_all_context(self) -> None:
64
117
  self.pep_context_dict.clear()
65
118
 
66
- def set_init_point(self, tag: str | None = None) -> pt.Point:
119
+ def set_init_point(self, tag: str) -> pt.Point:
67
120
  point = pt.Point(is_basis=True)
68
121
  point.add_tag(tag)
69
122
  return point
@@ -74,6 +127,9 @@ class PEPBuilder:
74
127
  def set_performance_metric(self, metric: sc.Scalar):
75
128
  self.performance_metric = metric
76
129
 
130
+ def set_relaxed_constraints(self, relaxed_constraints: list[str]):
131
+ self.relaxed_constraints.extend(relaxed_constraints)
132
+
77
133
  def declare_func(self, function_class, **kwargs):
78
134
  func = function_class(is_basis=True, composition=None, **kwargs)
79
135
  self.functions.append(func)
@@ -85,11 +141,10 @@ class PEPBuilder:
85
141
  if context is None:
86
142
  raise RuntimeError("Did you forget to create a context?")
87
143
 
88
- all_constraints = [*self.init_conditions]
144
+ all_constraints: list[Constraint] = [*self.init_conditions]
89
145
  for f in self.functions:
90
- if f.is_basis:
91
- all_constraints.extend(f.get_interpolation_constraints())
92
- all_constraints.extend(f.constraints)
146
+ all_constraints.extend(f.get_interpolation_constraints())
147
+ all_constraints.extend(context.opt_conditions[f])
93
148
 
94
149
  # for now, we heavily rely on the CVX. We can make a wrapper class to avoid
95
150
  # direct dependency in the future.
@@ -106,4 +161,5 @@ class PEPBuilder:
106
161
  primal_opt_value=result,
107
162
  dual_var_manager=solver.dual_var_manager,
108
163
  solver_status=problem.status,
164
+ context=context,
109
165
  )
pepflow/pep_context.py CHANGED
@@ -1,5 +1,36 @@
1
+ # Copyright: 2025 The PEPFlow Developers
2
+ #
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
1
20
  from __future__ import annotations
2
21
 
22
+ from collections import defaultdict
23
+ from typing import TYPE_CHECKING
24
+
25
+ import natsort
26
+ import pandas as pd
27
+
28
+ if TYPE_CHECKING:
29
+ from pepflow.constraint import Constraint
30
+ from pepflow.function import Function, Triplet
31
+ from pepflow.point import Point
32
+ from pepflow.scalar import Scalar
33
+
3
34
  # A global variable for storing the current context that is used for points or scalars.
4
35
  CURRENT_CONTEXT: PEPContext | None = None
5
36
 
@@ -16,15 +47,81 @@ def set_current_context(ctx: PEPContext | None):
16
47
 
17
48
  class PEPContext:
18
49
  def __init__(self):
19
- self.points = []
20
- self.scalars = []
50
+ self.points: list[Point] = []
51
+ self.scalars: list[Scalar] = []
52
+ self.triplets: dict[Function, list[Triplet]] = defaultdict(list)
53
+ self.opt_conditions: dict[Function, list[Constraint]] = defaultdict(list)
21
54
 
22
- def add_point(self, point):
55
+ def add_point(self, point: Point):
23
56
  self.points.append(point)
24
57
 
25
- def add_scalar(self, scalar):
58
+ def add_scalar(self, scalar: Scalar):
26
59
  self.scalars.append(scalar)
27
60
 
61
+ def add_triplet(self, function: Function, triplet: Triplet):
62
+ self.triplets[function].append(triplet)
63
+
64
+ def add_opt_condition(self, function: Function, opt_condition: Constraint):
65
+ self.opt_conditions[function].append(opt_condition)
66
+
67
+ def get_by_tag(self, tag: str) -> Point | Scalar:
68
+ for p in self.points:
69
+ if p.tag == tag:
70
+ return p
71
+ for s in self.scalars:
72
+ if s.tag == tag:
73
+ return s
74
+ raise ValueError("Cannot find the point or scalar of given tag")
75
+
28
76
  def clear(self):
29
- self.points = []
30
- self.scalars = []
77
+ self.points.clear()
78
+ self.scalars.clear()
79
+ self.triplets.clear()
80
+ self.opt_conditions.clear()
81
+
82
+ def tracked_point(self, func: Function) -> list[Point]:
83
+ return natsort.natsorted(
84
+ [t.point for t in self.triplets[func]], key=lambda x: x.tag
85
+ )
86
+
87
+ def tracked_grad(self, func: Function) -> list[Point]:
88
+ return natsort.natsorted(
89
+ [t.gradient for t in self.triplets[func]], key=lambda x: x.tag
90
+ )
91
+
92
+ def tracked_func_value(self, func: Function) -> list[Scalar]:
93
+ return natsort.natsorted(
94
+ [t.function_value for t in self.triplets[func]], key=lambda x: x.tag
95
+ )
96
+
97
+ def order_of_point(self, func: Function) -> list[str]:
98
+ return natsort.natsorted([t.point.tag for t in self.triplets[func]])
99
+
100
+ def triplets_to_df_and_order(
101
+ self,
102
+ ) -> tuple[dict[Function, pd.DataFrame], dict[Function, list[str]]]:
103
+ func_to_df: dict[Function, pd.DataFrame] = {}
104
+ func_to_order: dict[Function, list[str]] = {}
105
+
106
+ def name_to_point_tuple(c_name: str) -> list[str]:
107
+ _, points = c_name.split(":")
108
+ return points.split(",")
109
+
110
+ for func, triplets in self.triplets.items():
111
+ order = self.order_of_point(func)
112
+ df = pd.DataFrame(
113
+ [
114
+ (
115
+ constraint.name,
116
+ *name_to_point_tuple(constraint.name),
117
+ )
118
+ for constraint in func.get_interpolation_constraints(self)
119
+ ],
120
+ columns=["constraint_name", "col_point", "row_point"],
121
+ )
122
+ df["row"] = df["row_point"].map(lambda x: order.index(x))
123
+ df["col"] = df["col_point"].map(lambda x: order.index(x))
124
+ func_to_df[func] = df
125
+ func_to_order[func] = order
126
+
127
+ return func_to_df, func_to_order
@@ -0,0 +1,102 @@
1
+ # Copyright: 2025 The PEPFlow Developers
2
+ #
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
20
+ import pandas as pd
21
+
22
+ from pepflow import pep_context as pc
23
+ from pepflow.function import SmoothConvexFunction
24
+ from pepflow.point import Point
25
+
26
+
27
+ def test_tracked_points():
28
+ ctx = pc.PEPContext()
29
+ pc.set_current_context(ctx)
30
+
31
+ f = SmoothConvexFunction(L=1, is_basis=True)
32
+ f.add_tag("f")
33
+
34
+ p1 = Point(is_basis=True, tags=["x_1"])
35
+ p2 = Point(is_basis=True, tags=["x_3"])
36
+ p3 = Point(is_basis=True, tags=["x_2"])
37
+ p_star = Point(is_basis=True, tags=["x_*"])
38
+
39
+ _ = f.generate_triplet(p1)
40
+ _ = f.generate_triplet(p2)
41
+ _ = f.generate_triplet(p3)
42
+ _ = f.generate_triplet(p_star)
43
+
44
+ assert ctx.order_of_point(f) == ["x_1", "x_2", "x_3", "x_*"]
45
+ assert ctx.tracked_point(f) == [p1, p3, p2, p_star]
46
+
47
+ pc.set_current_context(None)
48
+
49
+
50
+ def test_triplets_to_dataframe():
51
+ ctx = pc.PEPContext()
52
+ pc.set_current_context(ctx)
53
+
54
+ f = SmoothConvexFunction(L=1, is_basis=True)
55
+ f.add_tag("f")
56
+
57
+ p1 = Point(is_basis=True, tags=["x1"])
58
+ p2 = Point(is_basis=True, tags=["x3"])
59
+ p3 = Point(is_basis=True, tags=["x2"])
60
+
61
+ _ = f.generate_triplet(p1)
62
+ _ = f.generate_triplet(p2)
63
+ _ = f.generate_triplet(p3)
64
+
65
+ func_to_df, func_to_order = ctx.triplets_to_df_and_order()
66
+ expected_df = pd.DataFrame(
67
+ {
68
+ "constraint_name": [
69
+ "f:x1,x3",
70
+ "f:x1,x2",
71
+ "f:x3,x1",
72
+ "f:x3,x2",
73
+ "f:x2,x1",
74
+ "f:x2,x3",
75
+ ],
76
+ "col_point": ["x1", "x1", "x3", "x3", "x2", "x2"],
77
+ "row_point": ["x3", "x2", "x1", "x2", "x1", "x3"],
78
+ "row": [2, 1, 0, 1, 0, 2],
79
+ "col": [0, 0, 2, 2, 1, 1],
80
+ }
81
+ )
82
+
83
+ pd.testing.assert_frame_equal(func_to_df[f], expected_df)
84
+ assert func_to_order[f] == ["x1", "x2", "x3"]
85
+
86
+ pc.set_current_context(None)
87
+
88
+
89
+ def test_get_by_tag():
90
+ ctx = pc.PEPContext()
91
+ pc.set_current_context(ctx)
92
+
93
+ f = SmoothConvexFunction(L=1, is_basis=True)
94
+ f.add_tag("f")
95
+ p1 = Point(is_basis=True, tags=["x1"])
96
+
97
+ triplet = f.generate_triplet(p1)
98
+
99
+ assert ctx.get_by_tag("x1") == p1
100
+ assert ctx.get_by_tag("f(x1)") == triplet.function_value
101
+ assert ctx.get_by_tag("gradient_f(x1)") == triplet.gradient
102
+ pc.set_current_context(None)
pepflow/pep_test.py CHANGED
@@ -1,3 +1,22 @@
1
+ # Copyright: 2025 The PEPFlow Developers
2
+ #
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
1
20
  import pytest
2
21
 
3
22
  from pepflow import pep
pepflow/point.py CHANGED
@@ -1,3 +1,22 @@
1
+ # Copyright: 2025 The PEPFlow Developers
2
+ #
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
1
20
  from __future__ import annotations
2
21
 
3
22
  import uuid
@@ -28,7 +47,7 @@ class EvalExpressionPoint:
28
47
 
29
48
  @attrs.frozen
30
49
  class EvaluatedPoint:
31
- vector: np.array
50
+ vector: np.ndarray
32
51
 
33
52
  def __add__(self, other):
34
53
  if isinstance(other, EvaluatedPoint):
@@ -101,9 +120,20 @@ class Point:
101
120
  raise RuntimeError("Did you forget to create a context?")
102
121
  pep_context.add_point(self)
103
122
 
123
+ @property
124
+ def tag(self):
125
+ if len(self.tags) == 0:
126
+ raise ValueError("Point should have a name.")
127
+ return self.tags[-1]
128
+
104
129
  def add_tag(self, tag: str) -> None:
105
130
  self.tags.append(tag)
106
131
 
132
+ def __repr__(self):
133
+ if self.tags:
134
+ return self.tag
135
+ return super().__repr__()
136
+
107
137
  # TODO: add a validator that `is_basis` and `eval_expression` are properly setup.
108
138
  def __add__(self, other):
109
139
  assert is_numerical_or_point(other)