egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import altair as alt
|
|
5
|
+
from scipy.stats import beta, pearsonr, gaussian_kde
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def plot_bayesian_correlation_hardcoded():
|
|
9
|
+
# Set random seed for reproducibility
|
|
10
|
+
np.random.seed(42)
|
|
11
|
+
|
|
12
|
+
# Register and enable custom font theme
|
|
13
|
+
alt.themes.register('custom_theme', lambda: {
|
|
14
|
+
'config': {
|
|
15
|
+
'title': {'font': 'Produkt'},
|
|
16
|
+
'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
17
|
+
'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
18
|
+
'mark': {'font': 'Produkt'},
|
|
19
|
+
'text': {'font': 'Produkt'},
|
|
20
|
+
}
|
|
21
|
+
})
|
|
22
|
+
alt.themes.enable('custom_theme')
|
|
23
|
+
|
|
24
|
+
# Hardcoded values
|
|
25
|
+
checkpoint_data = {
|
|
26
|
+
"checkpoint_10": {
|
|
27
|
+
"sim_success": 1001,
|
|
28
|
+
"sim_total": 5000,
|
|
29
|
+
"real_success": 60,
|
|
30
|
+
"real_total": 250
|
|
31
|
+
},
|
|
32
|
+
"checkpoint_50": {
|
|
33
|
+
"sim_success": 1847,
|
|
34
|
+
"sim_total": 5000,
|
|
35
|
+
"real_success": 96,
|
|
36
|
+
"real_total": 250
|
|
37
|
+
},
|
|
38
|
+
"checkpoint_64": {
|
|
39
|
+
"sim_success": 3137,
|
|
40
|
+
"sim_total": 5000,
|
|
41
|
+
"real_success": 168,
|
|
42
|
+
"real_total": 250
|
|
43
|
+
},
|
|
44
|
+
"checkpoint_80": {
|
|
45
|
+
"sim_success": 3994,
|
|
46
|
+
"sim_total": 5000,
|
|
47
|
+
"real_success": 208,
|
|
48
|
+
"real_total": 250
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
stats = {}
|
|
53
|
+
checkpoint_order = ["checkpoint_10", "checkpoint_50", "checkpoint_64", "checkpoint_80"]
|
|
54
|
+
|
|
55
|
+
for checkpoint in checkpoint_order:
|
|
56
|
+
data = checkpoint_data[checkpoint]
|
|
57
|
+
|
|
58
|
+
# Simulation stats
|
|
59
|
+
a_sim = 1 + data["sim_success"]
|
|
60
|
+
b_sim = 1 + (data["sim_total"] - data["sim_success"])
|
|
61
|
+
sim_mean = 100 * a_sim / (a_sim + b_sim)
|
|
62
|
+
|
|
63
|
+
# Real stats
|
|
64
|
+
a_real = 1 + data["real_success"]
|
|
65
|
+
b_real = 1 + (data["real_total"] - data["real_success"])
|
|
66
|
+
real_mean = 100 * a_real / (a_real + b_real)
|
|
67
|
+
|
|
68
|
+
print(f"{checkpoint}: sim={data['sim_success']}/{data['sim_total']} ({sim_mean:.1f}%), real={real_mean:.1f}% ({data['real_success']}/{data['real_total']})")
|
|
69
|
+
|
|
70
|
+
stats[checkpoint] = {
|
|
71
|
+
"sim_mean": sim_mean,
|
|
72
|
+
"sim_alpha": a_sim,
|
|
73
|
+
"sim_beta": b_sim,
|
|
74
|
+
"real_mean": real_mean,
|
|
75
|
+
"real_alpha": a_real,
|
|
76
|
+
"real_beta": b_real,
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
if not stats:
|
|
80
|
+
print("No data found!")
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
checkpoint_names = list(stats.keys())
|
|
84
|
+
sim_rates = np.array([stats[c]["sim_mean"] for c in checkpoint_names])
|
|
85
|
+
real_rates = np.array([stats[c]["real_mean"] for c in checkpoint_names])
|
|
86
|
+
|
|
87
|
+
r, p_value = pearsonr(sim_rates, real_rates)
|
|
88
|
+
|
|
89
|
+
# Compute confidence intervals and violin plot data
|
|
90
|
+
violin_data = []
|
|
91
|
+
point_data = []
|
|
92
|
+
error_data = []
|
|
93
|
+
|
|
94
|
+
for i, checkpoint in enumerate(checkpoint_names):
|
|
95
|
+
sim_lo = beta.ppf(0.025, stats[checkpoint]["sim_alpha"], stats[checkpoint]["sim_beta"]) * 100
|
|
96
|
+
sim_hi = beta.ppf(0.975, stats[checkpoint]["sim_alpha"], stats[checkpoint]["sim_beta"]) * 100
|
|
97
|
+
|
|
98
|
+
# Generate violin data
|
|
99
|
+
y_samples = beta.rvs(stats[checkpoint]["real_alpha"], stats[checkpoint]["real_beta"], size=10000) * 100
|
|
100
|
+
|
|
101
|
+
xc = sim_rates[i]
|
|
102
|
+
yc = real_rates[i]
|
|
103
|
+
|
|
104
|
+
# Store point data
|
|
105
|
+
point_data.append({
|
|
106
|
+
'checkpoint': checkpoint,
|
|
107
|
+
'sim_rate': xc,
|
|
108
|
+
'real_rate': yc,
|
|
109
|
+
'sim_lo': sim_lo,
|
|
110
|
+
'sim_hi': sim_hi
|
|
111
|
+
})
|
|
112
|
+
|
|
113
|
+
# Store violin data for density transform
|
|
114
|
+
for y_val in y_samples:
|
|
115
|
+
violin_data.append({
|
|
116
|
+
'checkpoint': checkpoint,
|
|
117
|
+
'sim_rate': xc,
|
|
118
|
+
'real_rate': y_val
|
|
119
|
+
})
|
|
120
|
+
|
|
121
|
+
# Compute bootstrap correlation confidence interval
|
|
122
|
+
n_samples = 1000
|
|
123
|
+
r_samples = []
|
|
124
|
+
for _ in range(n_samples):
|
|
125
|
+
sim_sample = [beta.rvs(stats[c]["sim_alpha"], stats[c]["sim_beta"]) * 100 for c in checkpoint_names]
|
|
126
|
+
real_sample = [beta.rvs(stats[c]["real_alpha"], stats[c]["real_beta"]) * 100 for c in checkpoint_names]
|
|
127
|
+
r_sample, _ = pearsonr(sim_sample, real_sample)
|
|
128
|
+
r_samples.append(r_sample)
|
|
129
|
+
r_samples = np.array(r_samples)
|
|
130
|
+
r_lo = np.percentile(r_samples, 2.5)
|
|
131
|
+
r_hi = np.percentile(r_samples, 97.5)
|
|
132
|
+
|
|
133
|
+
# Create DataFrames
|
|
134
|
+
point_df = pd.DataFrame(point_data)
|
|
135
|
+
|
|
136
|
+
# For violins, we need much less data - just sample points
|
|
137
|
+
violin_sample_data = []
|
|
138
|
+
for checkpoint in checkpoint_names:
|
|
139
|
+
# Use only 500 samples per checkpoint to avoid memory issues
|
|
140
|
+
y_samples = beta.rvs(stats[checkpoint]["real_alpha"], stats[checkpoint]["real_beta"], size=500) * 100
|
|
141
|
+
xc = point_df[point_df['checkpoint'] == checkpoint]['sim_rate'].values[0]
|
|
142
|
+
for y_val in y_samples:
|
|
143
|
+
violin_sample_data.append({
|
|
144
|
+
'checkpoint': checkpoint,
|
|
145
|
+
'sim_rate': xc,
|
|
146
|
+
'real_rate': y_val
|
|
147
|
+
})
|
|
148
|
+
|
|
149
|
+
violin_df = pd.DataFrame(violin_sample_data)
|
|
150
|
+
|
|
151
|
+
# Compute regression line
|
|
152
|
+
z = np.polyfit(sim_rates, real_rates, 1)
|
|
153
|
+
p = np.poly1d(z)
|
|
154
|
+
xs = np.linspace(sim_rates.min() - 5, sim_rates.max() + 5, 200)
|
|
155
|
+
regression_df = pd.DataFrame({'sim_rate': xs, 'real_rate': p(xs)})
|
|
156
|
+
|
|
157
|
+
# Manually create violin shapes positioned at sim_rate coordinates
|
|
158
|
+
violin_width_scale = 2.5
|
|
159
|
+
violin_polygon_data = []
|
|
160
|
+
|
|
161
|
+
# Color mapping from lowest to highest checkpoint
|
|
162
|
+
checkpoint_colors = {
|
|
163
|
+
"checkpoint_10": "#F8F0FA", # very light purple
|
|
164
|
+
"checkpoint_50": "#D6BAE2", # medium purple
|
|
165
|
+
"checkpoint_64": "#9B66BB", # medium-dark purple
|
|
166
|
+
"checkpoint_80": "#4B136D" # deep purple
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
for checkpoint in checkpoint_names:
|
|
170
|
+
# Get samples and compute KDE
|
|
171
|
+
y_samples = beta.rvs(stats[checkpoint]["real_alpha"], stats[checkpoint]["real_beta"], size=2000) * 100
|
|
172
|
+
kde = gaussian_kde(y_samples)
|
|
173
|
+
|
|
174
|
+
# Create density curve points
|
|
175
|
+
y_points = np.linspace(y_samples.min(), y_samples.max(), 100)
|
|
176
|
+
densities = kde(y_points)
|
|
177
|
+
densities = densities / densities.max() * violin_width_scale
|
|
178
|
+
|
|
179
|
+
xc = point_df[point_df['checkpoint'] == checkpoint]['sim_rate'].values[0]
|
|
180
|
+
color = checkpoint_colors[checkpoint]
|
|
181
|
+
|
|
182
|
+
# Create closed polygon: left side up, then right side down
|
|
183
|
+
for i, (y, d) in enumerate(zip(y_points, densities)):
|
|
184
|
+
violin_polygon_data.append({
|
|
185
|
+
'checkpoint': checkpoint,
|
|
186
|
+
'x': xc - d,
|
|
187
|
+
'y': y,
|
|
188
|
+
'order': i,
|
|
189
|
+
'color': color
|
|
190
|
+
})
|
|
191
|
+
# Right side going back down
|
|
192
|
+
for i, (y, d) in enumerate(zip(y_points[::-1], densities[::-1])):
|
|
193
|
+
violin_polygon_data.append({
|
|
194
|
+
'checkpoint': checkpoint,
|
|
195
|
+
'x': xc + d,
|
|
196
|
+
'y': y,
|
|
197
|
+
'order': len(y_points) + i,
|
|
198
|
+
'color': color
|
|
199
|
+
})
|
|
200
|
+
|
|
201
|
+
violin_polygon_df = pd.DataFrame(violin_polygon_data)
|
|
202
|
+
|
|
203
|
+
# Create violin shapes with color encoding
|
|
204
|
+
violins = alt.Chart(violin_polygon_df).mark_line(
|
|
205
|
+
fillOpacity=0.6,
|
|
206
|
+
stroke='#8B4789',
|
|
207
|
+
strokeWidth=1.1,
|
|
208
|
+
interpolate='linear',
|
|
209
|
+
filled=True
|
|
210
|
+
).encode(
|
|
211
|
+
x=alt.X('x:Q', title='EgoGym Performance (%)').scale(zero=False),
|
|
212
|
+
y=alt.Y('y:Q', title='Real Performance (%)').scale(zero=False),
|
|
213
|
+
order='order:Q',
|
|
214
|
+
detail='checkpoint:N',
|
|
215
|
+
fill=alt.Fill('color:N', scale=None, legend=None)
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# Create regression line
|
|
220
|
+
regression_line = alt.Chart(regression_df).mark_line(
|
|
221
|
+
strokeDash=[5, 5],
|
|
222
|
+
color='black',
|
|
223
|
+
opacity=0.5,
|
|
224
|
+
size=1.5
|
|
225
|
+
).encode(
|
|
226
|
+
x=alt.X('sim_rate:Q').scale(zero=False),
|
|
227
|
+
y=alt.Y('real_rate:Q').scale(zero=False)
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Create error bars for simulation
|
|
231
|
+
error_bars = alt.Chart(point_df).mark_errorbar(ticks=True, thickness=1.5).encode(
|
|
232
|
+
x=alt.X('sim_lo:Q', title=''),
|
|
233
|
+
x2=alt.X2('sim_hi:Q'),
|
|
234
|
+
y='real_rate:Q'
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Create horizontal lines at mean (for each violin)
|
|
238
|
+
mean_line_data = []
|
|
239
|
+
for checkpoint in checkpoint_names:
|
|
240
|
+
xc = point_df[point_df['checkpoint'] == checkpoint]['sim_rate'].values[0]
|
|
241
|
+
yc = point_df[point_df['checkpoint'] == checkpoint]['real_rate'].values[0]
|
|
242
|
+
mean_line_data.append({
|
|
243
|
+
'checkpoint': checkpoint,
|
|
244
|
+
'x_left': xc - 1.0,
|
|
245
|
+
'x_right': xc + 1.0,
|
|
246
|
+
'y': yc
|
|
247
|
+
})
|
|
248
|
+
mean_line_df = pd.DataFrame(mean_line_data)
|
|
249
|
+
|
|
250
|
+
mean_lines = alt.Chart(mean_line_df).mark_rule(color='black', size=1.2).encode(
|
|
251
|
+
x=alt.X('x_left:Q'),
|
|
252
|
+
x2=alt.X2('x_right:Q'),
|
|
253
|
+
y='y:Q'
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Create central points
|
|
257
|
+
points = alt.Chart(point_df).mark_point(
|
|
258
|
+
filled=True,
|
|
259
|
+
size=100,
|
|
260
|
+
color='#A894DB',
|
|
261
|
+
stroke='#8B4789',
|
|
262
|
+
strokeWidth=1.5,
|
|
263
|
+
shape='diamond'
|
|
264
|
+
).encode(
|
|
265
|
+
x=alt.X('sim_rate:Q').scale(zero=False),
|
|
266
|
+
y=alt.Y('real_rate:Q').scale(zero=False),
|
|
267
|
+
tooltip=['checkpoint', 'sim_rate', 'real_rate']
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Combine layers (violins, regression line, error bars, mean lines and points on top)
|
|
271
|
+
chart = (violins + regression_line + mean_lines + error_bars + points).properties(
|
|
272
|
+
title={
|
|
273
|
+
'text': ' Blind EgoGym-Pick Sim-to-Real Correlation',
|
|
274
|
+
'fontSize': 24,
|
|
275
|
+
'anchor': 'start',
|
|
276
|
+
'dx': 15,
|
|
277
|
+
'dy': -8
|
|
278
|
+
},
|
|
279
|
+
width=500,
|
|
280
|
+
height=400
|
|
281
|
+
).configure_axis(
|
|
282
|
+
labelFontSize=16,
|
|
283
|
+
titleFontSize=18,
|
|
284
|
+
titleFontStyle='normal',
|
|
285
|
+
grid=True,
|
|
286
|
+
gridOpacity=0.3,
|
|
287
|
+
tickCount=6
|
|
288
|
+
).configure_view(
|
|
289
|
+
strokeWidth=0
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Add background box for correlation text (wider and thinner)
|
|
293
|
+
correlation_box = alt.Chart(pd.DataFrame([{
|
|
294
|
+
'x': point_df['sim_rate'].max() - 18,
|
|
295
|
+
'y': point_df['real_rate'].min() - 10,
|
|
296
|
+
'x2': point_df['sim_rate'].max() + 4,
|
|
297
|
+
'y2': point_df['real_rate'].min() - 4
|
|
298
|
+
}])).mark_rect(
|
|
299
|
+
fill='#F7D45B',
|
|
300
|
+
stroke='gray',
|
|
301
|
+
strokeWidth=1.5,
|
|
302
|
+
opacity=0.9
|
|
303
|
+
).encode(
|
|
304
|
+
x=alt.X('x:Q'),
|
|
305
|
+
y=alt.Y('y:Q'),
|
|
306
|
+
x2='x2:Q',
|
|
307
|
+
y2='y2:Q'
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Add correlation text annotation in box
|
|
311
|
+
correlation_text = alt.Chart(pd.DataFrame([{
|
|
312
|
+
'x': point_df['sim_rate'].max() - 7,
|
|
313
|
+
'y': point_df['real_rate'].min() - 7,
|
|
314
|
+
'text': f'95% CI r: {r_lo:.3f}, {r_hi:.3f}'
|
|
315
|
+
}])).mark_text(
|
|
316
|
+
align='center',
|
|
317
|
+
baseline='middle',
|
|
318
|
+
fontSize=16,
|
|
319
|
+
font='Produkt'
|
|
320
|
+
).encode(
|
|
321
|
+
x=alt.X('x:Q'),
|
|
322
|
+
y=alt.Y('y:Q'),
|
|
323
|
+
text='text:N'
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Combine with correlation box and text
|
|
327
|
+
chart = chart + correlation_box + correlation_text
|
|
328
|
+
|
|
329
|
+
# Save chart
|
|
330
|
+
chart.save("checkpoint_correlation.html")
|
|
331
|
+
chart.save("checkpoint_correlation.png", scale_factor=2.0)
|
|
332
|
+
chart.save("checkpoint_correlation.pdf", scale_factor=2.0)
|
|
333
|
+
print(f"\nCorrelation: r = {r:.3f}, p = {p_value:.4f}")
|
|
334
|
+
print("\nPlot saved to: checkpoint_correlation.html and checkpoint_correlation.png")
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
if __name__ == "__main__":
|
|
338
|
+
plot_bayesian_correlation_hardcoded()
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import altair as alt
|
|
6
|
+
|
|
7
|
+
# Register custom font for export
|
|
8
|
+
alt.themes.register('custom_theme', lambda: {
|
|
9
|
+
'config': {
|
|
10
|
+
'title': {'font': 'Produkt'},
|
|
11
|
+
'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
12
|
+
'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
13
|
+
'mark': {'font': 'Produkt'},
|
|
14
|
+
'text': {'font': 'Produkt'},
|
|
15
|
+
}
|
|
16
|
+
})
|
|
17
|
+
alt.themes.enable('custom_theme')
|
|
18
|
+
|
|
19
|
+
BASE_DIR = "logs/5_objects"
|
|
20
|
+
REWARD_THRESHOLD = 0.03
|
|
21
|
+
PARTIAL_LIFT_THRESHOLD = 0.005
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def compute_outcomes_from_csv(csv_path):
|
|
25
|
+
if not os.path.exists(csv_path):
|
|
26
|
+
return None, None
|
|
27
|
+
df = pd.read_csv(csv_path, sep="\t")
|
|
28
|
+
total_episodes = len(df)
|
|
29
|
+
|
|
30
|
+
# Count successes
|
|
31
|
+
successes = (df["max_reward"] > REWARD_THRESHOLD).sum()
|
|
32
|
+
|
|
33
|
+
def get_failure_mode(row):
|
|
34
|
+
if row["max_reward"] > REWARD_THRESHOLD:
|
|
35
|
+
return "Success"
|
|
36
|
+
|
|
37
|
+
bodies_contacted = str(row.get("grasped_bodies", ""))
|
|
38
|
+
object_name = str(row.get("object_name", ""))
|
|
39
|
+
grasping_object = row.get("grasping_object", False)
|
|
40
|
+
is_grasping = row.get("is_grasping", False) # Final gripper state only
|
|
41
|
+
|
|
42
|
+
# Extract body names from grasped_bodies list
|
|
43
|
+
has_target_contact = object_name in bodies_contacted
|
|
44
|
+
has_gripper_contact = "left" in bodies_contacted or "right" in bodies_contacted
|
|
45
|
+
has_any_object_contact = "object" in bodies_contacted
|
|
46
|
+
has_wrong_object_contact = has_any_object_contact and not has_target_contact
|
|
47
|
+
|
|
48
|
+
# Decision tree (most specific to least specific)
|
|
49
|
+
# Note: is_grasping is only final state, grasping_object tracks if target was ever grasped
|
|
50
|
+
|
|
51
|
+
# 1. Successfully grasped target but didn't lift high enough
|
|
52
|
+
if grasping_object and row["max_reward"] >= PARTIAL_LIFT_THRESHOLD:
|
|
53
|
+
return "Did not lift enough"
|
|
54
|
+
|
|
55
|
+
# 2. Grasped target but barely lifted (or dropped immediately)
|
|
56
|
+
if grasping_object:
|
|
57
|
+
return "Object touched but not grasped"
|
|
58
|
+
|
|
59
|
+
# 3. Made contact with target but never achieved grasp
|
|
60
|
+
if has_target_contact:
|
|
61
|
+
return "Object touched but not grasped"
|
|
62
|
+
|
|
63
|
+
# 4. Grasped wrong object with significant lift
|
|
64
|
+
if has_wrong_object_contact and row["max_reward"] >= PARTIAL_LIFT_THRESHOLD:
|
|
65
|
+
return "Picked wrong object"
|
|
66
|
+
|
|
67
|
+
# 5. Contacted wrong object (with or without final grasp state)
|
|
68
|
+
if has_wrong_object_contact:
|
|
69
|
+
return "Picked wrong object"
|
|
70
|
+
|
|
71
|
+
# 6. Gripper closed at end or had gripper contact but no object identification
|
|
72
|
+
if is_grasping or has_gripper_contact:
|
|
73
|
+
return "Empty Grasp"
|
|
74
|
+
|
|
75
|
+
# 7. Never made meaningful contact with anything
|
|
76
|
+
return "Did not grasp"
|
|
77
|
+
|
|
78
|
+
df["outcome"] = df.apply(get_failure_mode, axis=1)
|
|
79
|
+
|
|
80
|
+
outcomes = {
|
|
81
|
+
"Success": 0,
|
|
82
|
+
"Did not lift enough": 0,
|
|
83
|
+
"Object touched but not grasped": 0,
|
|
84
|
+
"Picked wrong object": 0,
|
|
85
|
+
"Empty Grasp": 0,
|
|
86
|
+
"Did not grasp": 0,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
for _, row in df.iterrows():
|
|
90
|
+
mode = row["outcome"]
|
|
91
|
+
if mode in outcomes:
|
|
92
|
+
outcomes[mode] += 1
|
|
93
|
+
|
|
94
|
+
return outcomes, total_episodes
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def plot_failure_modes():
|
|
98
|
+
checkpoint_order = ["checkpoint_10", "checkpoint_50", "checkpoint_64", "checkpoint_80"]
|
|
99
|
+
checkpoint_labels = {
|
|
100
|
+
"checkpoint_10": "Checkpoint 24%",
|
|
101
|
+
"checkpoint_50": "Checkpoint 39%",
|
|
102
|
+
"checkpoint_64": "Checkpoint 68%",
|
|
103
|
+
"checkpoint_80": "Checkpoint 83%"
|
|
104
|
+
}
|
|
105
|
+
outcome_data = {}
|
|
106
|
+
episode_totals = {}
|
|
107
|
+
|
|
108
|
+
for checkpoint in checkpoint_order:
|
|
109
|
+
csv_path = os.path.join(BASE_DIR, checkpoint, "log.csv")
|
|
110
|
+
outcomes, total_episodes = compute_outcomes_from_csv(csv_path)
|
|
111
|
+
if outcomes is None:
|
|
112
|
+
print(f"Warning: Could not find CSV data for {checkpoint}")
|
|
113
|
+
continue
|
|
114
|
+
outcome_data[checkpoint] = outcomes
|
|
115
|
+
episode_totals[checkpoint] = total_episodes
|
|
116
|
+
|
|
117
|
+
if not outcome_data:
|
|
118
|
+
print("No data found!")
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
checkpoint_names = list(outcome_data.keys())
|
|
122
|
+
|
|
123
|
+
# Print results to terminal
|
|
124
|
+
print("\n" + "="*80)
|
|
125
|
+
print("FAILURE MODE ANALYSIS BY CHECKPOINT")
|
|
126
|
+
print("="*80)
|
|
127
|
+
for checkpoint in checkpoint_names:
|
|
128
|
+
checkpoint_label = checkpoint_labels.get(checkpoint, checkpoint)
|
|
129
|
+
total = episode_totals.get(checkpoint, 0)
|
|
130
|
+
print(f"\n{checkpoint_label} (n={total}):")
|
|
131
|
+
print("-" * 60)
|
|
132
|
+
for outcome, count in outcome_data[checkpoint].items():
|
|
133
|
+
percentage = (count / total * 100) if total > 0 else 0
|
|
134
|
+
print(f" {outcome:40s}: {count:3d} ({percentage:5.1f}%)")
|
|
135
|
+
print("\n" + "="*80 + "\n")
|
|
136
|
+
|
|
137
|
+
outcome_order = [
|
|
138
|
+
"Success",
|
|
139
|
+
"Did not lift enough",
|
|
140
|
+
"Object touched but not grasped",
|
|
141
|
+
"Picked wrong object",
|
|
142
|
+
"Empty Grasp",
|
|
143
|
+
"Did not grasp",
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
colors = {
|
|
147
|
+
"Success": "#388038", # Green
|
|
148
|
+
"Did not lift enough": "#F7D45B", # Yellow
|
|
149
|
+
"Object touched but not grasped": "#66ACF7", # Blue
|
|
150
|
+
"Picked wrong object": "#F0529C", # Pink
|
|
151
|
+
"Empty Grasp": "#9B66BB", # Purple
|
|
152
|
+
"Did not grasp": "#870927", # Dark red
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
# Calculate average percentage for each outcome to sort by size
|
|
156
|
+
outcome_totals = {}
|
|
157
|
+
for checkpoint in checkpoint_names:
|
|
158
|
+
total = episode_totals.get(checkpoint, 0) or 1
|
|
159
|
+
for outcome in outcome_order:
|
|
160
|
+
count = outcome_data[checkpoint].get(outcome, 0)
|
|
161
|
+
percentage = count / total * 100
|
|
162
|
+
if outcome not in outcome_totals:
|
|
163
|
+
outcome_totals[outcome] = 0
|
|
164
|
+
outcome_totals[outcome] += percentage
|
|
165
|
+
|
|
166
|
+
# Sort outcomes by total percentage (largest first), excluding Success
|
|
167
|
+
failure_modes = [o for o in outcome_order if o != "Success"]
|
|
168
|
+
sorted_failures = sorted(failure_modes, key=lambda x: outcome_totals.get(x, 0), reverse=True)
|
|
169
|
+
|
|
170
|
+
# Build stacking order: Success at bottom, then failures from largest to smallest going up
|
|
171
|
+
stacking_order = ["Success"] + sorted_failures
|
|
172
|
+
|
|
173
|
+
chart_data = []
|
|
174
|
+
for checkpoint in checkpoint_names:
|
|
175
|
+
checkpoint_label = checkpoint_labels.get(checkpoint, checkpoint)
|
|
176
|
+
|
|
177
|
+
total = episode_totals.get(checkpoint, 0) or 1
|
|
178
|
+
for outcome in stacking_order:
|
|
179
|
+
count = outcome_data[checkpoint].get(outcome, 0)
|
|
180
|
+
percentage = count / total * 100
|
|
181
|
+
if percentage > 0: # Only include non-zero values
|
|
182
|
+
chart_data.append({
|
|
183
|
+
'Checkpoint': checkpoint_label,
|
|
184
|
+
'Outcome': outcome,
|
|
185
|
+
'Percentage': percentage,
|
|
186
|
+
'Color': colors.get(outcome, "#999999")
|
|
187
|
+
})
|
|
188
|
+
|
|
189
|
+
df = pd.DataFrame(chart_data)
|
|
190
|
+
|
|
191
|
+
# Get only outcomes that actually appear in the data
|
|
192
|
+
outcomes_in_data = df['Outcome'].unique().tolist()
|
|
193
|
+
|
|
194
|
+
# Filter stacking_order to only include outcomes present in data
|
|
195
|
+
filtered_stacking_order = [o for o in stacking_order if o in outcomes_in_data]
|
|
196
|
+
|
|
197
|
+
# Add sort index to control stacking order
|
|
198
|
+
outcome_to_index = {outcome: i for i, outcome in enumerate(filtered_stacking_order)}
|
|
199
|
+
df['sort_index'] = df['Outcome'].map(outcome_to_index)
|
|
200
|
+
|
|
201
|
+
# Create color scale only for outcomes in data
|
|
202
|
+
color_scale = alt.Scale(
|
|
203
|
+
domain=filtered_stacking_order,
|
|
204
|
+
range=[colors[o] for o in filtered_stacking_order]
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Create stacked bar chart
|
|
208
|
+
chart = alt.Chart(df).mark_bar(
|
|
209
|
+
stroke='white',
|
|
210
|
+
strokeWidth=1
|
|
211
|
+
).encode(
|
|
212
|
+
x=alt.X('Checkpoint:N', title=None, axis=alt.Axis(labelFontSize=18, labelAngle=0)),
|
|
213
|
+
y=alt.Y('Percentage:Q', title='Share of Episodes (%)', axis=alt.Axis(labelFontSize=18, titleFontSize=20)).scale(domain=[0, 100]),
|
|
214
|
+
color=alt.Color('Outcome:N', scale=color_scale, sort=filtered_stacking_order, legend=alt.Legend(
|
|
215
|
+
title=None,
|
|
216
|
+
labelFontSize=16,
|
|
217
|
+
symbolSize=200,
|
|
218
|
+
orient="bottom",
|
|
219
|
+
direction="horizontal",
|
|
220
|
+
labelLimit=0,
|
|
221
|
+
columns=3
|
|
222
|
+
)),
|
|
223
|
+
order=alt.Order('sort_index:Q'),
|
|
224
|
+
tooltip=['Checkpoint', 'Outcome', alt.Tooltip('Percentage:Q', format='.1f')]
|
|
225
|
+
).properties(
|
|
226
|
+
width=600,
|
|
227
|
+
height=400,
|
|
228
|
+
title={
|
|
229
|
+
'text': ' EgoGym-Pick Failure Modes by Checkpoint',
|
|
230
|
+
'fontSize': 22,
|
|
231
|
+
'anchor': 'start',
|
|
232
|
+
'dx': 60,
|
|
233
|
+
'dy': -20
|
|
234
|
+
},
|
|
235
|
+
padding={"left": 5, "right": 5, "top": 20, "bottom": 40}
|
|
236
|
+
).configure_view(
|
|
237
|
+
strokeWidth=0
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Save chart
|
|
241
|
+
chart.save("failure_modes_by_checkpoint.html")
|
|
242
|
+
chart.save("failure_modes_by_checkpoint.pdf", scale_factor=3)
|
|
243
|
+
chart.save("failure_modes_by_checkpoint.png", scale_factor=3)
|
|
244
|
+
print("\nPlot saved to: failure_modes_by_checkpoint.html and failure_modes_by_checkpoint.png")
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
if __name__ == "__main__":
|
|
248
|
+
plot_failure_modes()
|