egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,338 @@
1
+
2
+ import numpy as np
3
+ import pandas as pd
4
+ import altair as alt
5
+ from scipy.stats import beta, pearsonr, gaussian_kde
6
+
7
+
8
+ def plot_bayesian_correlation_hardcoded():
9
+ # Set random seed for reproducibility
10
+ np.random.seed(42)
11
+
12
+ # Register and enable custom font theme
13
+ alt.themes.register('custom_theme', lambda: {
14
+ 'config': {
15
+ 'title': {'font': 'Produkt'},
16
+ 'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
17
+ 'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
18
+ 'mark': {'font': 'Produkt'},
19
+ 'text': {'font': 'Produkt'},
20
+ }
21
+ })
22
+ alt.themes.enable('custom_theme')
23
+
24
+ # Hardcoded values
25
+ checkpoint_data = {
26
+ "checkpoint_10": {
27
+ "sim_success": 1001,
28
+ "sim_total": 5000,
29
+ "real_success": 60,
30
+ "real_total": 250
31
+ },
32
+ "checkpoint_50": {
33
+ "sim_success": 1847,
34
+ "sim_total": 5000,
35
+ "real_success": 96,
36
+ "real_total": 250
37
+ },
38
+ "checkpoint_64": {
39
+ "sim_success": 3137,
40
+ "sim_total": 5000,
41
+ "real_success": 168,
42
+ "real_total": 250
43
+ },
44
+ "checkpoint_80": {
45
+ "sim_success": 3994,
46
+ "sim_total": 5000,
47
+ "real_success": 208,
48
+ "real_total": 250
49
+ }
50
+ }
51
+
52
+ stats = {}
53
+ checkpoint_order = ["checkpoint_10", "checkpoint_50", "checkpoint_64", "checkpoint_80"]
54
+
55
+ for checkpoint in checkpoint_order:
56
+ data = checkpoint_data[checkpoint]
57
+
58
+ # Simulation stats
59
+ a_sim = 1 + data["sim_success"]
60
+ b_sim = 1 + (data["sim_total"] - data["sim_success"])
61
+ sim_mean = 100 * a_sim / (a_sim + b_sim)
62
+
63
+ # Real stats
64
+ a_real = 1 + data["real_success"]
65
+ b_real = 1 + (data["real_total"] - data["real_success"])
66
+ real_mean = 100 * a_real / (a_real + b_real)
67
+
68
+ print(f"{checkpoint}: sim={data['sim_success']}/{data['sim_total']} ({sim_mean:.1f}%), real={real_mean:.1f}% ({data['real_success']}/{data['real_total']})")
69
+
70
+ stats[checkpoint] = {
71
+ "sim_mean": sim_mean,
72
+ "sim_alpha": a_sim,
73
+ "sim_beta": b_sim,
74
+ "real_mean": real_mean,
75
+ "real_alpha": a_real,
76
+ "real_beta": b_real,
77
+ }
78
+
79
+ if not stats:
80
+ print("No data found!")
81
+ return
82
+
83
+ checkpoint_names = list(stats.keys())
84
+ sim_rates = np.array([stats[c]["sim_mean"] for c in checkpoint_names])
85
+ real_rates = np.array([stats[c]["real_mean"] for c in checkpoint_names])
86
+
87
+ r, p_value = pearsonr(sim_rates, real_rates)
88
+
89
+ # Compute confidence intervals and violin plot data
90
+ violin_data = []
91
+ point_data = []
92
+ error_data = []
93
+
94
+ for i, checkpoint in enumerate(checkpoint_names):
95
+ sim_lo = beta.ppf(0.025, stats[checkpoint]["sim_alpha"], stats[checkpoint]["sim_beta"]) * 100
96
+ sim_hi = beta.ppf(0.975, stats[checkpoint]["sim_alpha"], stats[checkpoint]["sim_beta"]) * 100
97
+
98
+ # Generate violin data
99
+ y_samples = beta.rvs(stats[checkpoint]["real_alpha"], stats[checkpoint]["real_beta"], size=10000) * 100
100
+
101
+ xc = sim_rates[i]
102
+ yc = real_rates[i]
103
+
104
+ # Store point data
105
+ point_data.append({
106
+ 'checkpoint': checkpoint,
107
+ 'sim_rate': xc,
108
+ 'real_rate': yc,
109
+ 'sim_lo': sim_lo,
110
+ 'sim_hi': sim_hi
111
+ })
112
+
113
+ # Store violin data for density transform
114
+ for y_val in y_samples:
115
+ violin_data.append({
116
+ 'checkpoint': checkpoint,
117
+ 'sim_rate': xc,
118
+ 'real_rate': y_val
119
+ })
120
+
121
+ # Compute bootstrap correlation confidence interval
122
+ n_samples = 1000
123
+ r_samples = []
124
+ for _ in range(n_samples):
125
+ sim_sample = [beta.rvs(stats[c]["sim_alpha"], stats[c]["sim_beta"]) * 100 for c in checkpoint_names]
126
+ real_sample = [beta.rvs(stats[c]["real_alpha"], stats[c]["real_beta"]) * 100 for c in checkpoint_names]
127
+ r_sample, _ = pearsonr(sim_sample, real_sample)
128
+ r_samples.append(r_sample)
129
+ r_samples = np.array(r_samples)
130
+ r_lo = np.percentile(r_samples, 2.5)
131
+ r_hi = np.percentile(r_samples, 97.5)
132
+
133
+ # Create DataFrames
134
+ point_df = pd.DataFrame(point_data)
135
+
136
+ # For violins, we need much less data - just sample points
137
+ violin_sample_data = []
138
+ for checkpoint in checkpoint_names:
139
+ # Use only 500 samples per checkpoint to avoid memory issues
140
+ y_samples = beta.rvs(stats[checkpoint]["real_alpha"], stats[checkpoint]["real_beta"], size=500) * 100
141
+ xc = point_df[point_df['checkpoint'] == checkpoint]['sim_rate'].values[0]
142
+ for y_val in y_samples:
143
+ violin_sample_data.append({
144
+ 'checkpoint': checkpoint,
145
+ 'sim_rate': xc,
146
+ 'real_rate': y_val
147
+ })
148
+
149
+ violin_df = pd.DataFrame(violin_sample_data)
150
+
151
+ # Compute regression line
152
+ z = np.polyfit(sim_rates, real_rates, 1)
153
+ p = np.poly1d(z)
154
+ xs = np.linspace(sim_rates.min() - 5, sim_rates.max() + 5, 200)
155
+ regression_df = pd.DataFrame({'sim_rate': xs, 'real_rate': p(xs)})
156
+
157
+ # Manually create violin shapes positioned at sim_rate coordinates
158
+ violin_width_scale = 2.5
159
+ violin_polygon_data = []
160
+
161
+ # Color mapping from lowest to highest checkpoint
162
+ checkpoint_colors = {
163
+ "checkpoint_10": "#F8F0FA", # very light purple
164
+ "checkpoint_50": "#D6BAE2", # medium purple
165
+ "checkpoint_64": "#9B66BB", # medium-dark purple
166
+ "checkpoint_80": "#4B136D" # deep purple
167
+ }
168
+
169
+ for checkpoint in checkpoint_names:
170
+ # Get samples and compute KDE
171
+ y_samples = beta.rvs(stats[checkpoint]["real_alpha"], stats[checkpoint]["real_beta"], size=2000) * 100
172
+ kde = gaussian_kde(y_samples)
173
+
174
+ # Create density curve points
175
+ y_points = np.linspace(y_samples.min(), y_samples.max(), 100)
176
+ densities = kde(y_points)
177
+ densities = densities / densities.max() * violin_width_scale
178
+
179
+ xc = point_df[point_df['checkpoint'] == checkpoint]['sim_rate'].values[0]
180
+ color = checkpoint_colors[checkpoint]
181
+
182
+ # Create closed polygon: left side up, then right side down
183
+ for i, (y, d) in enumerate(zip(y_points, densities)):
184
+ violin_polygon_data.append({
185
+ 'checkpoint': checkpoint,
186
+ 'x': xc - d,
187
+ 'y': y,
188
+ 'order': i,
189
+ 'color': color
190
+ })
191
+ # Right side going back down
192
+ for i, (y, d) in enumerate(zip(y_points[::-1], densities[::-1])):
193
+ violin_polygon_data.append({
194
+ 'checkpoint': checkpoint,
195
+ 'x': xc + d,
196
+ 'y': y,
197
+ 'order': len(y_points) + i,
198
+ 'color': color
199
+ })
200
+
201
+ violin_polygon_df = pd.DataFrame(violin_polygon_data)
202
+
203
+ # Create violin shapes with color encoding
204
+ violins = alt.Chart(violin_polygon_df).mark_line(
205
+ fillOpacity=0.6,
206
+ stroke='#8B4789',
207
+ strokeWidth=1.1,
208
+ interpolate='linear',
209
+ filled=True
210
+ ).encode(
211
+ x=alt.X('x:Q', title='EgoGym Performance (%)').scale(zero=False),
212
+ y=alt.Y('y:Q', title='Real Performance (%)').scale(zero=False),
213
+ order='order:Q',
214
+ detail='checkpoint:N',
215
+ fill=alt.Fill('color:N', scale=None, legend=None)
216
+ )
217
+
218
+
219
+ # Create regression line
220
+ regression_line = alt.Chart(regression_df).mark_line(
221
+ strokeDash=[5, 5],
222
+ color='black',
223
+ opacity=0.5,
224
+ size=1.5
225
+ ).encode(
226
+ x=alt.X('sim_rate:Q').scale(zero=False),
227
+ y=alt.Y('real_rate:Q').scale(zero=False)
228
+ )
229
+
230
+ # Create error bars for simulation
231
+ error_bars = alt.Chart(point_df).mark_errorbar(ticks=True, thickness=1.5).encode(
232
+ x=alt.X('sim_lo:Q', title=''),
233
+ x2=alt.X2('sim_hi:Q'),
234
+ y='real_rate:Q'
235
+ )
236
+
237
+ # Create horizontal lines at mean (for each violin)
238
+ mean_line_data = []
239
+ for checkpoint in checkpoint_names:
240
+ xc = point_df[point_df['checkpoint'] == checkpoint]['sim_rate'].values[0]
241
+ yc = point_df[point_df['checkpoint'] == checkpoint]['real_rate'].values[0]
242
+ mean_line_data.append({
243
+ 'checkpoint': checkpoint,
244
+ 'x_left': xc - 1.0,
245
+ 'x_right': xc + 1.0,
246
+ 'y': yc
247
+ })
248
+ mean_line_df = pd.DataFrame(mean_line_data)
249
+
250
+ mean_lines = alt.Chart(mean_line_df).mark_rule(color='black', size=1.2).encode(
251
+ x=alt.X('x_left:Q'),
252
+ x2=alt.X2('x_right:Q'),
253
+ y='y:Q'
254
+ )
255
+
256
+ # Create central points
257
+ points = alt.Chart(point_df).mark_point(
258
+ filled=True,
259
+ size=100,
260
+ color='#A894DB',
261
+ stroke='#8B4789',
262
+ strokeWidth=1.5,
263
+ shape='diamond'
264
+ ).encode(
265
+ x=alt.X('sim_rate:Q').scale(zero=False),
266
+ y=alt.Y('real_rate:Q').scale(zero=False),
267
+ tooltip=['checkpoint', 'sim_rate', 'real_rate']
268
+ )
269
+
270
+ # Combine layers (violins, regression line, error bars, mean lines and points on top)
271
+ chart = (violins + regression_line + mean_lines + error_bars + points).properties(
272
+ title={
273
+ 'text': ' Blind EgoGym-Pick Sim-to-Real Correlation',
274
+ 'fontSize': 24,
275
+ 'anchor': 'start',
276
+ 'dx': 15,
277
+ 'dy': -8
278
+ },
279
+ width=500,
280
+ height=400
281
+ ).configure_axis(
282
+ labelFontSize=16,
283
+ titleFontSize=18,
284
+ titleFontStyle='normal',
285
+ grid=True,
286
+ gridOpacity=0.3,
287
+ tickCount=6
288
+ ).configure_view(
289
+ strokeWidth=0
290
+ )
291
+
292
+ # Add background box for correlation text (wider and thinner)
293
+ correlation_box = alt.Chart(pd.DataFrame([{
294
+ 'x': point_df['sim_rate'].max() - 18,
295
+ 'y': point_df['real_rate'].min() - 10,
296
+ 'x2': point_df['sim_rate'].max() + 4,
297
+ 'y2': point_df['real_rate'].min() - 4
298
+ }])).mark_rect(
299
+ fill='#F7D45B',
300
+ stroke='gray',
301
+ strokeWidth=1.5,
302
+ opacity=0.9
303
+ ).encode(
304
+ x=alt.X('x:Q'),
305
+ y=alt.Y('y:Q'),
306
+ x2='x2:Q',
307
+ y2='y2:Q'
308
+ )
309
+
310
+ # Add correlation text annotation in box
311
+ correlation_text = alt.Chart(pd.DataFrame([{
312
+ 'x': point_df['sim_rate'].max() - 7,
313
+ 'y': point_df['real_rate'].min() - 7,
314
+ 'text': f'95% CI r: {r_lo:.3f}, {r_hi:.3f}'
315
+ }])).mark_text(
316
+ align='center',
317
+ baseline='middle',
318
+ fontSize=16,
319
+ font='Produkt'
320
+ ).encode(
321
+ x=alt.X('x:Q'),
322
+ y=alt.Y('y:Q'),
323
+ text='text:N'
324
+ )
325
+
326
+ # Combine with correlation box and text
327
+ chart = chart + correlation_box + correlation_text
328
+
329
+ # Save chart
330
+ chart.save("checkpoint_correlation.html")
331
+ chart.save("checkpoint_correlation.png", scale_factor=2.0)
332
+ chart.save("checkpoint_correlation.pdf", scale_factor=2.0)
333
+ print(f"\nCorrelation: r = {r:.3f}, p = {p_value:.4f}")
334
+ print("\nPlot saved to: checkpoint_correlation.html and checkpoint_correlation.png")
335
+
336
+
337
+ if __name__ == "__main__":
338
+ plot_bayesian_correlation_hardcoded()
@@ -0,0 +1,248 @@
1
+
2
+ import os
3
+ import numpy as np
4
+ import pandas as pd
5
+ import altair as alt
6
+
7
+ # Register custom font for export
8
+ alt.themes.register('custom_theme', lambda: {
9
+ 'config': {
10
+ 'title': {'font': 'Produkt'},
11
+ 'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
12
+ 'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
13
+ 'mark': {'font': 'Produkt'},
14
+ 'text': {'font': 'Produkt'},
15
+ }
16
+ })
17
+ alt.themes.enable('custom_theme')
18
+
19
+ BASE_DIR = "logs/5_objects"
20
+ REWARD_THRESHOLD = 0.03
21
+ PARTIAL_LIFT_THRESHOLD = 0.005
22
+
23
+
24
+ def compute_outcomes_from_csv(csv_path):
25
+ if not os.path.exists(csv_path):
26
+ return None, None
27
+ df = pd.read_csv(csv_path, sep="\t")
28
+ total_episodes = len(df)
29
+
30
+ # Count successes
31
+ successes = (df["max_reward"] > REWARD_THRESHOLD).sum()
32
+
33
+ def get_failure_mode(row):
34
+ if row["max_reward"] > REWARD_THRESHOLD:
35
+ return "Success"
36
+
37
+ bodies_contacted = str(row.get("grasped_bodies", ""))
38
+ object_name = str(row.get("object_name", ""))
39
+ grasping_object = row.get("grasping_object", False)
40
+ is_grasping = row.get("is_grasping", False) # Final gripper state only
41
+
42
+ # Extract body names from grasped_bodies list
43
+ has_target_contact = object_name in bodies_contacted
44
+ has_gripper_contact = "left" in bodies_contacted or "right" in bodies_contacted
45
+ has_any_object_contact = "object" in bodies_contacted
46
+ has_wrong_object_contact = has_any_object_contact and not has_target_contact
47
+
48
+ # Decision tree (most specific to least specific)
49
+ # Note: is_grasping is only final state, grasping_object tracks if target was ever grasped
50
+
51
+ # 1. Successfully grasped target but didn't lift high enough
52
+ if grasping_object and row["max_reward"] >= PARTIAL_LIFT_THRESHOLD:
53
+ return "Did not lift enough"
54
+
55
+ # 2. Grasped target but barely lifted (or dropped immediately)
56
+ if grasping_object:
57
+ return "Object touched but not grasped"
58
+
59
+ # 3. Made contact with target but never achieved grasp
60
+ if has_target_contact:
61
+ return "Object touched but not grasped"
62
+
63
+ # 4. Grasped wrong object with significant lift
64
+ if has_wrong_object_contact and row["max_reward"] >= PARTIAL_LIFT_THRESHOLD:
65
+ return "Picked wrong object"
66
+
67
+ # 5. Contacted wrong object (with or without final grasp state)
68
+ if has_wrong_object_contact:
69
+ return "Picked wrong object"
70
+
71
+ # 6. Gripper closed at end or had gripper contact but no object identification
72
+ if is_grasping or has_gripper_contact:
73
+ return "Empty Grasp"
74
+
75
+ # 7. Never made meaningful contact with anything
76
+ return "Did not grasp"
77
+
78
+ df["outcome"] = df.apply(get_failure_mode, axis=1)
79
+
80
+ outcomes = {
81
+ "Success": 0,
82
+ "Did not lift enough": 0,
83
+ "Object touched but not grasped": 0,
84
+ "Picked wrong object": 0,
85
+ "Empty Grasp": 0,
86
+ "Did not grasp": 0,
87
+ }
88
+
89
+ for _, row in df.iterrows():
90
+ mode = row["outcome"]
91
+ if mode in outcomes:
92
+ outcomes[mode] += 1
93
+
94
+ return outcomes, total_episodes
95
+
96
+
97
+ def plot_failure_modes():
98
+ checkpoint_order = ["checkpoint_10", "checkpoint_50", "checkpoint_64", "checkpoint_80"]
99
+ checkpoint_labels = {
100
+ "checkpoint_10": "Checkpoint 24%",
101
+ "checkpoint_50": "Checkpoint 39%",
102
+ "checkpoint_64": "Checkpoint 68%",
103
+ "checkpoint_80": "Checkpoint 83%"
104
+ }
105
+ outcome_data = {}
106
+ episode_totals = {}
107
+
108
+ for checkpoint in checkpoint_order:
109
+ csv_path = os.path.join(BASE_DIR, checkpoint, "log.csv")
110
+ outcomes, total_episodes = compute_outcomes_from_csv(csv_path)
111
+ if outcomes is None:
112
+ print(f"Warning: Could not find CSV data for {checkpoint}")
113
+ continue
114
+ outcome_data[checkpoint] = outcomes
115
+ episode_totals[checkpoint] = total_episodes
116
+
117
+ if not outcome_data:
118
+ print("No data found!")
119
+ return
120
+
121
+ checkpoint_names = list(outcome_data.keys())
122
+
123
+ # Print results to terminal
124
+ print("\n" + "="*80)
125
+ print("FAILURE MODE ANALYSIS BY CHECKPOINT")
126
+ print("="*80)
127
+ for checkpoint in checkpoint_names:
128
+ checkpoint_label = checkpoint_labels.get(checkpoint, checkpoint)
129
+ total = episode_totals.get(checkpoint, 0)
130
+ print(f"\n{checkpoint_label} (n={total}):")
131
+ print("-" * 60)
132
+ for outcome, count in outcome_data[checkpoint].items():
133
+ percentage = (count / total * 100) if total > 0 else 0
134
+ print(f" {outcome:40s}: {count:3d} ({percentage:5.1f}%)")
135
+ print("\n" + "="*80 + "\n")
136
+
137
+ outcome_order = [
138
+ "Success",
139
+ "Did not lift enough",
140
+ "Object touched but not grasped",
141
+ "Picked wrong object",
142
+ "Empty Grasp",
143
+ "Did not grasp",
144
+ ]
145
+
146
+ colors = {
147
+ "Success": "#388038", # Green
148
+ "Did not lift enough": "#F7D45B", # Yellow
149
+ "Object touched but not grasped": "#66ACF7", # Blue
150
+ "Picked wrong object": "#F0529C", # Pink
151
+ "Empty Grasp": "#9B66BB", # Purple
152
+ "Did not grasp": "#870927", # Dark red
153
+ }
154
+
155
+ # Calculate average percentage for each outcome to sort by size
156
+ outcome_totals = {}
157
+ for checkpoint in checkpoint_names:
158
+ total = episode_totals.get(checkpoint, 0) or 1
159
+ for outcome in outcome_order:
160
+ count = outcome_data[checkpoint].get(outcome, 0)
161
+ percentage = count / total * 100
162
+ if outcome not in outcome_totals:
163
+ outcome_totals[outcome] = 0
164
+ outcome_totals[outcome] += percentage
165
+
166
+ # Sort outcomes by total percentage (largest first), excluding Success
167
+ failure_modes = [o for o in outcome_order if o != "Success"]
168
+ sorted_failures = sorted(failure_modes, key=lambda x: outcome_totals.get(x, 0), reverse=True)
169
+
170
+ # Build stacking order: Success at bottom, then failures from largest to smallest going up
171
+ stacking_order = ["Success"] + sorted_failures
172
+
173
+ chart_data = []
174
+ for checkpoint in checkpoint_names:
175
+ checkpoint_label = checkpoint_labels.get(checkpoint, checkpoint)
176
+
177
+ total = episode_totals.get(checkpoint, 0) or 1
178
+ for outcome in stacking_order:
179
+ count = outcome_data[checkpoint].get(outcome, 0)
180
+ percentage = count / total * 100
181
+ if percentage > 0: # Only include non-zero values
182
+ chart_data.append({
183
+ 'Checkpoint': checkpoint_label,
184
+ 'Outcome': outcome,
185
+ 'Percentage': percentage,
186
+ 'Color': colors.get(outcome, "#999999")
187
+ })
188
+
189
+ df = pd.DataFrame(chart_data)
190
+
191
+ # Get only outcomes that actually appear in the data
192
+ outcomes_in_data = df['Outcome'].unique().tolist()
193
+
194
+ # Filter stacking_order to only include outcomes present in data
195
+ filtered_stacking_order = [o for o in stacking_order if o in outcomes_in_data]
196
+
197
+ # Add sort index to control stacking order
198
+ outcome_to_index = {outcome: i for i, outcome in enumerate(filtered_stacking_order)}
199
+ df['sort_index'] = df['Outcome'].map(outcome_to_index)
200
+
201
+ # Create color scale only for outcomes in data
202
+ color_scale = alt.Scale(
203
+ domain=filtered_stacking_order,
204
+ range=[colors[o] for o in filtered_stacking_order]
205
+ )
206
+
207
+ # Create stacked bar chart
208
+ chart = alt.Chart(df).mark_bar(
209
+ stroke='white',
210
+ strokeWidth=1
211
+ ).encode(
212
+ x=alt.X('Checkpoint:N', title=None, axis=alt.Axis(labelFontSize=18, labelAngle=0)),
213
+ y=alt.Y('Percentage:Q', title='Share of Episodes (%)', axis=alt.Axis(labelFontSize=18, titleFontSize=20)).scale(domain=[0, 100]),
214
+ color=alt.Color('Outcome:N', scale=color_scale, sort=filtered_stacking_order, legend=alt.Legend(
215
+ title=None,
216
+ labelFontSize=16,
217
+ symbolSize=200,
218
+ orient="bottom",
219
+ direction="horizontal",
220
+ labelLimit=0,
221
+ columns=3
222
+ )),
223
+ order=alt.Order('sort_index:Q'),
224
+ tooltip=['Checkpoint', 'Outcome', alt.Tooltip('Percentage:Q', format='.1f')]
225
+ ).properties(
226
+ width=600,
227
+ height=400,
228
+ title={
229
+ 'text': ' EgoGym-Pick Failure Modes by Checkpoint',
230
+ 'fontSize': 22,
231
+ 'anchor': 'start',
232
+ 'dx': 60,
233
+ 'dy': -20
234
+ },
235
+ padding={"left": 5, "right": 5, "top": 20, "bottom": 40}
236
+ ).configure_view(
237
+ strokeWidth=0
238
+ )
239
+
240
+ # Save chart
241
+ chart.save("failure_modes_by_checkpoint.html")
242
+ chart.save("failure_modes_by_checkpoint.pdf", scale_factor=3)
243
+ chart.save("failure_modes_by_checkpoint.png", scale_factor=3)
244
+ print("\nPlot saved to: failure_modes_by_checkpoint.html and failure_modes_by_checkpoint.png")
245
+
246
+
247
+ if __name__ == "__main__":
248
+ plot_failure_modes()