pepflow 0.1.3a1__py3-none-any.whl → 0.1.4a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,6 +22,7 @@ from __future__ import annotations
22
22
  import json
23
23
  from typing import TYPE_CHECKING
24
24
 
25
+ import attrs
25
26
  import dash
26
27
  import dash_bootstrap_components as dbc
27
28
  import numpy as np
@@ -29,11 +30,12 @@ import pandas as pd
29
30
  import plotly
30
31
  import plotly.express as px
31
32
  import plotly.graph_objects as go
32
- from dash import Dash, Input, Output, State, dcc, html
33
+ from dash import ALL, MATCH, Dash, Input, Output, State, dcc, html
33
34
 
34
35
  from pepflow.constants import PSD_CONSTRAINT
35
36
 
36
37
  if TYPE_CHECKING:
38
+ from pepflow.function import Function
37
39
  from pepflow.pep import PEPBuilder, PEPResult
38
40
  from pepflow.pep_context import PEPContext
39
41
 
@@ -42,26 +44,84 @@ plotly.io.renderers.default = "colab+vscode"
42
44
  plotly.io.templates.default = "plotly_white"
43
45
 
44
46
 
47
+ @attrs.frozen
48
+ class PlotData:
49
+ dataframe: pd.DataFrame
50
+ figure: go.Figure
51
+ function: Function
52
+
53
+ def dual_matrix_to_tab(self) -> html.Pre:
54
+ def get_matrix_of_dual_value(df: pd.DataFrame) -> np.ndarray:
55
+ # Check if we need to update the order.
56
+ return (
57
+ pd.pivot_table(
58
+ df, values="dual_value", index="row", columns="col", dropna=False
59
+ )
60
+ .fillna(0.0)
61
+ .to_numpy()
62
+ .T
63
+ )
64
+
65
+ with np.printoptions(precision=3, linewidth=100, suppress=True):
66
+ dual_value_tab = html.Pre(
67
+ str(get_matrix_of_dual_value(self.dataframe)),
68
+ id={"type": "dual-value-display", "index": self.function.tag},
69
+ style={
70
+ "border": "1px solid lightgrey",
71
+ "padding": "10px",
72
+ "height": "60vh",
73
+ "overflowY": "auto",
74
+ },
75
+ )
76
+ return dual_value_tab
77
+
78
+ def plot_data_to_tab(self) -> dbc.Tab:
79
+ tab = dbc.Tab(
80
+ html.Div(
81
+ [
82
+ html.P("Interactive Heat Map:"),
83
+ dcc.Graph(
84
+ id={
85
+ "type": "interactive-scatter",
86
+ "index": self.function.tag,
87
+ },
88
+ figure=self.figure,
89
+ ),
90
+ html.P("Dual Value Matrix:"),
91
+ self.dual_matrix_to_tab(),
92
+ ]
93
+ ),
94
+ label=f"{self.function.tag}-Interpolation Conditions",
95
+ tab_id=f"{self.function.tag}-interactive-scatter-tab",
96
+ )
97
+ return tab
98
+
99
+
45
100
  def solve_prob_and_get_figure(
46
101
  pep_builder: PEPBuilder, context: PEPContext
47
- ) -> tuple[go.Figure, pd.DataFrame, PEPResult]:
48
- assert len(context.triplets) == 1, "Support single function only for now"
102
+ ) -> tuple[list[PlotData], PEPResult]:
103
+ plot_data_list: list[PlotData] = []
49
104
 
50
105
  result = pep_builder.solve(context=context)
51
106
 
52
107
  df_dict, order_dict = context.triplets_to_df_and_order()
53
- f = pep_builder.functions[0]
54
108
 
55
- df = df_dict[f]
56
- order = order_dict[f]
109
+ for f in context.triplets.keys():
110
+ df = df_dict[f]
111
+ order = order_dict[f]
57
112
 
58
- df["constraint"] = df.constraint_name.map(
59
- lambda x: "inactive" if x in pep_builder.relaxed_constraints else "active"
60
- )
61
- df["dual_value"] = df.constraint_name.map(
62
- lambda x: result.dual_var_manager.dual_value(x)
63
- )
64
- return processed_df_to_fig(df, order), df, result
113
+ df["constraint"] = df.constraint_name.map(
114
+ lambda x: "inactive" if x in pep_builder.relaxed_constraints else "active"
115
+ )
116
+ df["dual_value"] = df.constraint_name.map(
117
+ lambda x: result.dual_var_manager.dual_value(x)
118
+ )
119
+
120
+ plot_data_list.append(
121
+ PlotData(dataframe=df, figure=processed_df_to_fig(df, order), function=f)
122
+ )
123
+
124
+ return plot_data_list, result
65
125
 
66
126
 
67
127
  def processed_df_to_fig(df: pd.DataFrame, order: list[str]):
@@ -99,19 +159,7 @@ def get_matrix_of_dual_value(df: pd.DataFrame) -> np.ndarray:
99
159
 
100
160
  def launch(pep_builder: PEPBuilder, context: PEPContext):
101
161
  app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
102
- fig, df, result = solve_prob_and_get_figure(pep_builder, context)
103
-
104
- with np.printoptions(precision=3, linewidth=100, suppress=True):
105
- dual_value_tab = html.Pre(
106
- str(get_matrix_of_dual_value(df)),
107
- id="dual-value-display",
108
- style={
109
- "border": "1px solid lightgrey",
110
- "padding": "10px",
111
- "height": "60vh",
112
- "overflowY": "auto",
113
- },
114
- )
162
+ plot_data_list, result = solve_prob_and_get_figure(pep_builder, context)
115
163
  # Think how can we manipulate the pep_builder here.
116
164
  display_row = dbc.Row(
117
165
  [
@@ -129,21 +177,8 @@ def launch(pep_builder: PEPBuilder, context: PEPContext):
129
177
  color="success",
130
178
  ),
131
179
  dbc.Tabs(
132
- [
133
- dbc.Tab(
134
- html.Div(
135
- dcc.Graph(id="interactive-scatter", figure=fig),
136
- ),
137
- label="Interactive Heatmap",
138
- tab_id="interactive-scatter-tab",
139
- ),
140
- dbc.Tab(
141
- dual_value_tab,
142
- label="Dual Value Matrix",
143
- tab_id="dual_value_tab",
144
- ),
145
- ],
146
- active_tab="interactive-scatter-tab",
180
+ [plot_data.plot_data_to_tab() for plot_data in plot_data_list],
181
+ active_tab=f"{plot_data_list[0].function.tag}-interactive-scatter-tab",
147
182
  ),
148
183
  ],
149
184
  width=5,
@@ -175,25 +210,37 @@ def launch(pep_builder: PEPBuilder, context: PEPContext):
175
210
  [
176
211
  html.H2("PEPFlow"),
177
212
  display_row,
178
- # Store the entire DataFrame as a dictionary in dcc.Store
179
- dcc.Store(id="dataframe-store", data=df.to_dict("records")),
213
+ # For each function, store the corresponding DataFrame as a dictionary in dcc.Store
214
+ *[
215
+ dcc.Store(
216
+ id={
217
+ "type": "dataframe-store",
218
+ "index": plot_data.function.tag,
219
+ },
220
+ data=(
221
+ plot_data.function.tag,
222
+ plot_data.dataframe.to_dict("records"),
223
+ ),
224
+ )
225
+ for plot_data in plot_data_list
226
+ ],
180
227
  ]
181
228
  )
182
229
 
183
230
  @dash.callback(
184
231
  Output("result-card", "children"),
185
- Output("dual-value-display", "children"),
186
- Output("interactive-scatter", "figure"),
187
- Output("dataframe-store", "data"),
232
+ Output({"type": "dual-value-display", "index": ALL}, "children"),
233
+ Output({"type": "interactive-scatter", "index": ALL}, "figure"),
234
+ Output({"type": "dataframe-store", "index": ALL}, "data"),
188
235
  Input("solve-button", "n_clicks"),
189
236
  )
190
237
  def solve(_):
191
- fig, df, result = solve_prob_and_get_figure(pep_builder, context)
238
+ plot_data_list, result = solve_prob_and_get_figure(pep_builder, context)
192
239
  with np.printoptions(precision=3, linewidth=100, suppress=True):
193
240
  psd_dual_value = np.array(
194
241
  result.dual_var_manager.dual_value(PSD_CONSTRAINT)
195
242
  )
196
- reslt_card = dbc.CardBody(
243
+ result_card = dbc.CardBody(
197
244
  [
198
245
  html.H2(f"Optimal Value {result.primal_opt_value:.4g}"),
199
246
  html.H3(f"Solver Status: {result.solver_status}"),
@@ -203,50 +250,89 @@ def launch(pep_builder: PEPBuilder, context: PEPContext):
203
250
  html.Pre(json.dumps(pep_builder.relaxed_constraints, indent=2)),
204
251
  ]
205
252
  )
206
- dual_value_display = str(get_matrix_of_dual_value(df))
207
- return reslt_card, dual_value_display, fig, df.to_dict("records")
253
+ dual_value_displays = [
254
+ str(get_matrix_of_dual_value(plot_data.dataframe))
255
+ for plot_data in plot_data_list
256
+ ]
257
+ figs = [plot_data.figure for plot_data in plot_data_list]
258
+
259
+ df_data = [
260
+ (plot_data.function.tag, plot_data.dataframe.to_dict("records"))
261
+ for plot_data in plot_data_list
262
+ ]
263
+
264
+ return result_card, dual_value_displays, figs, df_data
208
265
 
209
266
  @dash.callback(
210
- Output("interactive-scatter", "figure", allow_duplicate=True),
211
- Output("dataframe-store", "data", allow_duplicate=True),
267
+ Output(
268
+ {"type": "interactive-scatter", "index": ALL},
269
+ "figure",
270
+ allow_duplicate=True,
271
+ ),
272
+ Output({"type": "dataframe-store", "index": ALL}, "data", allow_duplicate=True),
212
273
  Input("restore-all-constraints-button", "n_clicks"),
213
- State("dataframe-store", "data"),
274
+ State({"type": "dataframe-store", "index": ALL}, "data"),
214
275
  prevent_initial_call=True,
215
276
  )
216
- def restore_all_constraints(_, previous_df):
277
+ def restore_all_constraints(_, list_previous_df_tuples):
217
278
  nonlocal pep_builder
218
- df_updated = pd.DataFrame(previous_df)
219
279
  pep_builder.relaxed_constraints = []
220
- df_updated["constraint"] = "active"
221
- order = context.order_of_point(pep_builder.functions[0])
222
- return processed_df_to_fig(df_updated, order), df_updated.to_dict("records")
280
+ updated_figs = []
281
+ df_data = []
282
+ for previous_df_tuple in list_previous_df_tuples:
283
+ tag, previous_df = previous_df_tuple
284
+ df_updated = pd.DataFrame(previous_df)
285
+ df_updated["constraint"] = "active"
286
+ order = context.order_of_point(pep_builder.get_func_by_tag(tag))
287
+ updated_figs.append(processed_df_to_fig(df_updated, order))
288
+ df_data.append((tag, df_updated.to_dict("records")))
289
+ return updated_figs, df_data
223
290
 
224
291
  @dash.callback(
225
- Output("interactive-scatter", "figure", allow_duplicate=True),
226
- Output("dataframe-store", "data", allow_duplicate=True),
292
+ Output(
293
+ {"type": "interactive-scatter", "index": ALL},
294
+ "figure",
295
+ allow_duplicate=True,
296
+ ),
297
+ Output({"type": "dataframe-store", "index": ALL}, "data", allow_duplicate=True),
227
298
  Input("relax-all-constraints-button", "n_clicks"),
228
- State("dataframe-store", "data"),
299
+ State({"type": "dataframe-store", "index": ALL}, "data"),
229
300
  prevent_initial_call=True,
230
301
  )
231
- def relax_all_constraints(_, previous_df):
302
+ def relax_all_constraints(_, list_previous_df_tuples):
232
303
  nonlocal pep_builder
233
- df_updated = pd.DataFrame(previous_df)
234
- pep_builder.relaxed_constraints = df_updated["constraint_name"].to_list()
235
- df_updated["constraint"] = "inactive"
236
- order = context.order_of_point(pep_builder.functions[0])
237
- return processed_df_to_fig(df_updated, order), df_updated.to_dict("records")
304
+ pep_builder.relaxed_constraints = []
305
+ updated_figs = []
306
+ df_data = []
307
+ for previous_df_tuple in list_previous_df_tuples:
308
+ tag, previous_df = previous_df_tuple
309
+ df_updated = pd.DataFrame(previous_df)
310
+ pep_builder.relaxed_constraints.extend(
311
+ df_updated["constraint_name"].to_list()
312
+ )
313
+ df_updated["constraint"] = "inactive"
314
+ order = context.order_of_point(pep_builder.get_func_by_tag(tag))
315
+ updated_figs.append(processed_df_to_fig(df_updated, order))
316
+ df_data.append((tag, df_updated.to_dict("records")))
317
+ return updated_figs, df_data
238
318
 
239
319
  @dash.callback(
240
- Output("interactive-scatter", "figure", allow_duplicate=True),
241
- Output("dataframe-store", "data", allow_duplicate=True),
242
- Input("interactive-scatter", "clickData"),
243
- State("dataframe-store", "data"),
320
+ Output(
321
+ {"type": "interactive-scatter", "index": MATCH},
322
+ "figure",
323
+ allow_duplicate=True,
324
+ ),
325
+ Output(
326
+ {"type": "dataframe-store", "index": MATCH}, "data", allow_duplicate=True
327
+ ),
328
+ Input({"type": "interactive-scatter", "index": MATCH}, "clickData"),
329
+ State({"type": "dataframe-store", "index": MATCH}, "data"),
244
330
  prevent_initial_call=True,
245
331
  )
246
- def update_df_and_redraw(clickData, previous_df):
332
+ def update_df_and_redraw(clickData, previous_df_tuple):
247
333
  nonlocal pep_builder
248
334
  if not clickData["points"][0]["customdata"]:
249
- return dash.no_update, dash.no_update, dash.no_update
335
+ return dash.no_update, dash.no_update
250
336
 
251
337
  clicked_name = clickData["points"][0]["customdata"][0]
252
338
  if clicked_name not in pep_builder.relaxed_constraints:
@@ -254,11 +340,15 @@ def launch(pep_builder: PEPBuilder, context: PEPContext):
254
340
  else:
255
341
  pep_builder.relaxed_constraints.remove(clicked_name)
256
342
 
343
+ tag, previous_df = previous_df_tuple
257
344
  df_updated = pd.DataFrame(previous_df)
258
345
  df_updated["constraint"] = df_updated.constraint_name.map(
259
346
  lambda x: "inactive" if x in pep_builder.relaxed_constraints else "active"
260
347
  )
261
- order = context.order_of_point(pep_builder.functions[0])
262
- return processed_df_to_fig(df_updated, order), df_updated.to_dict("records")
348
+ order = context.order_of_point(pep_builder.get_func_by_tag(tag))
349
+ return processed_df_to_fig(df_updated, order), (
350
+ tag,
351
+ df_updated.to_dict("records"),
352
+ )
263
353
 
264
354
  app.run(debug=True)
pepflow/pep.py CHANGED
@@ -87,6 +87,13 @@ class PEPBuilder:
87
87
  # We should think about a better choice like manager.
88
88
  self.relaxed_constraints = []
89
89
 
90
+ def clear_setup(self):
91
+ self.init_conditions.clear()
92
+ self.functions.clear()
93
+ self.interpolation_constraints.clear()
94
+ self.performance_metric = None
95
+ self.relaxed_constraints.clear()
96
+
90
97
  @contextlib.contextmanager
91
98
  def make_context(
92
99
  self, name: str, override: bool = False
@@ -94,7 +101,8 @@ class PEPBuilder:
94
101
  if not override and name in self.pep_context_dict:
95
102
  raise KeyError(f"There is already a context {name} in the builder")
96
103
  try:
97
- ctx = pc.PEPContext()
104
+ self.clear_setup()
105
+ ctx = pc.PEPContext(name)
98
106
  self.pep_context_dict[name] = ctx
99
107
  pc.set_current_context(ctx)
100
108
  yield ctx
@@ -130,11 +138,19 @@ class PEPBuilder:
130
138
  def set_relaxed_constraints(self, relaxed_constraints: list[str]):
131
139
  self.relaxed_constraints.extend(relaxed_constraints)
132
140
 
133
- def declare_func(self, function_class, **kwargs):
141
+ def declare_func(self, function_class: type[Function], tag: str, **kwargs):
134
142
  func = function_class(is_basis=True, composition=None, **kwargs)
143
+ func.add_tag(tag)
135
144
  self.functions.append(func)
136
145
  return func
137
146
 
147
+ def get_func_by_tag(self, tag: str):
148
+ # TODO: Add support to return composite functions as well. Right now we can only return base functions
149
+ for f in self.functions:
150
+ if tag in f.tags:
151
+ return f
152
+ raise ValueError("Cannot find the function of given tag.")
153
+
138
154
  def solve(self, context: pc.PEPContext | None = None, **kwargs):
139
155
  if context is None:
140
156
  context = pc.get_current_context()
@@ -144,7 +160,6 @@ class PEPBuilder:
144
160
  all_constraints: list[Constraint] = [*self.init_conditions]
145
161
  for f in self.functions:
146
162
  all_constraints.extend(f.get_interpolation_constraints())
147
- all_constraints.extend(context.opt_conditions[f])
148
163
 
149
164
  # for now, we heavily rely on the CVX. We can make a wrapper class to avoid
150
165
  # direct dependency in the future.
pepflow/pep_context.py CHANGED
@@ -26,13 +26,14 @@ import natsort
26
26
  import pandas as pd
27
27
 
28
28
  if TYPE_CHECKING:
29
- from pepflow.constraint import Constraint
30
29
  from pepflow.function import Function, Triplet
31
30
  from pepflow.point import Point
32
31
  from pepflow.scalar import Scalar
33
32
 
34
33
  # A global variable for storing the current context that is used for points or scalars.
35
34
  CURRENT_CONTEXT: PEPContext | None = None
35
+ # Keep the track of all previous created context
36
+ GLOBAL_CONTEXT_DICT: dict[str, PEPContext] = {}
36
37
 
37
38
 
38
39
  def get_current_context() -> PEPContext | None:
@@ -46,11 +47,15 @@ def set_current_context(ctx: PEPContext | None):
46
47
 
47
48
 
48
49
  class PEPContext:
49
- def __init__(self):
50
+ def __init__(self, name: str):
51
+ self.name = name
50
52
  self.points: list[Point] = []
51
53
  self.scalars: list[Scalar] = []
52
54
  self.triplets: dict[Function, list[Triplet]] = defaultdict(list)
53
- self.opt_conditions: dict[Function, list[Constraint]] = defaultdict(list)
55
+ # self.triplets will contain all stationary_triplets. They are not mutually exclusive.
56
+ self.stationary_triplets: dict[Function, list[Triplet]] = defaultdict(list)
57
+
58
+ GLOBAL_CONTEXT_DICT[name] = self
54
59
 
55
60
  def set_as_current(self) -> PEPContext:
56
61
  set_current_context(self)
@@ -65,8 +70,8 @@ class PEPContext:
65
70
  def add_triplet(self, function: Function, triplet: Triplet):
66
71
  self.triplets[function].append(triplet)
67
72
 
68
- def add_opt_condition(self, function: Function, opt_condition: Constraint):
69
- self.opt_conditions[function].append(opt_condition)
73
+ def add_stationary_triplet(self, function: Function, stationary_triplet: Triplet):
74
+ self.stationary_triplets[function].append(stationary_triplet)
70
75
 
71
76
  def get_by_tag(self, tag: str) -> Point | Scalar:
72
77
  for p in self.points:
@@ -75,13 +80,13 @@ class PEPContext:
75
80
  for s in self.scalars:
76
81
  if tag in s.tags:
77
82
  return s
78
- raise ValueError("Cannot find the point or scalar of given tag")
83
+ raise ValueError("Cannot find the point, scalar, or function of given tag.")
79
84
 
80
85
  def clear(self):
81
86
  self.points.clear()
82
87
  self.scalars.clear()
83
88
  self.triplets.clear()
84
- self.opt_conditions.clear()
89
+ self.stationary_triplets.clear()
85
90
 
86
91
  def tracked_point(self, func: Function) -> list[Point]:
87
92
  return natsort.natsorted(
@@ -17,17 +17,25 @@
17
17
  # specific language governing permissions and limitations
18
18
  # under the License.
19
19
 
20
+ from typing import Iterator
21
+
20
22
  import pandas as pd
23
+ import pytest
21
24
 
22
25
  from pepflow import pep_context as pc
23
26
  from pepflow.function import SmoothConvexFunction
24
27
  from pepflow.point import Point
25
28
 
26
29
 
27
- def test_tracked_points():
28
- ctx = pc.PEPContext()
29
- pc.set_current_context(ctx)
30
+ @pytest.fixture
31
+ def pep_context() -> Iterator[pc.PEPContext]:
32
+ """Prepare the pep context and reset the context to None at the end."""
33
+ ctx = pc.PEPContext("test").set_as_current()
34
+ yield ctx
35
+ pc.set_current_context(None)
36
+
30
37
 
38
+ def test_tracked_points(pep_context: pc.PEPContext):
31
39
  f = SmoothConvexFunction(L=1, is_basis=True)
32
40
  f.add_tag("f")
33
41
 
@@ -41,16 +49,11 @@ def test_tracked_points():
41
49
  _ = f.generate_triplet(p3)
42
50
  _ = f.generate_triplet(p_star)
43
51
 
44
- assert ctx.order_of_point(f) == ["x_1", "x_2", "x_3", "x_*"]
45
- assert ctx.tracked_point(f) == [p1, p3, p2, p_star]
46
-
47
- pc.set_current_context(None)
48
-
52
+ assert pep_context.order_of_point(f) == ["x_1", "x_2", "x_3", "x_*"]
53
+ assert pep_context.tracked_point(f) == [p1, p3, p2, p_star]
49
54
 
50
- def test_triplets_to_dataframe():
51
- ctx = pc.PEPContext()
52
- pc.set_current_context(ctx)
53
55
 
56
+ def test_triplets_to_dataframe(pep_context: pc.PEPContext):
54
57
  f = SmoothConvexFunction(L=1, is_basis=True)
55
58
  f.add_tag("f")
56
59
 
@@ -62,7 +65,7 @@ def test_triplets_to_dataframe():
62
65
  _ = f.generate_triplet(p2)
63
66
  _ = f.generate_triplet(p3)
64
67
 
65
- func_to_df, func_to_order = ctx.triplets_to_df_and_order()
68
+ func_to_df, func_to_order = pep_context.triplets_to_df_and_order()
66
69
  expected_df = pd.DataFrame(
67
70
  {
68
71
  "constraint_name": [
@@ -83,20 +86,19 @@ def test_triplets_to_dataframe():
83
86
  pd.testing.assert_frame_equal(func_to_df[f], expected_df)
84
87
  assert func_to_order[f] == ["x1", "x2", "x3"]
85
88
 
86
- pc.set_current_context(None)
87
-
88
-
89
- def test_get_by_tag():
90
- ctx = pc.PEPContext()
91
- pc.set_current_context(ctx)
92
89
 
90
+ def test_get_by_tag(pep_context: pc.PEPContext):
93
91
  f = SmoothConvexFunction(L=1, is_basis=True)
94
92
  f.add_tag("f")
95
93
  p1 = Point(is_basis=True, tags=["x1"])
94
+ p2 = Point(is_basis=True, tags=["x2"])
95
+ p3 = p1 + p2
96
96
 
97
97
  triplet = f.generate_triplet(p1)
98
+ _ = f.generate_triplet(p2)
98
99
 
99
- assert ctx.get_by_tag("x1") == p1
100
- assert ctx.get_by_tag("f(x1)") == triplet.function_value
101
- assert ctx.get_by_tag("gradient_f(x1)") == triplet.gradient
100
+ assert pep_context.get_by_tag("x1") == p1
101
+ assert pep_context.get_by_tag("f(x1)") == triplet.function_value
102
+ assert pep_context.get_by_tag("gradient_f(x1)") == triplet.gradient
103
+ assert pep_context.get_by_tag("x1+x2") == p3
102
104
  pc.set_current_context(None)
pepflow/pep_test.py CHANGED
@@ -19,6 +19,7 @@
19
19
 
20
20
  import pytest
21
21
 
22
+ from pepflow import function as fc
22
23
  from pepflow import pep
23
24
  from pepflow import pep_context as pc
24
25
 
@@ -75,3 +76,10 @@ class TestPEPBuilder:
75
76
 
76
77
  with builder.make_context("test", override=True):
77
78
  pass
79
+
80
+ def test_get_func_by_tag(self) -> None:
81
+ builder = pep.PEPBuilder()
82
+ with builder.make_context("test"):
83
+ f = builder.declare_func(fc.SmoothConvexFunction, "f", L=1)
84
+
85
+ assert builder.get_func_by_tag("f") == f
pepflow/point.py CHANGED
@@ -134,75 +134,110 @@ class Point:
134
134
  return self.tag
135
135
  return super().__repr__()
136
136
 
137
+ def _repr_latex_(self):
138
+ s = repr(self)
139
+ s = s.replace("star", r"\star")
140
+ s = s.replace("gradient_", r"\nabla ")
141
+ s = s.replace("|", r"\|")
142
+ return rf"$\\displaystyle {s}$"
143
+
137
144
  # TODO: add a validator that `is_basis` and `eval_expression` are properly setup.
138
145
  def __add__(self, other):
139
- assert is_numerical_or_point(other)
146
+ assert isinstance(other, Point)
140
147
  return Point(
141
148
  is_basis=False,
142
149
  eval_expression=EvalExpressionPoint(utils.Op.ADD, self, other),
150
+ tags=[f"{self.tag}+{other.tag}"],
143
151
  )
144
152
 
145
153
  def __radd__(self, other):
146
- assert is_numerical_or_point(other)
154
+ # TODO: come up with better way to handle this
155
+ if other == 0:
156
+ return self
157
+ assert isinstance(other, Point)
147
158
  return Point(
148
159
  is_basis=False,
149
160
  eval_expression=EvalExpressionPoint(utils.Op.ADD, other, self),
161
+ tags=[f"{other.tag}+{self.tag}"],
150
162
  )
151
163
 
152
164
  def __sub__(self, other):
153
- assert is_numerical_or_point(other)
165
+ assert isinstance(other, Point)
166
+ tag_other = utils.parenthesize_tag(other)
154
167
  return Point(
155
168
  is_basis=False,
156
169
  eval_expression=EvalExpressionPoint(utils.Op.SUB, self, other),
170
+ tags=[f"{self.tag}-{tag_other}"],
157
171
  )
158
172
 
159
173
  def __rsub__(self, other):
160
- assert is_numerical_or_point(other)
174
+ assert isinstance(other, Point)
175
+ tag_self = utils.parenthesize_tag(self)
161
176
  return Point(
162
177
  is_basis=False,
163
178
  eval_expression=EvalExpressionPoint(utils.Op.SUB, other, self),
179
+ tags=[f"{other.tag}-{tag_self}"],
164
180
  )
165
181
 
166
182
  def __mul__(self, other):
167
183
  # TODO allow the other to be point so that we return a scalar.
168
184
  assert is_numerical_or_point(other)
185
+ tag_self = utils.parenthesize_tag(self)
169
186
  if utils.is_numerical(other):
170
187
  return Point(
171
188
  is_basis=False,
172
189
  eval_expression=EvalExpressionPoint(utils.Op.MUL, self, other),
190
+ tags=[f"{tag_self}*{other:.4g}"],
173
191
  )
174
192
  else:
193
+ tag_other = utils.parenthesize_tag(other)
175
194
  return Scalar(
176
195
  is_basis=False,
177
- eval_expression=EvalExpressionScalar(utils.Op.MUL, self, other), # TODO
196
+ eval_expression=EvalExpressionScalar(utils.Op.MUL, self, other),
197
+ tags=[f"{tag_self}*{tag_other}"],
178
198
  )
179
199
 
180
200
  def __rmul__(self, other):
181
201
  # TODO allow the other to be point so that we return a scalar.
182
202
  assert is_numerical_or_point(other)
203
+ tag_self = utils.parenthesize_tag(self)
183
204
  if utils.is_numerical(other):
184
205
  return Point(
185
206
  is_basis=False,
186
207
  eval_expression=EvalExpressionPoint(utils.Op.MUL, other, self),
208
+ tags=[f"{other:.4g}*{tag_self}"],
187
209
  )
188
210
  else:
211
+ tag_other = utils.parenthesize_tag(other)
189
212
  return Scalar(
190
213
  is_basis=False,
191
- eval_expression=EvalExpressionScalar(utils.Op.MUL, other, self), # TODO
214
+ eval_expression=EvalExpressionScalar(utils.Op.MUL, other, self),
215
+ tags=[f"{tag_other}*{tag_self}"],
192
216
  )
193
217
 
194
218
  def __pow__(self, power):
195
219
  assert power == 2
196
- return self.__rmul__(self)
220
+ return Scalar(
221
+ is_basis=False,
222
+ eval_expression=EvalExpressionScalar(utils.Op.MUL, self, self),
223
+ tags=[rf"|{self.tag}|^{power}"],
224
+ )
197
225
 
198
226
  def __neg__(self):
199
- return self.__rmul__(other=-1)
227
+ tag_self = utils.parenthesize_tag(self)
228
+ return Point(
229
+ is_basis=False,
230
+ eval_expression=EvalExpressionPoint(utils.Op.MUL, -1, self),
231
+ tags=[f"-{tag_self}"],
232
+ )
200
233
 
201
234
  def __truediv__(self, other):
202
235
  assert utils.is_numerical(other)
236
+ tag_self = utils.parenthesize_tag(self)
203
237
  return Point(
204
238
  is_basis=False,
205
239
  eval_expression=EvalExpressionPoint(utils.Op.DIV, self, other),
240
+ tags=[f"1/{other:.4g}*{tag_self}"],
206
241
  )
207
242
 
208
243
  def __hash__(self):