egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,303 @@
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import altair as alt
5
+ from scipy.stats import beta, gaussian_kde
6
+
7
+ # Register custom font for PNG export
8
+ alt.themes.register('custom_theme', lambda: {
9
+ 'config': {
10
+ 'title': {'font': 'Produkt'},
11
+ 'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
12
+ 'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
13
+ 'mark': {'font': 'Produkt'},
14
+ 'text': {'font': 'Produkt'},
15
+ }
16
+ })
17
+ alt.themes.enable('custom_theme')
18
+
19
+ alt.renderers.enable("colab")
20
+
21
+ BASE_DIR = "logs"
22
+ REWARD_THRESHOLD = 0.03
23
+ SHOW_LABELS = 0 # Set to True to show percentage labels above points
24
+
25
+ MODELS = {
26
+ "CAP + Gemini-ER": "rum/gemini",
27
+ "π-0.5": "pi0",
28
+ "CAP + Oracle": "rum/oracle",
29
+ "CAP + Molmo": "rum/molmo",
30
+ "CAP + Moondream": "rum/moondream",
31
+ }
32
+
33
+ NUM_OBJECTS = [1, 2, 3, 4, 5]
34
+
35
+
36
+ def compute_success_from_csv(csv_path):
37
+ df = pd.read_csv(csv_path, sep="\t")
38
+ successes = (df["max_reward"] > REWARD_THRESHOLD).sum()
39
+ total = len(df)
40
+ return successes, total
41
+
42
+
43
+ def plot_bayesian_success_by_objects():
44
+ rows = []
45
+
46
+ for model_name, model_folder in MODELS.items():
47
+ for n_obj in NUM_OBJECTS:
48
+ possible_folders = [
49
+ f"{n_obj}_objects",
50
+ f"{n_obj}_object",
51
+ f"{n_obj}-objects",
52
+ f"{n_obj}-object",
53
+ ]
54
+
55
+ csv_path = None
56
+ for folder in possible_folders:
57
+ # First try direct path
58
+ candidate = os.path.join(
59
+ BASE_DIR, model_folder, folder, "log.csv"
60
+ )
61
+ if os.path.exists(candidate):
62
+ csv_path = candidate
63
+ break
64
+
65
+ # Try nested evaluation folder structure
66
+ folder_path = os.path.join(BASE_DIR, model_folder, folder)
67
+ if os.path.exists(folder_path) and os.path.isdir(folder_path):
68
+ for subdir in os.listdir(folder_path):
69
+ subdir_path = os.path.join(folder_path, subdir)
70
+ if os.path.isdir(subdir_path):
71
+ candidate = os.path.join(subdir_path, "log.csv")
72
+ if os.path.exists(candidate):
73
+ csv_path = candidate
74
+ break
75
+ if csv_path:
76
+ break
77
+
78
+ if csv_path is None:
79
+ print(f"Missing data: {model_name}, {n_obj} objects")
80
+ continue
81
+
82
+ s, t = compute_success_from_csv(csv_path)
83
+
84
+ # Beta posterior
85
+ a, b = 1 + s, 1 + (t - s)
86
+ mean = a / (a + b)
87
+ lo = beta.ppf(0.025, a, b)
88
+ hi = beta.ppf(0.975, a, b)
89
+
90
+ print(f"{model_name} | {n_obj} objects: {s}/{t} = {mean:.2%}")
91
+
92
+ rows.append({
93
+ "model": model_name,
94
+ "num_objects": n_obj,
95
+ "num_distractors": n_obj - 1,
96
+ "mean": mean,
97
+ "lo": lo,
98
+ "hi": hi,
99
+ "err": hi - mean,
100
+ "alpha": a,
101
+ "beta": b,
102
+ })
103
+
104
+ df = pd.DataFrame(rows)
105
+
106
+ # Normalize each model's success rates relative to 0 distractors (1 object)
107
+ normalized_rows = []
108
+ for model_name in df["model"].unique():
109
+ model_df = df[df["model"] == model_name].copy()
110
+ # Get the baseline (0 distractors = 1 object)
111
+ baseline_row = model_df[model_df["num_distractors"] == 0]
112
+ if len(baseline_row) == 0:
113
+ print(f"Warning: No 0 distractor data for {model_name}")
114
+ continue
115
+ baseline_mean = baseline_row.iloc[0]["mean"]
116
+ if baseline_mean == 0:
117
+ print(f"Warning: Baseline is 0 for {model_name}, skipping normalization")
118
+ continue
119
+
120
+ # Store baseline for violin normalization
121
+ model_df["baseline_mean"] = baseline_mean
122
+
123
+ # Normalize all metrics
124
+ model_df["mean"] = model_df["mean"] / baseline_mean
125
+ model_df["lo"] = model_df["lo"] / baseline_mean
126
+ model_df["hi"] = model_df["hi"] / baseline_mean
127
+ model_df["err"] = model_df["hi"] - model_df["mean"]
128
+ normalized_rows.append(model_df)
129
+
130
+ df = pd.concat(normalized_rows, ignore_index=True)
131
+
132
+ # Calculate average success rate per model and sort by descending order
133
+ model_means = df.groupby("model")["mean"].mean().sort_values(ascending=False)
134
+ sorted_models = model_means.index.tolist()
135
+
136
+ # Create color mapping maintaining original colors
137
+ color_map = {
138
+ "CAP + Gemini-ER": "#6095e1",
139
+ "π-0.5": "#E0BE16",
140
+ "CAP + Oracle": "#A918C0",
141
+ "CAP + Molmo": "#F0529c",
142
+ "CAP + Moondream": "#000000",
143
+ "CAP + 4o": "#FF6B35"
144
+ }
145
+
146
+ # Ensure all sorted models have colors
147
+ sorted_colors = []
148
+ for model in sorted_models:
149
+ sorted_colors.append(color_map[model])
150
+
151
+ print(f"Sorted models: {sorted_models}")
152
+ print(f"Sorted colors: {sorted_colors}")
153
+
154
+ base = alt.Chart(df).encode(
155
+ x=alt.X(
156
+ "num_distractors:Q",
157
+ title="Number of Distractor Objects",
158
+ axis=alt.Axis(labelFontSize=20, titleFontSize=24, labelAngle=0, titlePadding=15, tickMinStep=1),
159
+ scale=alt.Scale(domain=[0, 4.5])
160
+ ),
161
+ color=alt.Color(
162
+ "model:N",
163
+ scale=alt.Scale(
164
+ domain=sorted_models,
165
+ range=sorted_colors
166
+ ),
167
+ legend=alt.Legend(
168
+ title=None,
169
+ labelFontSize=18,
170
+ symbolSize=200,
171
+ orient="bottom",
172
+ direction="horizontal",
173
+ symbolType="square",
174
+ labelLimit=0
175
+ )
176
+ )
177
+ )
178
+
179
+ lines = base.mark_line(
180
+ strokeWidth=3,
181
+ point=alt.OverlayMarkDef(size=100)
182
+ ).encode(
183
+ y=alt.Y(
184
+ "mean:Q",
185
+ title="Relative Success Rate (%)",
186
+ scale=alt.Scale(domain=[0.6, 1.0]),
187
+ axis=alt.Axis(
188
+ values=[0.6, 0.7, 0.8, 0.9, 1.0],
189
+ labelExpr="datum.value * 100",
190
+ labelFontSize=20,
191
+ titleFontSize=24,
192
+ titlePadding=14,
193
+ grid=False
194
+ )
195
+ )
196
+ )
197
+
198
+ # Create violin plots only for non-zero distractors
199
+ df_with_violins = df[df["num_distractors"] != 0]
200
+
201
+ # Set random seed for reproducibility
202
+ np.random.seed(42)
203
+
204
+ # Create violin polygons
205
+ violin_width_scale = 0.08 # Width of violins
206
+ violin_polygon_data = []
207
+
208
+ for _, row in df_with_violins.iterrows():
209
+ # Generate samples from beta distribution (these are in 0-1 range)
210
+ samples = beta.rvs(row["alpha"], row["beta"], size=2000)
211
+ # Normalize samples using the stored baseline_mean
212
+ baseline_mean = row["baseline_mean"]
213
+ if baseline_mean > 0:
214
+ samples = samples / baseline_mean
215
+
216
+ # Compute KDE
217
+ kde = gaussian_kde(samples)
218
+
219
+ # Create density curve points
220
+ y_points = np.linspace(samples.min(), samples.max(), 100)
221
+ densities = kde(y_points)
222
+ densities = densities / densities.max() * violin_width_scale
223
+
224
+ xc = row["num_distractors"]
225
+ model = row["model"]
226
+ color = sorted_colors[sorted_models.index(model)]
227
+
228
+ # Create closed polygon: left side up, then right side down
229
+ for i, (y, d) in enumerate(zip(y_points, densities)):
230
+ violin_polygon_data.append({
231
+ "model": model,
232
+ "x": xc - d,
233
+ "y": y,
234
+ "order": i,
235
+ "color": color,
236
+ "group": f"{model}_{xc}"
237
+ })
238
+ # Right side going back down
239
+ for i, (y, d) in enumerate(zip(y_points[::-1], densities[::-1])):
240
+ violin_polygon_data.append({
241
+ "model": model,
242
+ "x": xc + d,
243
+ "y": y,
244
+ "order": len(y_points) + i,
245
+ "color": color,
246
+ "group": f"{model}_{xc}"
247
+ })
248
+
249
+ violin_polygon_df = pd.DataFrame(violin_polygon_data)
250
+
251
+ # Create violin shapes - now uses same quantitative scale as base
252
+ violins = alt.Chart(violin_polygon_df).mark_line(
253
+ fillOpacity=0.4,
254
+ strokeWidth=0.5,
255
+ interpolate='linear',
256
+ filled=True
257
+ ).encode(
258
+ x=alt.X('x:Q', scale=alt.Scale(domain=[0, 4.5]), title=''),
259
+ y=alt.Y('y:Q', scale=alt.Scale(domain=[0.6, 1.0]), title=''),
260
+ order='order:Q',
261
+ detail='group:N',
262
+ fill=alt.Fill('color:N', scale=None, legend=None),
263
+ stroke=alt.Stroke('color:N', scale=None, legend=None)
264
+ )
265
+
266
+
267
+ labels = alt.Chart(df).mark_text(
268
+ dy=-8,
269
+ fontSize=14,
270
+ color="black"
271
+ ).encode(
272
+ x=alt.X("num_distractors:O"),
273
+ y="hi:Q",
274
+ text=alt.Text("mean:Q", format=".0%")
275
+ )
276
+
277
+ # Combine layers based on SHOW_LABELS setting
278
+ if SHOW_LABELS:
279
+ chart = violins + lines + labels
280
+ else:
281
+ chart = violins + lines
282
+
283
+ return (
284
+ chart
285
+ .properties(
286
+ width=720,
287
+ height=420,
288
+ title={
289
+ "text": " EgoGym Pick Relative Success Rate vs Number of Distractors",
290
+ "fontSize": 22,
291
+ "anchor": "start",
292
+ "dx": 65,
293
+ "offset": 20
294
+ },
295
+ padding={"left": 10, "right": 10, "top": 40, "bottom": 40}
296
+ )
297
+ .configure_view(stroke=None)
298
+ )
299
+
300
+
301
+ chart = plot_bayesian_success_by_objects()
302
+ chart.save("success_by_objects.png", scale_factor=3)
303
+ chart.save("success_by_objects.pdf", scale_factor=3)
@@ -0,0 +1,285 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import altair as alt
4
+ from scipy.stats import beta, gaussian_kde
5
+
6
+ # Register custom font for PNG export
7
+ alt.themes.register('custom_theme', lambda: {
8
+ 'config': {
9
+ 'title': {'font': 'Produkt'},
10
+ 'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
11
+ 'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
12
+ 'mark': {'font': 'Produkt'},
13
+ 'text': {'font': 'Produkt'},
14
+ }
15
+ })
16
+ alt.themes.enable('custom_theme')
17
+
18
+ alt.renderers.enable("colab")
19
+
20
+ SHOW_LABELS = 0 # Set to True to show percentage labels above points
21
+
22
+ # Hardcoded data: {model: {num_objects: (successes, total)}}
23
+ HARDCODED_DATA = {
24
+ "CAP + Gemini-ER": {
25
+ 1: (3706, 5000),
26
+ 2: (3576, 5000),
27
+ 3: (3548, 5000),
28
+ 4: (3482, 5000),
29
+ 5: (3441, 5000),
30
+ },
31
+ "π-0.5": {
32
+ 1: (1478, 5000),
33
+ 2: (1366, 5000),
34
+ 3: (1221, 5000),
35
+ 4: (1058, 5000),
36
+ 5: (1046, 5000),
37
+ },
38
+ "CAP + Oracle": {
39
+ 1: (4034, 5000),
40
+ 2: (3962, 5000),
41
+ 3: (4000, 5000),
42
+ 4: (3967, 5000),
43
+ 5: (3994, 5000),
44
+ },
45
+ "CAP + Molmo": {
46
+ 1: (3672, 5000),
47
+ 2: (3350, 5000),
48
+ 3: (3213, 5000),
49
+ 4: (3069, 5000),
50
+ 5: (3008, 5000),
51
+ },
52
+ "CAP + Moondream": {
53
+ 1: (3407, 5000),
54
+ 2: (3254, 5000),
55
+ 3: (3184, 5000),
56
+ 4: (3044, 5000),
57
+ 5: (3033, 5000),
58
+ },
59
+ }
60
+
61
+
62
+ def plot_bayesian_success_by_objects():
63
+ rows = []
64
+
65
+ for model_name, data_dict in HARDCODED_DATA.items():
66
+ for n_obj, (s, t) in data_dict.items():
67
+ # Beta posterior
68
+ a, b = 1 + s, 1 + (t - s)
69
+ mean = a / (a + b)
70
+ lo = beta.ppf(0.025, a, b)
71
+ hi = beta.ppf(0.975, a, b)
72
+
73
+ print(f"{model_name} | {n_obj} objects: {s}/{t} = {mean:.2%}")
74
+
75
+ rows.append({
76
+ "model": model_name,
77
+ "num_objects": n_obj,
78
+ "num_distractors": n_obj - 1,
79
+ "mean": mean,
80
+ "lo": lo,
81
+ "hi": hi,
82
+ "err": hi - mean,
83
+ "alpha": a,
84
+ "beta": b,
85
+ })
86
+
87
+ df = pd.DataFrame(rows)
88
+
89
+ # Normalize each model's success rates relative to 0 distractors (1 object)
90
+ normalized_rows = []
91
+ for model_name in df["model"].unique():
92
+ model_df = df[df["model"] == model_name].copy()
93
+ # Get the baseline (0 distractors = 1 object)
94
+ baseline_row = model_df[model_df["num_distractors"] == 0]
95
+ if len(baseline_row) == 0:
96
+ print(f"Warning: No 0 distractor data for {model_name}")
97
+ continue
98
+ baseline_mean = baseline_row.iloc[0]["mean"]
99
+ if baseline_mean == 0:
100
+ print(f"Warning: Baseline is 0 for {model_name}, skipping normalization")
101
+ continue
102
+
103
+ # Store baseline for violin normalization
104
+ model_df["baseline_mean"] = baseline_mean
105
+
106
+ # Normalize all metrics
107
+ model_df["mean"] = model_df["mean"] / baseline_mean
108
+ model_df["lo"] = model_df["lo"] / baseline_mean
109
+ model_df["hi"] = model_df["hi"] / baseline_mean
110
+ model_df["err"] = model_df["hi"] - model_df["mean"]
111
+ normalized_rows.append(model_df)
112
+
113
+ df = pd.concat(normalized_rows, ignore_index=True)
114
+
115
+ # Calculate average success rate per model and sort by descending order
116
+ model_means = df.groupby("model")["mean"].mean().sort_values(ascending=False)
117
+ sorted_models = model_means.index.tolist()
118
+
119
+ # Create color mapping maintaining original colors
120
+ color_map = {
121
+ "CAP + Gemini-ER": "#6095e1",
122
+ "π-0.5": "#E0BE16",
123
+ "CAP + Oracle": "#A918C0",
124
+ "CAP + Molmo": "#F0529c",
125
+ "CAP + Moondream": "#000000",
126
+ "CAP + 4o": "#FF6B35"
127
+ }
128
+
129
+ # Ensure all sorted models have colors
130
+ sorted_colors = []
131
+ for model in sorted_models:
132
+ sorted_colors.append(color_map[model])
133
+
134
+ print(f"Sorted models: {sorted_models}")
135
+ print(f"Sorted colors: {sorted_colors}")
136
+
137
+ base = alt.Chart(df).encode(
138
+ x=alt.X(
139
+ "num_distractors:Q",
140
+ title="Number of Distractor Objects",
141
+ axis=alt.Axis(labelFontSize=20, titleFontSize=24, labelAngle=0, titlePadding=15, tickMinStep=1),
142
+ scale=alt.Scale(domain=[0, 4.5])
143
+ ),
144
+ color=alt.Color(
145
+ "model:N",
146
+ scale=alt.Scale(
147
+ domain=sorted_models,
148
+ range=sorted_colors
149
+ ),
150
+ legend=alt.Legend(
151
+ title=None,
152
+ labelFontSize=18,
153
+ symbolSize=200,
154
+ orient="bottom",
155
+ direction="horizontal",
156
+ symbolType="square",
157
+ labelLimit=0
158
+ )
159
+ )
160
+ )
161
+
162
+ lines = base.mark_line(
163
+ strokeWidth=3,
164
+ point=alt.OverlayMarkDef(size=100)
165
+ ).encode(
166
+ y=alt.Y(
167
+ "mean:Q",
168
+ title="Relative Success Rate (%)",
169
+ scale=alt.Scale(domain=[0.6, 1.0]),
170
+ axis=alt.Axis(
171
+ format=".0%",
172
+ values=[0.6, 0.7, 0.8, 0.9, 1.0],
173
+ labelFontSize=20,
174
+ titleFontSize=24,
175
+ titlePadding=14,
176
+ grid=False
177
+ )
178
+ )
179
+ )
180
+
181
+ # Create violin plots only for non-zero distractors
182
+ df_with_violins = df[df["num_distractors"] != 0]
183
+
184
+ # Set random seed for reproducibility
185
+ np.random.seed(42)
186
+
187
+ # Create violin polygons
188
+ violin_width_scale = 0.08 # Width of violins
189
+ violin_polygon_data = []
190
+
191
+ for _, row in df_with_violins.iterrows():
192
+ # Generate samples from beta distribution (these are in 0-1 range)
193
+ samples = beta.rvs(row["alpha"], row["beta"], size=2000)
194
+ # Normalize samples using the stored baseline_mean
195
+ baseline_mean = row["baseline_mean"]
196
+ if baseline_mean > 0:
197
+ samples = samples / baseline_mean
198
+
199
+ # Compute KDE
200
+ kde = gaussian_kde(samples)
201
+
202
+ # Create density curve points
203
+ y_points = np.linspace(samples.min(), samples.max(), 100)
204
+ densities = kde(y_points)
205
+ densities = densities / densities.max() * violin_width_scale
206
+
207
+ xc = row["num_distractors"]
208
+ model = row["model"]
209
+ color = sorted_colors[sorted_models.index(model)]
210
+
211
+ # Create closed polygon: left side up, then right side down
212
+ for i, (y, d) in enumerate(zip(y_points, densities)):
213
+ violin_polygon_data.append({
214
+ "model": model,
215
+ "x": xc - d,
216
+ "y": y,
217
+ "order": i,
218
+ "color": color,
219
+ "group": f"{model}_{xc}"
220
+ })
221
+ # Right side going back down
222
+ for i, (y, d) in enumerate(zip(y_points[::-1], densities[::-1])):
223
+ violin_polygon_data.append({
224
+ "model": model,
225
+ "x": xc + d,
226
+ "y": y,
227
+ "order": len(y_points) + i,
228
+ "color": color,
229
+ "group": f"{model}_{xc}"
230
+ })
231
+
232
+ violin_polygon_df = pd.DataFrame(violin_polygon_data)
233
+
234
+ # Create violin shapes - now uses same quantitative scale as base
235
+ violins = alt.Chart(violin_polygon_df).mark_line(
236
+ fillOpacity=0.4,
237
+ strokeWidth=0.5,
238
+ interpolate='linear',
239
+ filled=True
240
+ ).encode(
241
+ x=alt.X('x:Q', scale=alt.Scale(domain=[0, 4.5]), title=''),
242
+ y=alt.Y('y:Q', scale=alt.Scale(domain=[0.6, 1.0]), title=''),
243
+ order='order:Q',
244
+ detail='group:N',
245
+ fill=alt.Fill('color:N', scale=None, legend=None),
246
+ stroke=alt.Stroke('color:N', scale=None, legend=None)
247
+ )
248
+
249
+
250
+ labels = alt.Chart(df).mark_text(
251
+ dy=-8,
252
+ fontSize=14,
253
+ color="black"
254
+ ).encode(
255
+ x=alt.X("num_distractors:O"),
256
+ y="hi:Q",
257
+ text=alt.Text("mean:Q", format=".0%")
258
+ )
259
+
260
+ # Combine layers based on SHOW_LABELS setting
261
+ if SHOW_LABELS:
262
+ chart = violins + lines + labels
263
+ else:
264
+ chart = violins + lines
265
+
266
+ return (
267
+ chart
268
+ .properties(
269
+ width=720,
270
+ height=420,
271
+ title={
272
+ "text": " EgoGym Pick Relative Success Rate vs Number of Distractors",
273
+ "fontSize": 22,
274
+ "anchor": "start",
275
+ "dx": 40
276
+ },
277
+ padding={"left": 10, "right": 10, "top": 40, "bottom": 40}
278
+ )
279
+ .configure_view(stroke=None)
280
+ )
281
+
282
+
283
+ chart = plot_bayesian_success_by_objects()
284
+ chart.save("success_by_objects.png", scale_factor=3)
285
+ chart.save("success_by_objects.pdf", scale_factor=3)