egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import altair as alt
|
|
5
|
+
from scipy.stats import beta, gaussian_kde
|
|
6
|
+
|
|
7
|
+
# Register custom font for PNG export
|
|
8
|
+
alt.themes.register('custom_theme', lambda: {
|
|
9
|
+
'config': {
|
|
10
|
+
'title': {'font': 'Produkt'},
|
|
11
|
+
'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
12
|
+
'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
13
|
+
'mark': {'font': 'Produkt'},
|
|
14
|
+
'text': {'font': 'Produkt'},
|
|
15
|
+
}
|
|
16
|
+
})
|
|
17
|
+
alt.themes.enable('custom_theme')
|
|
18
|
+
|
|
19
|
+
alt.renderers.enable("colab")
|
|
20
|
+
|
|
21
|
+
BASE_DIR = "logs"
|
|
22
|
+
REWARD_THRESHOLD = 0.03
|
|
23
|
+
SHOW_LABELS = 0 # Set to True to show percentage labels above points
|
|
24
|
+
|
|
25
|
+
MODELS = {
|
|
26
|
+
"CAP + Gemini-ER": "rum/gemini",
|
|
27
|
+
"π-0.5": "pi0",
|
|
28
|
+
"CAP + Oracle": "rum/oracle",
|
|
29
|
+
"CAP + Molmo": "rum/molmo",
|
|
30
|
+
"CAP + Moondream": "rum/moondream",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
NUM_OBJECTS = [1, 2, 3, 4, 5]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def compute_success_from_csv(csv_path):
|
|
37
|
+
df = pd.read_csv(csv_path, sep="\t")
|
|
38
|
+
successes = (df["max_reward"] > REWARD_THRESHOLD).sum()
|
|
39
|
+
total = len(df)
|
|
40
|
+
return successes, total
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def plot_bayesian_success_by_objects():
|
|
44
|
+
rows = []
|
|
45
|
+
|
|
46
|
+
for model_name, model_folder in MODELS.items():
|
|
47
|
+
for n_obj in NUM_OBJECTS:
|
|
48
|
+
possible_folders = [
|
|
49
|
+
f"{n_obj}_objects",
|
|
50
|
+
f"{n_obj}_object",
|
|
51
|
+
f"{n_obj}-objects",
|
|
52
|
+
f"{n_obj}-object",
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
csv_path = None
|
|
56
|
+
for folder in possible_folders:
|
|
57
|
+
# First try direct path
|
|
58
|
+
candidate = os.path.join(
|
|
59
|
+
BASE_DIR, model_folder, folder, "log.csv"
|
|
60
|
+
)
|
|
61
|
+
if os.path.exists(candidate):
|
|
62
|
+
csv_path = candidate
|
|
63
|
+
break
|
|
64
|
+
|
|
65
|
+
# Try nested evaluation folder structure
|
|
66
|
+
folder_path = os.path.join(BASE_DIR, model_folder, folder)
|
|
67
|
+
if os.path.exists(folder_path) and os.path.isdir(folder_path):
|
|
68
|
+
for subdir in os.listdir(folder_path):
|
|
69
|
+
subdir_path = os.path.join(folder_path, subdir)
|
|
70
|
+
if os.path.isdir(subdir_path):
|
|
71
|
+
candidate = os.path.join(subdir_path, "log.csv")
|
|
72
|
+
if os.path.exists(candidate):
|
|
73
|
+
csv_path = candidate
|
|
74
|
+
break
|
|
75
|
+
if csv_path:
|
|
76
|
+
break
|
|
77
|
+
|
|
78
|
+
if csv_path is None:
|
|
79
|
+
print(f"Missing data: {model_name}, {n_obj} objects")
|
|
80
|
+
continue
|
|
81
|
+
|
|
82
|
+
s, t = compute_success_from_csv(csv_path)
|
|
83
|
+
|
|
84
|
+
# Beta posterior
|
|
85
|
+
a, b = 1 + s, 1 + (t - s)
|
|
86
|
+
mean = a / (a + b)
|
|
87
|
+
lo = beta.ppf(0.025, a, b)
|
|
88
|
+
hi = beta.ppf(0.975, a, b)
|
|
89
|
+
|
|
90
|
+
print(f"{model_name} | {n_obj} objects: {s}/{t} = {mean:.2%}")
|
|
91
|
+
|
|
92
|
+
rows.append({
|
|
93
|
+
"model": model_name,
|
|
94
|
+
"num_objects": n_obj,
|
|
95
|
+
"num_distractors": n_obj - 1,
|
|
96
|
+
"mean": mean,
|
|
97
|
+
"lo": lo,
|
|
98
|
+
"hi": hi,
|
|
99
|
+
"err": hi - mean,
|
|
100
|
+
"alpha": a,
|
|
101
|
+
"beta": b,
|
|
102
|
+
})
|
|
103
|
+
|
|
104
|
+
df = pd.DataFrame(rows)
|
|
105
|
+
|
|
106
|
+
# Normalize each model's success rates relative to 0 distractors (1 object)
|
|
107
|
+
normalized_rows = []
|
|
108
|
+
for model_name in df["model"].unique():
|
|
109
|
+
model_df = df[df["model"] == model_name].copy()
|
|
110
|
+
# Get the baseline (0 distractors = 1 object)
|
|
111
|
+
baseline_row = model_df[model_df["num_distractors"] == 0]
|
|
112
|
+
if len(baseline_row) == 0:
|
|
113
|
+
print(f"Warning: No 0 distractor data for {model_name}")
|
|
114
|
+
continue
|
|
115
|
+
baseline_mean = baseline_row.iloc[0]["mean"]
|
|
116
|
+
if baseline_mean == 0:
|
|
117
|
+
print(f"Warning: Baseline is 0 for {model_name}, skipping normalization")
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
# Store baseline for violin normalization
|
|
121
|
+
model_df["baseline_mean"] = baseline_mean
|
|
122
|
+
|
|
123
|
+
# Normalize all metrics
|
|
124
|
+
model_df["mean"] = model_df["mean"] / baseline_mean
|
|
125
|
+
model_df["lo"] = model_df["lo"] / baseline_mean
|
|
126
|
+
model_df["hi"] = model_df["hi"] / baseline_mean
|
|
127
|
+
model_df["err"] = model_df["hi"] - model_df["mean"]
|
|
128
|
+
normalized_rows.append(model_df)
|
|
129
|
+
|
|
130
|
+
df = pd.concat(normalized_rows, ignore_index=True)
|
|
131
|
+
|
|
132
|
+
# Calculate average success rate per model and sort by descending order
|
|
133
|
+
model_means = df.groupby("model")["mean"].mean().sort_values(ascending=False)
|
|
134
|
+
sorted_models = model_means.index.tolist()
|
|
135
|
+
|
|
136
|
+
# Create color mapping maintaining original colors
|
|
137
|
+
color_map = {
|
|
138
|
+
"CAP + Gemini-ER": "#6095e1",
|
|
139
|
+
"π-0.5": "#E0BE16",
|
|
140
|
+
"CAP + Oracle": "#A918C0",
|
|
141
|
+
"CAP + Molmo": "#F0529c",
|
|
142
|
+
"CAP + Moondream": "#000000",
|
|
143
|
+
"CAP + 4o": "#FF6B35"
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
# Ensure all sorted models have colors
|
|
147
|
+
sorted_colors = []
|
|
148
|
+
for model in sorted_models:
|
|
149
|
+
sorted_colors.append(color_map[model])
|
|
150
|
+
|
|
151
|
+
print(f"Sorted models: {sorted_models}")
|
|
152
|
+
print(f"Sorted colors: {sorted_colors}")
|
|
153
|
+
|
|
154
|
+
base = alt.Chart(df).encode(
|
|
155
|
+
x=alt.X(
|
|
156
|
+
"num_distractors:Q",
|
|
157
|
+
title="Number of Distractor Objects",
|
|
158
|
+
axis=alt.Axis(labelFontSize=20, titleFontSize=24, labelAngle=0, titlePadding=15, tickMinStep=1),
|
|
159
|
+
scale=alt.Scale(domain=[0, 4.5])
|
|
160
|
+
),
|
|
161
|
+
color=alt.Color(
|
|
162
|
+
"model:N",
|
|
163
|
+
scale=alt.Scale(
|
|
164
|
+
domain=sorted_models,
|
|
165
|
+
range=sorted_colors
|
|
166
|
+
),
|
|
167
|
+
legend=alt.Legend(
|
|
168
|
+
title=None,
|
|
169
|
+
labelFontSize=18,
|
|
170
|
+
symbolSize=200,
|
|
171
|
+
orient="bottom",
|
|
172
|
+
direction="horizontal",
|
|
173
|
+
symbolType="square",
|
|
174
|
+
labelLimit=0
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
lines = base.mark_line(
|
|
180
|
+
strokeWidth=3,
|
|
181
|
+
point=alt.OverlayMarkDef(size=100)
|
|
182
|
+
).encode(
|
|
183
|
+
y=alt.Y(
|
|
184
|
+
"mean:Q",
|
|
185
|
+
title="Relative Success Rate (%)",
|
|
186
|
+
scale=alt.Scale(domain=[0.6, 1.0]),
|
|
187
|
+
axis=alt.Axis(
|
|
188
|
+
values=[0.6, 0.7, 0.8, 0.9, 1.0],
|
|
189
|
+
labelExpr="datum.value * 100",
|
|
190
|
+
labelFontSize=20,
|
|
191
|
+
titleFontSize=24,
|
|
192
|
+
titlePadding=14,
|
|
193
|
+
grid=False
|
|
194
|
+
)
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Create violin plots only for non-zero distractors
|
|
199
|
+
df_with_violins = df[df["num_distractors"] != 0]
|
|
200
|
+
|
|
201
|
+
# Set random seed for reproducibility
|
|
202
|
+
np.random.seed(42)
|
|
203
|
+
|
|
204
|
+
# Create violin polygons
|
|
205
|
+
violin_width_scale = 0.08 # Width of violins
|
|
206
|
+
violin_polygon_data = []
|
|
207
|
+
|
|
208
|
+
for _, row in df_with_violins.iterrows():
|
|
209
|
+
# Generate samples from beta distribution (these are in 0-1 range)
|
|
210
|
+
samples = beta.rvs(row["alpha"], row["beta"], size=2000)
|
|
211
|
+
# Normalize samples using the stored baseline_mean
|
|
212
|
+
baseline_mean = row["baseline_mean"]
|
|
213
|
+
if baseline_mean > 0:
|
|
214
|
+
samples = samples / baseline_mean
|
|
215
|
+
|
|
216
|
+
# Compute KDE
|
|
217
|
+
kde = gaussian_kde(samples)
|
|
218
|
+
|
|
219
|
+
# Create density curve points
|
|
220
|
+
y_points = np.linspace(samples.min(), samples.max(), 100)
|
|
221
|
+
densities = kde(y_points)
|
|
222
|
+
densities = densities / densities.max() * violin_width_scale
|
|
223
|
+
|
|
224
|
+
xc = row["num_distractors"]
|
|
225
|
+
model = row["model"]
|
|
226
|
+
color = sorted_colors[sorted_models.index(model)]
|
|
227
|
+
|
|
228
|
+
# Create closed polygon: left side up, then right side down
|
|
229
|
+
for i, (y, d) in enumerate(zip(y_points, densities)):
|
|
230
|
+
violin_polygon_data.append({
|
|
231
|
+
"model": model,
|
|
232
|
+
"x": xc - d,
|
|
233
|
+
"y": y,
|
|
234
|
+
"order": i,
|
|
235
|
+
"color": color,
|
|
236
|
+
"group": f"{model}_{xc}"
|
|
237
|
+
})
|
|
238
|
+
# Right side going back down
|
|
239
|
+
for i, (y, d) in enumerate(zip(y_points[::-1], densities[::-1])):
|
|
240
|
+
violin_polygon_data.append({
|
|
241
|
+
"model": model,
|
|
242
|
+
"x": xc + d,
|
|
243
|
+
"y": y,
|
|
244
|
+
"order": len(y_points) + i,
|
|
245
|
+
"color": color,
|
|
246
|
+
"group": f"{model}_{xc}"
|
|
247
|
+
})
|
|
248
|
+
|
|
249
|
+
violin_polygon_df = pd.DataFrame(violin_polygon_data)
|
|
250
|
+
|
|
251
|
+
# Create violin shapes - now uses same quantitative scale as base
|
|
252
|
+
violins = alt.Chart(violin_polygon_df).mark_line(
|
|
253
|
+
fillOpacity=0.4,
|
|
254
|
+
strokeWidth=0.5,
|
|
255
|
+
interpolate='linear',
|
|
256
|
+
filled=True
|
|
257
|
+
).encode(
|
|
258
|
+
x=alt.X('x:Q', scale=alt.Scale(domain=[0, 4.5]), title=''),
|
|
259
|
+
y=alt.Y('y:Q', scale=alt.Scale(domain=[0.6, 1.0]), title=''),
|
|
260
|
+
order='order:Q',
|
|
261
|
+
detail='group:N',
|
|
262
|
+
fill=alt.Fill('color:N', scale=None, legend=None),
|
|
263
|
+
stroke=alt.Stroke('color:N', scale=None, legend=None)
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
labels = alt.Chart(df).mark_text(
|
|
268
|
+
dy=-8,
|
|
269
|
+
fontSize=14,
|
|
270
|
+
color="black"
|
|
271
|
+
).encode(
|
|
272
|
+
x=alt.X("num_distractors:O"),
|
|
273
|
+
y="hi:Q",
|
|
274
|
+
text=alt.Text("mean:Q", format=".0%")
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Combine layers based on SHOW_LABELS setting
|
|
278
|
+
if SHOW_LABELS:
|
|
279
|
+
chart = violins + lines + labels
|
|
280
|
+
else:
|
|
281
|
+
chart = violins + lines
|
|
282
|
+
|
|
283
|
+
return (
|
|
284
|
+
chart
|
|
285
|
+
.properties(
|
|
286
|
+
width=720,
|
|
287
|
+
height=420,
|
|
288
|
+
title={
|
|
289
|
+
"text": " EgoGym Pick Relative Success Rate vs Number of Distractors",
|
|
290
|
+
"fontSize": 22,
|
|
291
|
+
"anchor": "start",
|
|
292
|
+
"dx": 65,
|
|
293
|
+
"offset": 20
|
|
294
|
+
},
|
|
295
|
+
padding={"left": 10, "right": 10, "top": 40, "bottom": 40}
|
|
296
|
+
)
|
|
297
|
+
.configure_view(stroke=None)
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
chart = plot_bayesian_success_by_objects()
|
|
302
|
+
chart.save("success_by_objects.png", scale_factor=3)
|
|
303
|
+
chart.save("success_by_objects.pdf", scale_factor=3)
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import altair as alt
|
|
4
|
+
from scipy.stats import beta, gaussian_kde
|
|
5
|
+
|
|
6
|
+
# Register custom font for PNG export
|
|
7
|
+
alt.themes.register('custom_theme', lambda: {
|
|
8
|
+
'config': {
|
|
9
|
+
'title': {'font': 'Produkt'},
|
|
10
|
+
'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
11
|
+
'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
12
|
+
'mark': {'font': 'Produkt'},
|
|
13
|
+
'text': {'font': 'Produkt'},
|
|
14
|
+
}
|
|
15
|
+
})
|
|
16
|
+
alt.themes.enable('custom_theme')
|
|
17
|
+
|
|
18
|
+
alt.renderers.enable("colab")
|
|
19
|
+
|
|
20
|
+
SHOW_LABELS = 0 # Set to True to show percentage labels above points
|
|
21
|
+
|
|
22
|
+
# Hardcoded data: {model: {num_objects: (successes, total)}}
|
|
23
|
+
HARDCODED_DATA = {
|
|
24
|
+
"CAP + Gemini-ER": {
|
|
25
|
+
1: (3706, 5000),
|
|
26
|
+
2: (3576, 5000),
|
|
27
|
+
3: (3548, 5000),
|
|
28
|
+
4: (3482, 5000),
|
|
29
|
+
5: (3441, 5000),
|
|
30
|
+
},
|
|
31
|
+
"π-0.5": {
|
|
32
|
+
1: (1478, 5000),
|
|
33
|
+
2: (1366, 5000),
|
|
34
|
+
3: (1221, 5000),
|
|
35
|
+
4: (1058, 5000),
|
|
36
|
+
5: (1046, 5000),
|
|
37
|
+
},
|
|
38
|
+
"CAP + Oracle": {
|
|
39
|
+
1: (4034, 5000),
|
|
40
|
+
2: (3962, 5000),
|
|
41
|
+
3: (4000, 5000),
|
|
42
|
+
4: (3967, 5000),
|
|
43
|
+
5: (3994, 5000),
|
|
44
|
+
},
|
|
45
|
+
"CAP + Molmo": {
|
|
46
|
+
1: (3672, 5000),
|
|
47
|
+
2: (3350, 5000),
|
|
48
|
+
3: (3213, 5000),
|
|
49
|
+
4: (3069, 5000),
|
|
50
|
+
5: (3008, 5000),
|
|
51
|
+
},
|
|
52
|
+
"CAP + Moondream": {
|
|
53
|
+
1: (3407, 5000),
|
|
54
|
+
2: (3254, 5000),
|
|
55
|
+
3: (3184, 5000),
|
|
56
|
+
4: (3044, 5000),
|
|
57
|
+
5: (3033, 5000),
|
|
58
|
+
},
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def plot_bayesian_success_by_objects():
|
|
63
|
+
rows = []
|
|
64
|
+
|
|
65
|
+
for model_name, data_dict in HARDCODED_DATA.items():
|
|
66
|
+
for n_obj, (s, t) in data_dict.items():
|
|
67
|
+
# Beta posterior
|
|
68
|
+
a, b = 1 + s, 1 + (t - s)
|
|
69
|
+
mean = a / (a + b)
|
|
70
|
+
lo = beta.ppf(0.025, a, b)
|
|
71
|
+
hi = beta.ppf(0.975, a, b)
|
|
72
|
+
|
|
73
|
+
print(f"{model_name} | {n_obj} objects: {s}/{t} = {mean:.2%}")
|
|
74
|
+
|
|
75
|
+
rows.append({
|
|
76
|
+
"model": model_name,
|
|
77
|
+
"num_objects": n_obj,
|
|
78
|
+
"num_distractors": n_obj - 1,
|
|
79
|
+
"mean": mean,
|
|
80
|
+
"lo": lo,
|
|
81
|
+
"hi": hi,
|
|
82
|
+
"err": hi - mean,
|
|
83
|
+
"alpha": a,
|
|
84
|
+
"beta": b,
|
|
85
|
+
})
|
|
86
|
+
|
|
87
|
+
df = pd.DataFrame(rows)
|
|
88
|
+
|
|
89
|
+
# Normalize each model's success rates relative to 0 distractors (1 object)
|
|
90
|
+
normalized_rows = []
|
|
91
|
+
for model_name in df["model"].unique():
|
|
92
|
+
model_df = df[df["model"] == model_name].copy()
|
|
93
|
+
# Get the baseline (0 distractors = 1 object)
|
|
94
|
+
baseline_row = model_df[model_df["num_distractors"] == 0]
|
|
95
|
+
if len(baseline_row) == 0:
|
|
96
|
+
print(f"Warning: No 0 distractor data for {model_name}")
|
|
97
|
+
continue
|
|
98
|
+
baseline_mean = baseline_row.iloc[0]["mean"]
|
|
99
|
+
if baseline_mean == 0:
|
|
100
|
+
print(f"Warning: Baseline is 0 for {model_name}, skipping normalization")
|
|
101
|
+
continue
|
|
102
|
+
|
|
103
|
+
# Store baseline for violin normalization
|
|
104
|
+
model_df["baseline_mean"] = baseline_mean
|
|
105
|
+
|
|
106
|
+
# Normalize all metrics
|
|
107
|
+
model_df["mean"] = model_df["mean"] / baseline_mean
|
|
108
|
+
model_df["lo"] = model_df["lo"] / baseline_mean
|
|
109
|
+
model_df["hi"] = model_df["hi"] / baseline_mean
|
|
110
|
+
model_df["err"] = model_df["hi"] - model_df["mean"]
|
|
111
|
+
normalized_rows.append(model_df)
|
|
112
|
+
|
|
113
|
+
df = pd.concat(normalized_rows, ignore_index=True)
|
|
114
|
+
|
|
115
|
+
# Calculate average success rate per model and sort by descending order
|
|
116
|
+
model_means = df.groupby("model")["mean"].mean().sort_values(ascending=False)
|
|
117
|
+
sorted_models = model_means.index.tolist()
|
|
118
|
+
|
|
119
|
+
# Create color mapping maintaining original colors
|
|
120
|
+
color_map = {
|
|
121
|
+
"CAP + Gemini-ER": "#6095e1",
|
|
122
|
+
"π-0.5": "#E0BE16",
|
|
123
|
+
"CAP + Oracle": "#A918C0",
|
|
124
|
+
"CAP + Molmo": "#F0529c",
|
|
125
|
+
"CAP + Moondream": "#000000",
|
|
126
|
+
"CAP + 4o": "#FF6B35"
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
# Ensure all sorted models have colors
|
|
130
|
+
sorted_colors = []
|
|
131
|
+
for model in sorted_models:
|
|
132
|
+
sorted_colors.append(color_map[model])
|
|
133
|
+
|
|
134
|
+
print(f"Sorted models: {sorted_models}")
|
|
135
|
+
print(f"Sorted colors: {sorted_colors}")
|
|
136
|
+
|
|
137
|
+
base = alt.Chart(df).encode(
|
|
138
|
+
x=alt.X(
|
|
139
|
+
"num_distractors:Q",
|
|
140
|
+
title="Number of Distractor Objects",
|
|
141
|
+
axis=alt.Axis(labelFontSize=20, titleFontSize=24, labelAngle=0, titlePadding=15, tickMinStep=1),
|
|
142
|
+
scale=alt.Scale(domain=[0, 4.5])
|
|
143
|
+
),
|
|
144
|
+
color=alt.Color(
|
|
145
|
+
"model:N",
|
|
146
|
+
scale=alt.Scale(
|
|
147
|
+
domain=sorted_models,
|
|
148
|
+
range=sorted_colors
|
|
149
|
+
),
|
|
150
|
+
legend=alt.Legend(
|
|
151
|
+
title=None,
|
|
152
|
+
labelFontSize=18,
|
|
153
|
+
symbolSize=200,
|
|
154
|
+
orient="bottom",
|
|
155
|
+
direction="horizontal",
|
|
156
|
+
symbolType="square",
|
|
157
|
+
labelLimit=0
|
|
158
|
+
)
|
|
159
|
+
)
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
lines = base.mark_line(
|
|
163
|
+
strokeWidth=3,
|
|
164
|
+
point=alt.OverlayMarkDef(size=100)
|
|
165
|
+
).encode(
|
|
166
|
+
y=alt.Y(
|
|
167
|
+
"mean:Q",
|
|
168
|
+
title="Relative Success Rate (%)",
|
|
169
|
+
scale=alt.Scale(domain=[0.6, 1.0]),
|
|
170
|
+
axis=alt.Axis(
|
|
171
|
+
format=".0%",
|
|
172
|
+
values=[0.6, 0.7, 0.8, 0.9, 1.0],
|
|
173
|
+
labelFontSize=20,
|
|
174
|
+
titleFontSize=24,
|
|
175
|
+
titlePadding=14,
|
|
176
|
+
grid=False
|
|
177
|
+
)
|
|
178
|
+
)
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Create violin plots only for non-zero distractors
|
|
182
|
+
df_with_violins = df[df["num_distractors"] != 0]
|
|
183
|
+
|
|
184
|
+
# Set random seed for reproducibility
|
|
185
|
+
np.random.seed(42)
|
|
186
|
+
|
|
187
|
+
# Create violin polygons
|
|
188
|
+
violin_width_scale = 0.08 # Width of violins
|
|
189
|
+
violin_polygon_data = []
|
|
190
|
+
|
|
191
|
+
for _, row in df_with_violins.iterrows():
|
|
192
|
+
# Generate samples from beta distribution (these are in 0-1 range)
|
|
193
|
+
samples = beta.rvs(row["alpha"], row["beta"], size=2000)
|
|
194
|
+
# Normalize samples using the stored baseline_mean
|
|
195
|
+
baseline_mean = row["baseline_mean"]
|
|
196
|
+
if baseline_mean > 0:
|
|
197
|
+
samples = samples / baseline_mean
|
|
198
|
+
|
|
199
|
+
# Compute KDE
|
|
200
|
+
kde = gaussian_kde(samples)
|
|
201
|
+
|
|
202
|
+
# Create density curve points
|
|
203
|
+
y_points = np.linspace(samples.min(), samples.max(), 100)
|
|
204
|
+
densities = kde(y_points)
|
|
205
|
+
densities = densities / densities.max() * violin_width_scale
|
|
206
|
+
|
|
207
|
+
xc = row["num_distractors"]
|
|
208
|
+
model = row["model"]
|
|
209
|
+
color = sorted_colors[sorted_models.index(model)]
|
|
210
|
+
|
|
211
|
+
# Create closed polygon: left side up, then right side down
|
|
212
|
+
for i, (y, d) in enumerate(zip(y_points, densities)):
|
|
213
|
+
violin_polygon_data.append({
|
|
214
|
+
"model": model,
|
|
215
|
+
"x": xc - d,
|
|
216
|
+
"y": y,
|
|
217
|
+
"order": i,
|
|
218
|
+
"color": color,
|
|
219
|
+
"group": f"{model}_{xc}"
|
|
220
|
+
})
|
|
221
|
+
# Right side going back down
|
|
222
|
+
for i, (y, d) in enumerate(zip(y_points[::-1], densities[::-1])):
|
|
223
|
+
violin_polygon_data.append({
|
|
224
|
+
"model": model,
|
|
225
|
+
"x": xc + d,
|
|
226
|
+
"y": y,
|
|
227
|
+
"order": len(y_points) + i,
|
|
228
|
+
"color": color,
|
|
229
|
+
"group": f"{model}_{xc}"
|
|
230
|
+
})
|
|
231
|
+
|
|
232
|
+
violin_polygon_df = pd.DataFrame(violin_polygon_data)
|
|
233
|
+
|
|
234
|
+
# Create violin shapes - now uses same quantitative scale as base
|
|
235
|
+
violins = alt.Chart(violin_polygon_df).mark_line(
|
|
236
|
+
fillOpacity=0.4,
|
|
237
|
+
strokeWidth=0.5,
|
|
238
|
+
interpolate='linear',
|
|
239
|
+
filled=True
|
|
240
|
+
).encode(
|
|
241
|
+
x=alt.X('x:Q', scale=alt.Scale(domain=[0, 4.5]), title=''),
|
|
242
|
+
y=alt.Y('y:Q', scale=alt.Scale(domain=[0.6, 1.0]), title=''),
|
|
243
|
+
order='order:Q',
|
|
244
|
+
detail='group:N',
|
|
245
|
+
fill=alt.Fill('color:N', scale=None, legend=None),
|
|
246
|
+
stroke=alt.Stroke('color:N', scale=None, legend=None)
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
labels = alt.Chart(df).mark_text(
|
|
251
|
+
dy=-8,
|
|
252
|
+
fontSize=14,
|
|
253
|
+
color="black"
|
|
254
|
+
).encode(
|
|
255
|
+
x=alt.X("num_distractors:O"),
|
|
256
|
+
y="hi:Q",
|
|
257
|
+
text=alt.Text("mean:Q", format=".0%")
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# Combine layers based on SHOW_LABELS setting
|
|
261
|
+
if SHOW_LABELS:
|
|
262
|
+
chart = violins + lines + labels
|
|
263
|
+
else:
|
|
264
|
+
chart = violins + lines
|
|
265
|
+
|
|
266
|
+
return (
|
|
267
|
+
chart
|
|
268
|
+
.properties(
|
|
269
|
+
width=720,
|
|
270
|
+
height=420,
|
|
271
|
+
title={
|
|
272
|
+
"text": " EgoGym Pick Relative Success Rate vs Number of Distractors",
|
|
273
|
+
"fontSize": 22,
|
|
274
|
+
"anchor": "start",
|
|
275
|
+
"dx": 40
|
|
276
|
+
},
|
|
277
|
+
padding={"left": 10, "right": 10, "top": 40, "bottom": 40}
|
|
278
|
+
)
|
|
279
|
+
.configure_view(stroke=None)
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
chart = plot_bayesian_success_by_objects()
|
|
284
|
+
chart.save("success_by_objects.png", scale_factor=3)
|
|
285
|
+
chart.save("success_by_objects.pdf", scale_factor=3)
|