pepflow 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pepflow/function_test.py CHANGED
@@ -38,8 +38,8 @@ def pep_context() -> Iterator[pc.PEPContext]:
38
38
 
39
39
 
40
40
  def test_function_add_tag(pep_context: pc.PEPContext) -> None:
41
- f1 = fc.Function(is_basis=True, reuse_gradient=False, tags=["f1"])
42
- f2 = fc.Function(is_basis=True, reuse_gradient=False, tags=["f2"])
41
+ f1 = fc.Function(is_basis=True, tags=["f1"])
42
+ f2 = fc.Function(is_basis=True, tags=["f2"])
43
43
 
44
44
  f_add = f1 + f2
45
45
  assert f_add.tag == "f1+f2"
@@ -55,7 +55,7 @@ def test_function_add_tag(pep_context: pc.PEPContext) -> None:
55
55
 
56
56
 
57
57
  def test_function_mul_tag(pep_context: pc.PEPContext) -> None:
58
- f = fc.Function(is_basis=True, reuse_gradient=False, tags=["f"])
58
+ f = fc.Function(is_basis=True, tags=["f"])
59
59
 
60
60
  f_mul = f * 0.1
61
61
  assert f_mul.tag == "0.1*f"
@@ -71,8 +71,8 @@ def test_function_mul_tag(pep_context: pc.PEPContext) -> None:
71
71
 
72
72
 
73
73
  def test_function_add_and_mul_tag(pep_context: pc.PEPContext) -> None:
74
- f1 = fc.Function(is_basis=True, reuse_gradient=False, tags=["f1"])
75
- f2 = fc.Function(is_basis=True, reuse_gradient=False, tags=["f2"])
74
+ f1 = fc.Function(is_basis=True, tags=["f1"])
75
+ f2 = fc.Function(is_basis=True, tags=["f2"])
76
76
 
77
77
  f_add_mul = (f1 + f2) * 0.1
78
78
  assert f_add_mul.tag == "0.1*(f1+f2)"
@@ -94,116 +94,182 @@ def test_function_add_and_mul_tag(pep_context: pc.PEPContext) -> None:
94
94
 
95
95
 
96
96
  def test_function_call(pep_context: pc.PEPContext) -> None:
97
- f = fc.Function(is_basis=True, reuse_gradient=False, tags=["f"])
97
+ f = fc.Function(is_basis=True, tags=["f"])
98
98
  x = point.Point(is_basis=True, eval_expression=None, tags=["x"])
99
99
  assert f.function_value(x) == f(x)
100
100
 
101
101
 
102
- def test_function_repr():
103
- pep_builder = pep.PEPBuilder()
104
- with pep_builder.make_context("test"):
105
- f = fc.Function(is_basis=True, reuse_gradient=False)
106
- print(f) # it should be fine without tag
107
- f.add_tag("f")
108
- assert str(f) == "f"
109
-
110
-
111
- def test_stationary_point():
112
- pep_builder = pep.PEPBuilder()
113
- with pep_builder.make_context("test") as ctx:
114
- f = fc.Function(is_basis=True, reuse_gradient=False, tags=["f"])
115
- f.add_stationary_point("x_star")
116
-
117
- assert len(ctx.triplets) == 1
118
- assert len(ctx.triplets[f]) == 1
119
-
120
- f_triplet = ctx.triplets[f][0]
121
- assert f_triplet.name == "x_star_f(x_star)_gradient_f(x_star)"
122
- assert f_triplet.gradient.tag == "gradient_f(x_star)"
123
- assert f_triplet.function_value.tag == "f(x_star)"
124
-
125
- em = exm.ExpressionManager(ctx)
126
- np.testing.assert_allclose(
127
- em.eval_point(f_triplet.gradient).vector, np.array([0])
128
- )
129
- np.testing.assert_allclose(em.eval_point(f_triplet.point).vector, np.array([1]))
130
-
131
-
132
- def test_stationary_point_scaled():
133
- pep_builder = pep.PEPBuilder()
134
- with pep_builder.make_context("test") as ctx:
135
- f = fc.Function(is_basis=True, reuse_gradient=False, tags=["f"])
136
- g = 5 * f
137
- g.add_stationary_point("x_star")
138
-
139
- assert len(ctx.triplets) == 1
140
- assert len(ctx.triplets[f]) == 1
141
-
142
- f_triplet = ctx.triplets[f][0]
143
- assert f_triplet.name == "x_star_f(x_star)_gradient_f(x_star)"
144
- assert f_triplet.gradient.tag == "gradient_f(x_star)"
145
- assert f_triplet.function_value.tag == "f(x_star)"
146
-
147
- em = exm.ExpressionManager(ctx)
148
- np.testing.assert_allclose(
149
- em.eval_point(f_triplet.gradient).vector, np.array([0])
150
- )
151
- np.testing.assert_allclose(em.eval_point(f_triplet.point).vector, np.array([1]))
152
-
153
-
154
- def test_stationary_point_additive():
155
- pep_builder = pep.PEPBuilder()
156
- with pep_builder.make_context("test") as ctx:
157
- f = fc.Function(is_basis=True, reuse_gradient=False)
158
- f.add_tag("f")
159
- g = fc.Function(is_basis=True, reuse_gradient=False)
160
- g.add_tag("g")
161
- h = f + g
162
- h.add_tag("h")
163
-
164
- h.add_stationary_point("x_star")
165
- assert len(ctx.triplets) == 2
166
- assert len(ctx.triplets[f]) == 1
167
- assert len(ctx.triplets[g]) == 1
168
-
169
- f_triplet = ctx.triplets[f][0]
170
- g_triplet = ctx.triplets[g][0]
171
- assert f_triplet.name == "x_star_f(x_star)_gradient_f(x_star)"
172
- assert g_triplet.name == "x_star_g(x_star)_gradient_g(x_star)"
173
-
174
- em = exm.ExpressionManager(ctx)
175
- np.testing.assert_allclose(
176
- em.eval_point(f_triplet.gradient).vector, np.array([0, 1])
177
- )
178
- np.testing.assert_allclose(
179
- em.eval_point(g_triplet.gradient).vector, np.array([0, -1])
180
- )
181
-
182
-
183
- def test_stationary_point_linear_combination():
184
- pep_builder = pep.PEPBuilder()
185
- with pep_builder.make_context("test") as ctx:
186
- f = fc.Function(is_basis=True, reuse_gradient=False)
187
- f.add_tag("f")
188
- g = fc.Function(is_basis=True, reuse_gradient=False)
189
- g.add_tag("g")
190
- h = 3 * f + 2 * g
191
- h.add_tag("h")
192
-
193
- h.add_stationary_point("x_star")
194
- assert len(ctx.triplets) == 2
195
- assert len(ctx.triplets[f]) == 1
196
- assert len(ctx.triplets[g]) == 1
197
-
198
- f_triplet = ctx.triplets[f][0]
199
- g_triplet = ctx.triplets[g][0]
200
- assert f_triplet.name == "x_star_f(x_star)_gradient_f(x_star)"
201
- assert g_triplet.name == "x_star_g(x_star)_gradient_g(x_star)"
202
-
203
- em = exm.ExpressionManager(ctx)
204
- np.testing.assert_allclose(
205
- em.eval_point(f_triplet.gradient).vector, np.array([0, 1])
206
- )
207
- np.testing.assert_allclose(
208
- em.eval_point(g_triplet.gradient).vector, np.array([0, -1.5])
209
- )
102
+ def test_function_repr(pep_context: pc.PEPContext):
103
+ f = fc.Function(
104
+ is_basis=True,
105
+ )
106
+ print(f) # it should be fine without tag
107
+ f.add_tag("f")
108
+ assert str(f) == "f"
109
+
110
+
111
+ def test_stationary_point(pep_context: pc.PEPContext):
112
+ f = fc.Function(
113
+ is_basis=True,
114
+ tags=["f"],
115
+ )
116
+ f.add_stationary_point("x_star")
117
+
118
+ assert len(pep_context.triplets) == 1
119
+ assert len(pep_context.triplets[f]) == 1
120
+
121
+ f_triplet = pep_context.triplets[f][0]
122
+ assert f_triplet.name == "x_star_f(x_star)_gradient_f(x_star)"
123
+ assert f_triplet.gradient.tag == "gradient_f(x_star)"
124
+ assert f_triplet.function_value.tag == "f(x_star)"
125
+
126
+ em = exm.ExpressionManager(pep_context)
127
+ np.testing.assert_allclose(em.eval_point(f_triplet.gradient).vector, np.array([0]))
128
+ np.testing.assert_allclose(em.eval_point(f_triplet.point).vector, np.array([1]))
129
+
130
+
131
+ def test_stationary_point_scaled(pep_context: pc.PEPContext):
132
+ f = fc.Function(
133
+ is_basis=True,
134
+ tags=["f"],
135
+ )
136
+ g = 5 * f
137
+ g.add_stationary_point("x_star")
138
+
139
+ assert len(pep_context.triplets) == 1
140
+ assert len(pep_context.triplets[f]) == 1
141
+
142
+ f_triplet = pep_context.triplets[f][0]
143
+ assert f_triplet.name == "x_star_f(x_star)_gradient_f(x_star)"
144
+ assert f_triplet.gradient.tag == "gradient_f(x_star)"
145
+ assert f_triplet.function_value.tag == "f(x_star)"
146
+
147
+ em = exm.ExpressionManager(pep_context)
148
+ np.testing.assert_allclose(em.eval_point(f_triplet.gradient).vector, np.array([0]))
149
+ np.testing.assert_allclose(em.eval_point(f_triplet.point).vector, np.array([1]))
150
+
151
+
152
+ def test_stationary_point_additive(pep_context: pc.PEPContext):
153
+ f = fc.Function(is_basis=True)
154
+ f.add_tag("f")
155
+ g = fc.Function(is_basis=True)
156
+ g.add_tag("g")
157
+ h = f + g
158
+ h.add_tag("h")
159
+
160
+ h.add_stationary_point("x_star")
161
+ assert len(pep_context.triplets) == 2
162
+ assert len(pep_context.triplets[f]) == 1
163
+ assert len(pep_context.triplets[g]) == 1
164
+
165
+ f_triplet = pep_context.triplets[f][0]
166
+ g_triplet = pep_context.triplets[g][0]
167
+ assert f_triplet.name == "x_star_f(x_star)_gradient_f(x_star)"
168
+ assert g_triplet.name == "x_star_g(x_star)_gradient_g(x_star)"
169
+
170
+ em = exm.ExpressionManager(pep_context)
171
+ np.testing.assert_allclose(
172
+ em.eval_point(f_triplet.gradient).vector, np.array([0, 1])
173
+ )
174
+ np.testing.assert_allclose(
175
+ em.eval_point(g_triplet.gradient).vector, np.array([0, -1])
176
+ )
177
+
178
+
179
+ def test_stationary_point_linear_combination(pep_context: pc.PEPContext):
180
+ f = fc.Function(
181
+ is_basis=True,
182
+ )
183
+ f.add_tag("f")
184
+ g = fc.Function(
185
+ is_basis=True,
186
+ )
187
+ g.add_tag("g")
188
+ h = 3 * f + 2 * g
189
+ h.add_tag("h")
190
+
191
+ h.add_stationary_point("x_star")
192
+ assert len(pep_context.triplets) == 2
193
+ assert len(pep_context.triplets[f]) == 1
194
+ assert len(pep_context.triplets[g]) == 1
195
+
196
+ f_triplet = pep_context.triplets[f][0]
197
+ g_triplet = pep_context.triplets[g][0]
198
+ assert f_triplet.name == "x_star_f(x_star)_gradient_f(x_star)"
199
+ assert g_triplet.name == "x_star_g(x_star)_gradient_g(x_star)"
200
+
201
+ em = exm.ExpressionManager(pep_context)
202
+ np.testing.assert_allclose(
203
+ em.eval_point(f_triplet.gradient).vector, np.array([0, 1])
204
+ )
205
+ np.testing.assert_allclose(
206
+ em.eval_point(g_triplet.gradient).vector, np.array([0, -1.5])
207
+ )
208
+
209
+
210
+ def test_function_generate_triplet(pep_context: pc.PEPContext):
211
+ f = fc.Function(is_basis=True)
212
+ f.add_tag("f")
213
+ g = fc.Function(is_basis=True)
214
+ g.add_tag("g")
215
+ h = 5 * f + 5 * g
216
+ h.add_tag("h")
217
+
218
+ p1 = point.Point(is_basis=True)
219
+ p1.add_tag("p1")
220
+ p1_triplet = h.generate_triplet(p1)
221
+ p1_triplet_1 = h.generate_triplet(p1)
222
+
223
+ pm = exm.ExpressionManager(pep_context)
224
+
225
+ np.testing.assert_allclose(pm.eval_point(p1).vector, np.array([1, 0, 0]))
226
+
227
+ np.testing.assert_allclose(
228
+ pm.eval_point(p1_triplet.gradient).vector, np.array([0, 5, 5])
229
+ )
230
+ np.testing.assert_allclose(
231
+ pm.eval_scalar(p1_triplet.function_value).vector, np.array([5, 5])
232
+ )
233
+
234
+ np.testing.assert_allclose(
235
+ pm.eval_point(p1_triplet_1.gradient).vector, np.array([0, 5, 5])
236
+ )
237
+ np.testing.assert_allclose(
238
+ pm.eval_scalar(p1_triplet_1.function_value).vector, np.array([5, 5])
239
+ )
240
+
241
+
242
+ def test_function_add_stationary_point(pep_context: pc.PEPContext):
243
+ f = fc.Function(is_basis=True)
244
+ f.add_tag("f")
245
+ x_opt = f.add_stationary_point("x_opt")
246
+
247
+ pm = exm.ExpressionManager(pep_context)
248
+
249
+ np.testing.assert_allclose(pm.eval_point(x_opt).vector, np.array([1]))
250
+
251
+
252
+ def test_smooth_interpolability_constraints(pep_context: pc.PEPContext):
253
+ f = fc.SmoothConvexFunction(L=1)
254
+ f.add_tag("f")
255
+ _ = f.add_stationary_point("x_opt")
256
+
257
+ x_0 = point.Point(is_basis=True)
258
+ x_0.add_tag("x_0")
259
+ _ = f.generate_triplet(x_0)
260
+
261
+ all_interpolation_constraints = f.get_interpolation_constraints()
262
+
263
+ pm = exm.ExpressionManager(pep_context)
264
+
265
+ np.testing.assert_allclose(
266
+ pm.eval_scalar(all_interpolation_constraints[1].scalar).vector, [1, -1]
267
+ )
268
+ np.testing.assert_allclose(
269
+ pm.eval_scalar(all_interpolation_constraints[1].scalar).matrix,
270
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.5]],
271
+ )
272
+
273
+ np.testing.assert_allclose(
274
+ pm.eval_scalar(all_interpolation_constraints[1].scalar).constant, 0
275
+ )
@@ -22,6 +22,7 @@ from __future__ import annotations
22
22
  import json
23
23
  from typing import TYPE_CHECKING
24
24
 
25
+ import attrs
25
26
  import dash
26
27
  import dash_bootstrap_components as dbc
27
28
  import numpy as np
@@ -29,11 +30,12 @@ import pandas as pd
29
30
  import plotly
30
31
  import plotly.express as px
31
32
  import plotly.graph_objects as go
32
- from dash import Dash, Input, Output, State, dcc, html
33
+ from dash import ALL, MATCH, Dash, Input, Output, State, dcc, html
33
34
 
34
35
  from pepflow.constants import PSD_CONSTRAINT
35
36
 
36
37
  if TYPE_CHECKING:
38
+ from pepflow.function import Function
37
39
  from pepflow.pep import PEPBuilder, PEPResult
38
40
  from pepflow.pep_context import PEPContext
39
41
 
@@ -42,26 +44,84 @@ plotly.io.renderers.default = "colab+vscode"
42
44
  plotly.io.templates.default = "plotly_white"
43
45
 
44
46
 
47
+ @attrs.frozen
48
+ class PlotData:
49
+ dataframe: pd.DataFrame
50
+ figure: go.Figure
51
+ function: Function
52
+
53
+ def dual_matrix_to_tab(self) -> html.Pre:
54
+ def get_matrix_of_dual_value(df: pd.DataFrame) -> np.ndarray:
55
+ # Check if we need to update the order.
56
+ return (
57
+ pd.pivot_table(
58
+ df, values="dual_value", index="row", columns="col", dropna=False
59
+ )
60
+ .fillna(0.0)
61
+ .to_numpy()
62
+ .T
63
+ )
64
+
65
+ with np.printoptions(precision=3, linewidth=100, suppress=True):
66
+ dual_value_tab = html.Pre(
67
+ str(get_matrix_of_dual_value(self.dataframe)),
68
+ id={"type": "dual-value-display", "index": self.function.tag},
69
+ style={
70
+ "border": "1px solid lightgrey",
71
+ "padding": "10px",
72
+ "height": "60vh",
73
+ "overflowY": "auto",
74
+ },
75
+ )
76
+ return dual_value_tab
77
+
78
+ def plot_data_to_tab(self) -> dbc.Tab:
79
+ tab = dbc.Tab(
80
+ html.Div(
81
+ [
82
+ html.P("Interactive Heat Map:"),
83
+ dcc.Graph(
84
+ id={
85
+ "type": "interactive-scatter",
86
+ "index": self.function.tag,
87
+ },
88
+ figure=self.figure,
89
+ ),
90
+ html.P("Dual Value Matrix:"),
91
+ self.dual_matrix_to_tab(),
92
+ ]
93
+ ),
94
+ label=f"{self.function.tag}-Interpolation Conditions",
95
+ tab_id=f"{self.function.tag}-interactive-scatter-tab",
96
+ )
97
+ return tab
98
+
99
+
45
100
  def solve_prob_and_get_figure(
46
101
  pep_builder: PEPBuilder, context: PEPContext
47
- ) -> tuple[go.Figure, pd.DataFrame, PEPResult]:
48
- assert len(context.triplets) == 1, "Support single function only for now"
102
+ ) -> tuple[list[PlotData], PEPResult]:
103
+ plot_data_list: list[PlotData] = []
49
104
 
50
105
  result = pep_builder.solve(context=context)
51
106
 
52
107
  df_dict, order_dict = context.triplets_to_df_and_order()
53
- f = pep_builder.functions[0]
54
108
 
55
- df = df_dict[f]
56
- order = order_dict[f]
109
+ for f in context.triplets.keys():
110
+ df = df_dict[f]
111
+ order = order_dict[f]
57
112
 
58
- df["constraint"] = df.constraint_name.map(
59
- lambda x: "inactive" if x in pep_builder.relaxed_constraints else "active"
60
- )
61
- df["dual_value"] = df.constraint_name.map(
62
- lambda x: result.dual_var_manager.dual_value(x)
63
- )
64
- return processed_df_to_fig(df, order), df, result
113
+ df["constraint"] = df.constraint_name.map(
114
+ lambda x: "inactive" if x in pep_builder.relaxed_constraints else "active"
115
+ )
116
+ df["dual_value"] = df.constraint_name.map(
117
+ lambda x: result.dual_var_manager.dual_value(x)
118
+ )
119
+
120
+ plot_data_list.append(
121
+ PlotData(dataframe=df, figure=processed_df_to_fig(df, order), function=f)
122
+ )
123
+
124
+ return plot_data_list, result
65
125
 
66
126
 
67
127
  def processed_df_to_fig(df: pd.DataFrame, order: list[str]):
@@ -99,19 +159,7 @@ def get_matrix_of_dual_value(df: pd.DataFrame) -> np.ndarray:
99
159
 
100
160
  def launch(pep_builder: PEPBuilder, context: PEPContext):
101
161
  app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
102
- fig, df, result = solve_prob_and_get_figure(pep_builder, context)
103
-
104
- with np.printoptions(precision=3, linewidth=100, suppress=True):
105
- dual_value_tab = html.Pre(
106
- str(get_matrix_of_dual_value(df)),
107
- id="dual-value-display",
108
- style={
109
- "border": "1px solid lightgrey",
110
- "padding": "10px",
111
- "height": "60vh",
112
- "overflowY": "auto",
113
- },
114
- )
162
+ plot_data_list, result = solve_prob_and_get_figure(pep_builder, context)
115
163
  # Think how can we manipulate the pep_builder here.
116
164
  display_row = dbc.Row(
117
165
  [
@@ -129,21 +177,8 @@ def launch(pep_builder: PEPBuilder, context: PEPContext):
129
177
  color="success",
130
178
  ),
131
179
  dbc.Tabs(
132
- [
133
- dbc.Tab(
134
- html.Div(
135
- dcc.Graph(id="interactive-scatter", figure=fig),
136
- ),
137
- label="Interactive Heatmap",
138
- tab_id="interactive-scatter-tab",
139
- ),
140
- dbc.Tab(
141
- dual_value_tab,
142
- label="Dual Value Matrix",
143
- tab_id="dual_value_tab",
144
- ),
145
- ],
146
- active_tab="interactive-scatter-tab",
180
+ [plot_data.plot_data_to_tab() for plot_data in plot_data_list],
181
+ active_tab=f"{plot_data_list[0].function.tag}-interactive-scatter-tab",
147
182
  ),
148
183
  ],
149
184
  width=5,
@@ -175,25 +210,37 @@ def launch(pep_builder: PEPBuilder, context: PEPContext):
175
210
  [
176
211
  html.H2("PEPFlow"),
177
212
  display_row,
178
- # Store the entire DataFrame as a dictionary in dcc.Store
179
- dcc.Store(id="dataframe-store", data=df.to_dict("records")),
213
+ # For each function, store the corresponding DataFrame as a dictionary in dcc.Store
214
+ *[
215
+ dcc.Store(
216
+ id={
217
+ "type": "dataframe-store",
218
+ "index": plot_data.function.tag,
219
+ },
220
+ data=(
221
+ plot_data.function.tag,
222
+ plot_data.dataframe.to_dict("records"),
223
+ ),
224
+ )
225
+ for plot_data in plot_data_list
226
+ ],
180
227
  ]
181
228
  )
182
229
 
183
230
  @dash.callback(
184
231
  Output("result-card", "children"),
185
- Output("dual-value-display", "children"),
186
- Output("interactive-scatter", "figure"),
187
- Output("dataframe-store", "data"),
232
+ Output({"type": "dual-value-display", "index": ALL}, "children"),
233
+ Output({"type": "interactive-scatter", "index": ALL}, "figure"),
234
+ Output({"type": "dataframe-store", "index": ALL}, "data"),
188
235
  Input("solve-button", "n_clicks"),
189
236
  )
190
237
  def solve(_):
191
- fig, df, result = solve_prob_and_get_figure(pep_builder, context)
238
+ plot_data_list, result = solve_prob_and_get_figure(pep_builder, context)
192
239
  with np.printoptions(precision=3, linewidth=100, suppress=True):
193
240
  psd_dual_value = np.array(
194
241
  result.dual_var_manager.dual_value(PSD_CONSTRAINT)
195
242
  )
196
- reslt_card = dbc.CardBody(
243
+ result_card = dbc.CardBody(
197
244
  [
198
245
  html.H2(f"Optimal Value {result.primal_opt_value:.4g}"),
199
246
  html.H3(f"Solver Status: {result.solver_status}"),
@@ -203,50 +250,89 @@ def launch(pep_builder: PEPBuilder, context: PEPContext):
203
250
  html.Pre(json.dumps(pep_builder.relaxed_constraints, indent=2)),
204
251
  ]
205
252
  )
206
- dual_value_display = str(get_matrix_of_dual_value(df))
207
- return reslt_card, dual_value_display, fig, df.to_dict("records")
253
+ dual_value_displays = [
254
+ str(get_matrix_of_dual_value(plot_data.dataframe))
255
+ for plot_data in plot_data_list
256
+ ]
257
+ figs = [plot_data.figure for plot_data in plot_data_list]
258
+
259
+ df_data = [
260
+ (plot_data.function.tag, plot_data.dataframe.to_dict("records"))
261
+ for plot_data in plot_data_list
262
+ ]
263
+
264
+ return result_card, dual_value_displays, figs, df_data
208
265
 
209
266
  @dash.callback(
210
- Output("interactive-scatter", "figure", allow_duplicate=True),
211
- Output("dataframe-store", "data", allow_duplicate=True),
267
+ Output(
268
+ {"type": "interactive-scatter", "index": ALL},
269
+ "figure",
270
+ allow_duplicate=True,
271
+ ),
272
+ Output({"type": "dataframe-store", "index": ALL}, "data", allow_duplicate=True),
212
273
  Input("restore-all-constraints-button", "n_clicks"),
213
- State("dataframe-store", "data"),
274
+ State({"type": "dataframe-store", "index": ALL}, "data"),
214
275
  prevent_initial_call=True,
215
276
  )
216
- def restore_all_constraints(_, previous_df):
277
+ def restore_all_constraints(_, list_previous_df_tuples):
217
278
  nonlocal pep_builder
218
- df_updated = pd.DataFrame(previous_df)
219
279
  pep_builder.relaxed_constraints = []
220
- df_updated["constraint"] = "active"
221
- order = context.order_of_point(pep_builder.functions[0])
222
- return processed_df_to_fig(df_updated, order), df_updated.to_dict("records")
280
+ updated_figs = []
281
+ df_data = []
282
+ for previous_df_tuple in list_previous_df_tuples:
283
+ tag, previous_df = previous_df_tuple
284
+ df_updated = pd.DataFrame(previous_df)
285
+ df_updated["constraint"] = "active"
286
+ order = context.order_of_point(pep_builder.get_func_by_tag(tag))
287
+ updated_figs.append(processed_df_to_fig(df_updated, order))
288
+ df_data.append((tag, df_updated.to_dict("records")))
289
+ return updated_figs, df_data
223
290
 
224
291
  @dash.callback(
225
- Output("interactive-scatter", "figure", allow_duplicate=True),
226
- Output("dataframe-store", "data", allow_duplicate=True),
292
+ Output(
293
+ {"type": "interactive-scatter", "index": ALL},
294
+ "figure",
295
+ allow_duplicate=True,
296
+ ),
297
+ Output({"type": "dataframe-store", "index": ALL}, "data", allow_duplicate=True),
227
298
  Input("relax-all-constraints-button", "n_clicks"),
228
- State("dataframe-store", "data"),
299
+ State({"type": "dataframe-store", "index": ALL}, "data"),
229
300
  prevent_initial_call=True,
230
301
  )
231
- def relax_all_constraints(_, previous_df):
302
+ def relax_all_constraints(_, list_previous_df_tuples):
232
303
  nonlocal pep_builder
233
- df_updated = pd.DataFrame(previous_df)
234
- pep_builder.relaxed_constraints = df_updated["constraint_name"].to_list()
235
- df_updated["constraint"] = "inactive"
236
- order = context.order_of_point(pep_builder.functions[0])
237
- return processed_df_to_fig(df_updated, order), df_updated.to_dict("records")
304
+ pep_builder.relaxed_constraints = []
305
+ updated_figs = []
306
+ df_data = []
307
+ for previous_df_tuple in list_previous_df_tuples:
308
+ tag, previous_df = previous_df_tuple
309
+ df_updated = pd.DataFrame(previous_df)
310
+ pep_builder.relaxed_constraints.extend(
311
+ df_updated["constraint_name"].to_list()
312
+ )
313
+ df_updated["constraint"] = "inactive"
314
+ order = context.order_of_point(pep_builder.get_func_by_tag(tag))
315
+ updated_figs.append(processed_df_to_fig(df_updated, order))
316
+ df_data.append((tag, df_updated.to_dict("records")))
317
+ return updated_figs, df_data
238
318
 
239
319
  @dash.callback(
240
- Output("interactive-scatter", "figure", allow_duplicate=True),
241
- Output("dataframe-store", "data", allow_duplicate=True),
242
- Input("interactive-scatter", "clickData"),
243
- State("dataframe-store", "data"),
320
+ Output(
321
+ {"type": "interactive-scatter", "index": MATCH},
322
+ "figure",
323
+ allow_duplicate=True,
324
+ ),
325
+ Output(
326
+ {"type": "dataframe-store", "index": MATCH}, "data", allow_duplicate=True
327
+ ),
328
+ Input({"type": "interactive-scatter", "index": MATCH}, "clickData"),
329
+ State({"type": "dataframe-store", "index": MATCH}, "data"),
244
330
  prevent_initial_call=True,
245
331
  )
246
- def update_df_and_redraw(clickData, previous_df):
332
+ def update_df_and_redraw(clickData, previous_df_tuple):
247
333
  nonlocal pep_builder
248
334
  if not clickData["points"][0]["customdata"]:
249
- return dash.no_update, dash.no_update, dash.no_update
335
+ return dash.no_update, dash.no_update
250
336
 
251
337
  clicked_name = clickData["points"][0]["customdata"][0]
252
338
  if clicked_name not in pep_builder.relaxed_constraints:
@@ -254,11 +340,15 @@ def launch(pep_builder: PEPBuilder, context: PEPContext):
254
340
  else:
255
341
  pep_builder.relaxed_constraints.remove(clicked_name)
256
342
 
343
+ tag, previous_df = previous_df_tuple
257
344
  df_updated = pd.DataFrame(previous_df)
258
345
  df_updated["constraint"] = df_updated.constraint_name.map(
259
346
  lambda x: "inactive" if x in pep_builder.relaxed_constraints else "active"
260
347
  )
261
- order = context.order_of_point(pep_builder.functions[0])
262
- return processed_df_to_fig(df_updated, order), df_updated.to_dict("records")
348
+ order = context.order_of_point(pep_builder.get_func_by_tag(tag))
349
+ return processed_df_to_fig(df_updated, order), (
350
+ tag,
351
+ df_updated.to_dict("records"),
352
+ )
263
353
 
264
354
  app.run(debug=True)