egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,195 @@
1
+
2
+ import numpy as np
3
+ import pandas as pd
4
+ import altair as alt
5
+
6
+ # Register custom font for export
7
+ alt.themes.register('custom_theme', lambda: {
8
+ 'config': {
9
+ 'title': {'font': 'Produkt'},
10
+ 'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
11
+ 'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
12
+ 'mark': {'font': 'Produkt'},
13
+ 'text': {'font': 'Produkt'},
14
+ }
15
+ })
16
+ alt.themes.enable('custom_theme')
17
+
18
+ # Hardcoded data from experiments
19
+ CHECKPOINT_DATA = {
20
+ "CAP-A": {
21
+ "total": 5000,
22
+ "outcomes": {
23
+ "Success": 1001,
24
+ "Did not lift enough": 677,
25
+ "Object touched but not grasped": 1295,
26
+ "Picked wrong object": 337,
27
+ "Empty Grasp": 1690,
28
+ "Did not grasp": 0,
29
+ }
30
+ },
31
+ "CAP-B": {
32
+ "total": 5000,
33
+ "outcomes": {
34
+ "Success": 1847,
35
+ "Did not lift enough": 995,
36
+ "Object touched but not grasped": 1213,
37
+ "Picked wrong object": 208,
38
+ "Empty Grasp": 737,
39
+ "Did not grasp": 0,
40
+ }
41
+ },
42
+ "CAP-C": {
43
+ "total": 5000,
44
+ "outcomes": {
45
+ "Success": 3137,
46
+ "Did not lift enough": 263,
47
+ "Object touched but not grasped": 597,
48
+ "Picked wrong object": 110,
49
+ "Empty Grasp": 893,
50
+ "Did not grasp": 0,
51
+ }
52
+ },
53
+ "CAP-D": {
54
+ "total": 5000,
55
+ "outcomes": {
56
+ "Success": 3994,
57
+ "Did not lift enough": 139,
58
+ "Object touched but not grasped": 362,
59
+ "Picked wrong object": 59,
60
+ "Empty Grasp": 446,
61
+ "Did not grasp": 0,
62
+ }
63
+ }
64
+ }
65
+
66
+
67
+ def plot_failure_modes():
68
+ checkpoint_order = ["CAP-A", "CAP-B", "CAP-C", "CAP-D"]
69
+
70
+ # Print results to terminal
71
+ print("\n" + "="*80)
72
+ print("FAILURE MODE ANALYSIS BY CHECKPOINT")
73
+ print("="*80)
74
+ for checkpoint in checkpoint_order:
75
+ data = CHECKPOINT_DATA[checkpoint]
76
+ total = data["total"]
77
+ print(f"\n{checkpoint} (n={total}):")
78
+ print("-" * 60)
79
+ for outcome, count in data["outcomes"].items():
80
+ percentage = (count / total * 100) if total > 0 else 0
81
+ print(f" {outcome:40s}: {count:3d} ({percentage:5.1f}%)")
82
+ print("\n" + "="*80 + "\n")
83
+
84
+ outcome_order = [
85
+ "Success",
86
+ "Did not lift enough",
87
+ "Object touched but not grasped",
88
+ "Picked wrong object",
89
+ "Empty Grasp",
90
+ "Did not grasp",
91
+ ]
92
+
93
+ colors = {
94
+ "Success": "#388038", # Green
95
+ "Did not lift enough": "#F7D45B", # Yellow
96
+ "Object touched but not grasped": "#66ACF7", # Blue
97
+ "Picked wrong object": "#F0529C", # Pink
98
+ "Empty Grasp": "#9B66BB", # Purple
99
+ "Did not grasp": "#870927", # Dark red
100
+ }
101
+
102
+ # Calculate average percentage for each outcome to sort by size
103
+ outcome_totals = {}
104
+ for checkpoint in checkpoint_order:
105
+ data = CHECKPOINT_DATA[checkpoint]
106
+ total = data["total"] or 1
107
+ for outcome in outcome_order:
108
+ count = data["outcomes"].get(outcome, 0)
109
+ percentage = count / total * 100
110
+ if outcome not in outcome_totals:
111
+ outcome_totals[outcome] = 0
112
+ outcome_totals[outcome] += percentage
113
+
114
+ # Sort outcomes by total percentage (largest first), excluding Success
115
+ failure_modes = [o for o in outcome_order if o != "Success"]
116
+ sorted_failures = sorted(failure_modes, key=lambda x: outcome_totals.get(x, 0), reverse=True)
117
+
118
+ # Build stacking order: Success at bottom, then failures from largest to smallest going up
119
+ stacking_order = ["Success"] + sorted_failures
120
+
121
+ chart_data = []
122
+ for checkpoint in checkpoint_order:
123
+ data = CHECKPOINT_DATA[checkpoint]
124
+ total = data["total"] or 1
125
+ for outcome in stacking_order:
126
+ count = data["outcomes"].get(outcome, 0)
127
+ percentage = count / total * 100
128
+ if percentage > 0: # Only include non-zero values
129
+ chart_data.append({
130
+ 'Policy': checkpoint,
131
+ 'Outcome': outcome,
132
+ 'Percentage': percentage,
133
+ 'Color': colors.get(outcome, "#999999")
134
+ })
135
+
136
+ df = pd.DataFrame(chart_data)
137
+
138
+ # Get only outcomes that actually appear in the data
139
+ outcomes_in_data = df['Outcome'].unique().tolist()
140
+
141
+ # Filter stacking_order to only include outcomes present in data
142
+ filtered_stacking_order = [o for o in stacking_order if o in outcomes_in_data]
143
+
144
+ # Add sort index to control stacking order
145
+ outcome_to_index = {outcome: i for i, outcome in enumerate(filtered_stacking_order)}
146
+ df['sort_index'] = df['Outcome'].map(outcome_to_index)
147
+
148
+ # Create color scale only for outcomes in data
149
+ color_scale = alt.Scale(
150
+ domain=filtered_stacking_order,
151
+ range=[colors[o] for o in filtered_stacking_order]
152
+ )
153
+
154
+ # Create stacked bar chart
155
+ chart = alt.Chart(df).mark_bar(
156
+ stroke='white',
157
+ strokeWidth=1
158
+ ).encode(
159
+ x=alt.X('Policy:N', title=None, axis=alt.Axis(labelFontSize=18, labelAngle=0)),
160
+ y=alt.Y('Percentage:Q', title='Share of Episodes (%)', axis=alt.Axis(labelFontSize=18, titleFontSize=20)).scale(domain=[0, 100]),
161
+ color=alt.Color('Outcome:N', scale=color_scale, sort=filtered_stacking_order, legend=alt.Legend(
162
+ title=None,
163
+ labelFontSize=16,
164
+ symbolSize=200,
165
+ orient="bottom",
166
+ direction="horizontal",
167
+ labelLimit=0,
168
+ columns=3
169
+ )),
170
+ order=alt.Order('sort_index:Q'),
171
+ tooltip=['Policy', 'Outcome', alt.Tooltip('Percentage:Q', format='.1f')]
172
+ ).properties(
173
+ width=400,
174
+ height=400,
175
+ title={
176
+ 'text': ' EgoGym-Pick Failure Modes by Policy',
177
+ 'fontSize': 22,
178
+ 'anchor': 'start',
179
+ 'dx': 60,
180
+ 'dy': -20
181
+ },
182
+ padding={"left": 5, "right": 5, "top": 20, "bottom": 40}
183
+ ).configure_view(
184
+ strokeWidth=0
185
+ )
186
+
187
+ # Save chart
188
+ chart.save("failure_modes_by_checkpoint.html")
189
+ chart.save("failure_modes_by_checkpoint.pdf", scale_factor=3)
190
+ chart.save("failure_modes_by_checkpoint.png", scale_factor=3)
191
+ print("\nPlot saved to: failure_modes_by_checkpoint.html and failure_modes_by_checkpoint.png")
192
+
193
+
194
+ if __name__ == "__main__":
195
+ plot_failure_modes()
@@ -0,0 +1,257 @@
1
+
2
+ import os
3
+ import numpy as np
4
+ import pandas as pd
5
+ import altair as alt
6
+
7
+ # Register custom font for export
8
+ alt.themes.register('custom_theme', lambda: {
9
+ 'config': {
10
+ 'title': {'font': 'Produkt'},
11
+ 'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
12
+ 'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
13
+ 'mark': {'font': 'Produkt'},
14
+ 'text': {'font': 'Produkt'},
15
+ }
16
+ })
17
+ alt.themes.enable('custom_theme')
18
+
19
+ BASE_DIR = "logs"
20
+ REWARD_THRESHOLD = 0.03
21
+ PARTIAL_LIFT_THRESHOLD = 0.01
22
+
23
+ MODELS = {
24
+ "CAP + Oracle": "rum/oracle",
25
+ "CAP + Gemini": "rum/gemini",
26
+ "CAP + Molmo": "rum/molmo",
27
+ "CAP + Moondream": "rum/moondream",
28
+ }
29
+
30
+
31
+ def compute_outcomes_from_csv(csv_path):
32
+ if not os.path.exists(csv_path):
33
+ return None, None
34
+ df = pd.read_csv(csv_path, sep="\t")
35
+ total_episodes = len(df)
36
+
37
+ # Count successes
38
+ successes = (df["max_reward"] > REWARD_THRESHOLD).sum()
39
+
40
+ def get_failure_mode(row):
41
+ if row["max_reward"] > REWARD_THRESHOLD:
42
+ return "Success"
43
+
44
+ bodies_contacted = str(row.get("grasped_bodies", ""))
45
+ object_name = str(row.get("object_name", ""))
46
+ grasping_object = row.get("grasping_object", False)
47
+ is_grasping = row.get("is_grasping", False) # Final gripper state only
48
+
49
+ # Extract body names from grasped_bodies list
50
+ has_target_contact = object_name in bodies_contacted
51
+ has_gripper_contact = "left" in bodies_contacted or "right" in bodies_contacted
52
+ has_any_object_contact = "object" in bodies_contacted
53
+ has_wrong_object_contact = has_any_object_contact and not has_target_contact
54
+
55
+ # Decision tree (most specific to least specific)
56
+ # Note: is_grasping is only final state, grasping_object tracks if target was ever grasped
57
+
58
+ # 1. Successfully grasped target but didn't lift high enough
59
+ if grasping_object and row["max_reward"] >= PARTIAL_LIFT_THRESHOLD:
60
+ return "Did not lift enough"
61
+
62
+ # 2. Grasped target but barely lifted (or dropped immediately)
63
+ if grasping_object:
64
+ return "Object touched but not grasped"
65
+
66
+ # 3. Made contact with target but never achieved grasp
67
+ if has_target_contact:
68
+ return "Object touched but not grasped"
69
+
70
+ # 4. Grasped wrong object with significant lift
71
+ if has_wrong_object_contact and row["max_reward"] >= PARTIAL_LIFT_THRESHOLD:
72
+ return "Picked wrong object"
73
+
74
+ # 5. Contacted wrong object (with or without final grasp state)
75
+ if has_wrong_object_contact:
76
+ return "Picked wrong object"
77
+
78
+ # 6. Gripper closed at end or had gripper contact but no object identification
79
+ if is_grasping or has_gripper_contact:
80
+ return "Empty Grasp"
81
+
82
+ # 7. Never made meaningful contact with anything
83
+ return "Did not grasp"
84
+
85
+ df["outcome"] = df.apply(get_failure_mode, axis=1)
86
+
87
+ outcomes = {
88
+ "Success": 0,
89
+ "Did not lift enough": 0,
90
+ "Object touched but not grasped": 0,
91
+ "Picked wrong object": 0,
92
+ "Empty Grasp": 0,
93
+ "Did not grasp": 0,
94
+ }
95
+
96
+ for _, row in df.iterrows():
97
+ mode = row["outcome"]
98
+ if mode in outcomes:
99
+ outcomes[mode] += 1
100
+
101
+ return outcomes, total_episodes
102
+
103
+
104
+ def plot_failure_modes():
105
+ outcome_data = {}
106
+ episode_totals = {}
107
+
108
+ # Get data for 5 objects (4 distractors) for each model
109
+ for model_name, model_folder in MODELS.items():
110
+ # Try multiple folder name variations and nested evaluation folders
111
+ possible_paths = [
112
+ os.path.join(BASE_DIR, model_folder, "5_objects", "log.csv"),
113
+ os.path.join(BASE_DIR, model_folder, "5_object", "log.csv"),
114
+ ]
115
+
116
+ csv_path = None
117
+ for path in possible_paths:
118
+ if os.path.exists(path):
119
+ csv_path = path
120
+ break
121
+
122
+ # Check for nested evaluation folders
123
+ if csv_path is None:
124
+ for folder_name in ["5_objects", "5_object"]:
125
+ folder_path = os.path.join(BASE_DIR, model_folder, folder_name)
126
+ if os.path.exists(folder_path) and os.path.isdir(folder_path):
127
+ for subdir in os.listdir(folder_path):
128
+ subdir_path = os.path.join(folder_path, subdir)
129
+ if os.path.isdir(subdir_path):
130
+ candidate = os.path.join(subdir_path, "log.csv")
131
+ if os.path.exists(candidate):
132
+ csv_path = candidate
133
+ break
134
+ if csv_path:
135
+ break
136
+
137
+ if csv_path is None:
138
+ print(f"Warning: Could not find CSV data for {model_name}")
139
+ continue
140
+
141
+ outcomes, total_episodes = compute_outcomes_from_csv(csv_path)
142
+ if outcomes is None:
143
+ print(f"Warning: Could not find CSV data for {model_name}")
144
+ continue
145
+ outcome_data[model_name] = outcomes
146
+ episode_totals[model_name] = total_episodes
147
+
148
+ if not outcome_data:
149
+ print("No data found!")
150
+ return
151
+
152
+ model_names = list(outcome_data.keys())
153
+
154
+ # Print results to terminal
155
+ print("\n" + "="*80)
156
+ print("WRONG OBJECT SELECTION RATE BY VLM (4 Distractors)")
157
+ print("="*80)
158
+ for model in model_names:
159
+ total = episode_totals.get(model, 0)
160
+ print(f"\n{model} (n={total}):")
161
+ print("-" * 60)
162
+ for outcome, count in outcome_data[model].items():
163
+ percentage = (count / total * 100) if total > 0 else 0
164
+ print(f" {outcome:40s}: {count:3d} ({percentage:5.1f}%)")
165
+ print("\n" + "="*80 + "\n")
166
+
167
+ # Only show "Picked wrong object" percentage
168
+ chart_data = []
169
+ for model in model_names:
170
+ total = episode_totals.get(model, 0) or 1
171
+ count = outcome_data[model].get("Picked wrong object", 0)
172
+ proportion = count / total
173
+ percentage = proportion * 100
174
+
175
+ # Calculate standard error for binomial proportion
176
+ se = np.sqrt(proportion * (1 - proportion) / total) * 100
177
+
178
+ chart_data.append({
179
+ 'Model': model,
180
+ 'Percentage': percentage,
181
+ 'SE': se,
182
+ })
183
+
184
+ df = pd.DataFrame(chart_data)
185
+
186
+ # Sort by percentage (lowest to highest)
187
+ df = df.sort_values('Percentage')
188
+
189
+ # Create ordered list for x-axis
190
+ model_order = df['Model'].tolist()
191
+
192
+ # Assign colors from purple family (lightest for best, darkest for worst)
193
+ purple_shades = ["#D6BAE2", "#9B66BB", "#7B4FA2", "#4B136D"] # Lightest to darkest
194
+ color_mapping = {model: purple_shades[i] for i, model in enumerate(model_order)}
195
+ df['Color'] = df['Model'].map(color_mapping)
196
+
197
+ # Create simple bar chart with gradient colors
198
+ bars = alt.Chart(df).mark_bar(
199
+ stroke='white',
200
+ strokeWidth=1,
201
+ cornerRadiusTopLeft=8,
202
+ cornerRadiusTopRight=8
203
+ ).encode(
204
+ x=alt.X('Model:N', title=None, sort=model_order, axis=alt.Axis(labelFontSize=18, labelAngle=0)),
205
+ y=alt.Y('Percentage:Q', title=None, axis=alt.Axis(labelFontSize=18)),
206
+ color=alt.Color('Model:N', scale=alt.Scale(domain=model_order, range=[color_mapping[m] for m in model_order]), legend=None),
207
+ tooltip=['Model', alt.Tooltip('Percentage:Q', format='.1f')]
208
+ )
209
+
210
+ # Add error bars
211
+ error_bars = alt.Chart(df).mark_errorbar(
212
+ color='black',
213
+ thickness=2,
214
+ ticks=alt.MarkConfig(width=10)
215
+ ).encode(
216
+ x=alt.X('Model:N', sort=model_order),
217
+ y=alt.Y('Percentage:Q'),
218
+ yError=alt.YError('SE:Q')
219
+ )
220
+
221
+ # Add percentage labels on top of bars
222
+ text = alt.Chart(df).mark_text(
223
+ align='center',
224
+ baseline='bottom',
225
+ dy=-15,
226
+ fontSize=16,
227
+ fontWeight='bold'
228
+ ).encode(
229
+ x=alt.X('Model:N', sort=model_order),
230
+ y=alt.Y('Percentage:Q'),
231
+ text=alt.Text('Percentage:Q', format='.1f')
232
+ )
233
+
234
+ chart = (bars + error_bars + text).properties(
235
+ width=600,
236
+ height=400,
237
+ title={
238
+ 'text': ' Wrong Object Selection Rate by VLM (4 Distractors)',
239
+ 'fontSize': 22,
240
+ 'anchor': 'start',
241
+ 'dx': 60,
242
+ 'dy': -20
243
+ },
244
+ padding={"left": 5, "right": 5, "top": 20, "bottom": 0}
245
+ ).configure_view(
246
+ strokeWidth=0
247
+ )
248
+
249
+ # Save chart
250
+ chart.save("failure_modes_by_model.html")
251
+ chart.save("failure_modes_by_model.png", scale_factor=3)
252
+ chart.save("failure_modes_by_model.pdf")
253
+ print("Saved: failure_modes_by_model.html, .png, .pdf")
254
+
255
+
256
+ if __name__ == "__main__":
257
+ plot_failure_modes()
@@ -0,0 +1,177 @@
1
+
2
+ import numpy as np
3
+ import pandas as pd
4
+ import altair as alt
5
+
6
+ # Register custom font for export
7
+ alt.themes.register('custom_theme', lambda: {
8
+ 'config': {
9
+ 'title': {'font': 'Produkt'},
10
+ 'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
11
+ 'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
12
+ 'mark': {'font': 'Produkt'},
13
+ 'text': {'font': 'Produkt'},
14
+ }
15
+ })
16
+ alt.themes.enable('custom_theme')
17
+
18
+ # Hardcoded data from experiments
19
+ MODEL_DATA = {
20
+ "CAP + Oracle": {
21
+ "total": 5000,
22
+ "outcomes": {
23
+ "Success": 3994,
24
+ "Did not lift enough": 89,
25
+ "Object touched but not grasped": 412,
26
+ "Picked wrong object": 59,
27
+ "Empty Grasp": 446,
28
+ "Did not grasp": 0,
29
+ }
30
+ },
31
+ "CAP + Gemini": {
32
+ "total": 5000,
33
+ "outcomes": {
34
+ "Success": 3441,
35
+ "Did not lift enough": 88,
36
+ "Object touched but not grasped": 389,
37
+ "Picked wrong object": 448,
38
+ "Empty Grasp": 634,
39
+ "Did not grasp": 0,
40
+ }
41
+ },
42
+ "CAP + Molmo": {
43
+ "total": 5000,
44
+ "outcomes": {
45
+ "Success": 3008,
46
+ "Did not lift enough": 94,
47
+ "Object touched but not grasped": 467,
48
+ "Picked wrong object": 877,
49
+ "Empty Grasp": 554,
50
+ "Did not grasp": 0,
51
+ }
52
+ },
53
+ "CAP + Moondream": {
54
+ "total": 5000,
55
+ "outcomes": {
56
+ "Success": 3033,
57
+ "Did not lift enough": 83,
58
+ "Object touched but not grasped": 342,
59
+ "Picked wrong object": 746,
60
+ "Empty Grasp": 796,
61
+ "Did not grasp": 0,
62
+ }
63
+ }
64
+ }
65
+
66
+
67
+ def plot_failure_modes():
68
+ model_names = list(MODEL_DATA.keys())
69
+
70
+ # Print results to terminal
71
+ print("\n" + "="*80)
72
+ print("WRONG OBJECT SELECTION RATE BY VLM (4 Distractors)")
73
+ print("="*80)
74
+ for model in model_names:
75
+ data = MODEL_DATA[model]
76
+ total = data["total"]
77
+ print(f"\n{model} (n={total}):")
78
+ print("-" * 60)
79
+ for outcome, count in data["outcomes"].items():
80
+ percentage = (count / total * 100) if total > 0 else 0
81
+ print(f" {outcome:40s}: {count:3d} ({percentage:5.1f}%)")
82
+ print("\n" + "="*80 + "\n")
83
+
84
+ # Only show "Picked wrong object" percentage
85
+ chart_data = []
86
+ for model in model_names:
87
+ data = MODEL_DATA[model]
88
+ total = data["total"] or 1
89
+ count = data["outcomes"].get("Picked wrong object", 0)
90
+ proportion = count / total
91
+ percentage = proportion * 100
92
+
93
+ # Calculate standard error for binomial proportion
94
+ se = np.sqrt(proportion * (1 - proportion) / total) * 100
95
+
96
+ chart_data.append({
97
+ 'Model': model,
98
+ 'Percentage': percentage,
99
+ 'SE': se,
100
+ })
101
+
102
+ df = pd.DataFrame(chart_data)
103
+
104
+ # Sort by percentage (lowest to highest)
105
+ df = df.sort_values('Percentage')
106
+
107
+ # Create ordered list for x-axis
108
+ model_order = df['Model'].tolist()
109
+
110
+ # Assign colors from purple family (lightest for best, darkest for worst)
111
+ purple_shades = ["#D6BAE2", "#9B66BB", "#7B4FA2", "#4B136D"] # Lightest to darkest
112
+ color_mapping = {model: purple_shades[i] for i, model in enumerate(model_order)}
113
+ df['Color'] = df['Model'].map(color_mapping)
114
+
115
+ # Create simple bar chart with gradient colors
116
+ bars = alt.Chart(df).mark_bar(
117
+ stroke='white',
118
+ strokeWidth=1,
119
+ cornerRadiusTopLeft=8,
120
+ cornerRadiusTopRight=8
121
+ ).encode(
122
+ x=alt.X('Model:N', title=None, sort=model_order, axis=alt.Axis(labelFontSize=18, labelAngle=0)),
123
+ y=alt.Y('Percentage:Q', title=None, axis=alt.Axis(labelFontSize=18)),
124
+ color=alt.Color('Model:N', scale=alt.Scale(domain=model_order, range=[color_mapping[m] for m in model_order]), legend=None),
125
+ tooltip=['Model', alt.Tooltip('Percentage:Q', format='.1f')]
126
+ )
127
+
128
+ # Add error bars
129
+ error_bars = alt.Chart(df).mark_errorbar(
130
+ color='black',
131
+ thickness=2,
132
+ ticks=alt.MarkConfig(width=10)
133
+ ).encode(
134
+ x=alt.X('Model:N', sort=model_order),
135
+ y=alt.Y('Percentage:Q'),
136
+ yError=alt.YError('SE:Q')
137
+ )
138
+
139
+ # Add percentage labels on top of bars
140
+ text = alt.Chart(df).mark_text(
141
+ align='center',
142
+ baseline='bottom',
143
+ dy=-15,
144
+ fontSize=16,
145
+ fontWeight='bold'
146
+ ).encode(
147
+ x=alt.X('Model:N', sort=model_order),
148
+ y=alt.Y('Percentage:Q'),
149
+ text=alt.Text('Percentage:Q', format='.0f')
150
+ ).transform_calculate(
151
+ label='datum.Percentage + "%"'
152
+ )
153
+
154
+ chart = (bars + error_bars + text).properties(
155
+ width=600,
156
+ height=400,
157
+ title={
158
+ 'text': ' Wrong Object Selection Rate by VLM (4 Distractors)',
159
+ 'fontSize': 22,
160
+ 'anchor': 'start',
161
+ 'dx': 60,
162
+ 'dy': -20
163
+ },
164
+ padding={"left": 5, "right": 5, "top": 20, "bottom": 0}
165
+ ).configure_view(
166
+ strokeWidth=0
167
+ )
168
+
169
+ # Save chart
170
+ chart.save("failure_modes_by_model.html")
171
+ chart.save("failure_modes_by_model.png", scale_factor=3)
172
+ chart.save("failure_modes_by_model.pdf")
173
+ print("Saved: failure_modes_by_model.html, .png, .pdf")
174
+
175
+
176
+ if __name__ == "__main__":
177
+ plot_failure_modes()