egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import altair as alt
|
|
5
|
+
|
|
6
|
+
# Register custom font for export
|
|
7
|
+
alt.themes.register('custom_theme', lambda: {
|
|
8
|
+
'config': {
|
|
9
|
+
'title': {'font': 'Produkt'},
|
|
10
|
+
'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
11
|
+
'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
12
|
+
'mark': {'font': 'Produkt'},
|
|
13
|
+
'text': {'font': 'Produkt'},
|
|
14
|
+
}
|
|
15
|
+
})
|
|
16
|
+
alt.themes.enable('custom_theme')
|
|
17
|
+
|
|
18
|
+
# Hardcoded data from experiments
|
|
19
|
+
CHECKPOINT_DATA = {
|
|
20
|
+
"CAP-A": {
|
|
21
|
+
"total": 5000,
|
|
22
|
+
"outcomes": {
|
|
23
|
+
"Success": 1001,
|
|
24
|
+
"Did not lift enough": 677,
|
|
25
|
+
"Object touched but not grasped": 1295,
|
|
26
|
+
"Picked wrong object": 337,
|
|
27
|
+
"Empty Grasp": 1690,
|
|
28
|
+
"Did not grasp": 0,
|
|
29
|
+
}
|
|
30
|
+
},
|
|
31
|
+
"CAP-B": {
|
|
32
|
+
"total": 5000,
|
|
33
|
+
"outcomes": {
|
|
34
|
+
"Success": 1847,
|
|
35
|
+
"Did not lift enough": 995,
|
|
36
|
+
"Object touched but not grasped": 1213,
|
|
37
|
+
"Picked wrong object": 208,
|
|
38
|
+
"Empty Grasp": 737,
|
|
39
|
+
"Did not grasp": 0,
|
|
40
|
+
}
|
|
41
|
+
},
|
|
42
|
+
"CAP-C": {
|
|
43
|
+
"total": 5000,
|
|
44
|
+
"outcomes": {
|
|
45
|
+
"Success": 3137,
|
|
46
|
+
"Did not lift enough": 263,
|
|
47
|
+
"Object touched but not grasped": 597,
|
|
48
|
+
"Picked wrong object": 110,
|
|
49
|
+
"Empty Grasp": 893,
|
|
50
|
+
"Did not grasp": 0,
|
|
51
|
+
}
|
|
52
|
+
},
|
|
53
|
+
"CAP-D": {
|
|
54
|
+
"total": 5000,
|
|
55
|
+
"outcomes": {
|
|
56
|
+
"Success": 3994,
|
|
57
|
+
"Did not lift enough": 139,
|
|
58
|
+
"Object touched but not grasped": 362,
|
|
59
|
+
"Picked wrong object": 59,
|
|
60
|
+
"Empty Grasp": 446,
|
|
61
|
+
"Did not grasp": 0,
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def plot_failure_modes():
|
|
68
|
+
checkpoint_order = ["CAP-A", "CAP-B", "CAP-C", "CAP-D"]
|
|
69
|
+
|
|
70
|
+
# Print results to terminal
|
|
71
|
+
print("\n" + "="*80)
|
|
72
|
+
print("FAILURE MODE ANALYSIS BY CHECKPOINT")
|
|
73
|
+
print("="*80)
|
|
74
|
+
for checkpoint in checkpoint_order:
|
|
75
|
+
data = CHECKPOINT_DATA[checkpoint]
|
|
76
|
+
total = data["total"]
|
|
77
|
+
print(f"\n{checkpoint} (n={total}):")
|
|
78
|
+
print("-" * 60)
|
|
79
|
+
for outcome, count in data["outcomes"].items():
|
|
80
|
+
percentage = (count / total * 100) if total > 0 else 0
|
|
81
|
+
print(f" {outcome:40s}: {count:3d} ({percentage:5.1f}%)")
|
|
82
|
+
print("\n" + "="*80 + "\n")
|
|
83
|
+
|
|
84
|
+
outcome_order = [
|
|
85
|
+
"Success",
|
|
86
|
+
"Did not lift enough",
|
|
87
|
+
"Object touched but not grasped",
|
|
88
|
+
"Picked wrong object",
|
|
89
|
+
"Empty Grasp",
|
|
90
|
+
"Did not grasp",
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
colors = {
|
|
94
|
+
"Success": "#388038", # Green
|
|
95
|
+
"Did not lift enough": "#F7D45B", # Yellow
|
|
96
|
+
"Object touched but not grasped": "#66ACF7", # Blue
|
|
97
|
+
"Picked wrong object": "#F0529C", # Pink
|
|
98
|
+
"Empty Grasp": "#9B66BB", # Purple
|
|
99
|
+
"Did not grasp": "#870927", # Dark red
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
# Calculate average percentage for each outcome to sort by size
|
|
103
|
+
outcome_totals = {}
|
|
104
|
+
for checkpoint in checkpoint_order:
|
|
105
|
+
data = CHECKPOINT_DATA[checkpoint]
|
|
106
|
+
total = data["total"] or 1
|
|
107
|
+
for outcome in outcome_order:
|
|
108
|
+
count = data["outcomes"].get(outcome, 0)
|
|
109
|
+
percentage = count / total * 100
|
|
110
|
+
if outcome not in outcome_totals:
|
|
111
|
+
outcome_totals[outcome] = 0
|
|
112
|
+
outcome_totals[outcome] += percentage
|
|
113
|
+
|
|
114
|
+
# Sort outcomes by total percentage (largest first), excluding Success
|
|
115
|
+
failure_modes = [o for o in outcome_order if o != "Success"]
|
|
116
|
+
sorted_failures = sorted(failure_modes, key=lambda x: outcome_totals.get(x, 0), reverse=True)
|
|
117
|
+
|
|
118
|
+
# Build stacking order: Success at bottom, then failures from largest to smallest going up
|
|
119
|
+
stacking_order = ["Success"] + sorted_failures
|
|
120
|
+
|
|
121
|
+
chart_data = []
|
|
122
|
+
for checkpoint in checkpoint_order:
|
|
123
|
+
data = CHECKPOINT_DATA[checkpoint]
|
|
124
|
+
total = data["total"] or 1
|
|
125
|
+
for outcome in stacking_order:
|
|
126
|
+
count = data["outcomes"].get(outcome, 0)
|
|
127
|
+
percentage = count / total * 100
|
|
128
|
+
if percentage > 0: # Only include non-zero values
|
|
129
|
+
chart_data.append({
|
|
130
|
+
'Policy': checkpoint,
|
|
131
|
+
'Outcome': outcome,
|
|
132
|
+
'Percentage': percentage,
|
|
133
|
+
'Color': colors.get(outcome, "#999999")
|
|
134
|
+
})
|
|
135
|
+
|
|
136
|
+
df = pd.DataFrame(chart_data)
|
|
137
|
+
|
|
138
|
+
# Get only outcomes that actually appear in the data
|
|
139
|
+
outcomes_in_data = df['Outcome'].unique().tolist()
|
|
140
|
+
|
|
141
|
+
# Filter stacking_order to only include outcomes present in data
|
|
142
|
+
filtered_stacking_order = [o for o in stacking_order if o in outcomes_in_data]
|
|
143
|
+
|
|
144
|
+
# Add sort index to control stacking order
|
|
145
|
+
outcome_to_index = {outcome: i for i, outcome in enumerate(filtered_stacking_order)}
|
|
146
|
+
df['sort_index'] = df['Outcome'].map(outcome_to_index)
|
|
147
|
+
|
|
148
|
+
# Create color scale only for outcomes in data
|
|
149
|
+
color_scale = alt.Scale(
|
|
150
|
+
domain=filtered_stacking_order,
|
|
151
|
+
range=[colors[o] for o in filtered_stacking_order]
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Create stacked bar chart
|
|
155
|
+
chart = alt.Chart(df).mark_bar(
|
|
156
|
+
stroke='white',
|
|
157
|
+
strokeWidth=1
|
|
158
|
+
).encode(
|
|
159
|
+
x=alt.X('Policy:N', title=None, axis=alt.Axis(labelFontSize=18, labelAngle=0)),
|
|
160
|
+
y=alt.Y('Percentage:Q', title='Share of Episodes (%)', axis=alt.Axis(labelFontSize=18, titleFontSize=20)).scale(domain=[0, 100]),
|
|
161
|
+
color=alt.Color('Outcome:N', scale=color_scale, sort=filtered_stacking_order, legend=alt.Legend(
|
|
162
|
+
title=None,
|
|
163
|
+
labelFontSize=16,
|
|
164
|
+
symbolSize=200,
|
|
165
|
+
orient="bottom",
|
|
166
|
+
direction="horizontal",
|
|
167
|
+
labelLimit=0,
|
|
168
|
+
columns=3
|
|
169
|
+
)),
|
|
170
|
+
order=alt.Order('sort_index:Q'),
|
|
171
|
+
tooltip=['Policy', 'Outcome', alt.Tooltip('Percentage:Q', format='.1f')]
|
|
172
|
+
).properties(
|
|
173
|
+
width=400,
|
|
174
|
+
height=400,
|
|
175
|
+
title={
|
|
176
|
+
'text': ' EgoGym-Pick Failure Modes by Policy',
|
|
177
|
+
'fontSize': 22,
|
|
178
|
+
'anchor': 'start',
|
|
179
|
+
'dx': 60,
|
|
180
|
+
'dy': -20
|
|
181
|
+
},
|
|
182
|
+
padding={"left": 5, "right": 5, "top": 20, "bottom": 40}
|
|
183
|
+
).configure_view(
|
|
184
|
+
strokeWidth=0
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Save chart
|
|
188
|
+
chart.save("failure_modes_by_checkpoint.html")
|
|
189
|
+
chart.save("failure_modes_by_checkpoint.pdf", scale_factor=3)
|
|
190
|
+
chart.save("failure_modes_by_checkpoint.png", scale_factor=3)
|
|
191
|
+
print("\nPlot saved to: failure_modes_by_checkpoint.html and failure_modes_by_checkpoint.png")
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
if __name__ == "__main__":
|
|
195
|
+
plot_failure_modes()
|
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import altair as alt
|
|
6
|
+
|
|
7
|
+
# Register custom font for export
|
|
8
|
+
alt.themes.register('custom_theme', lambda: {
|
|
9
|
+
'config': {
|
|
10
|
+
'title': {'font': 'Produkt'},
|
|
11
|
+
'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
12
|
+
'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
13
|
+
'mark': {'font': 'Produkt'},
|
|
14
|
+
'text': {'font': 'Produkt'},
|
|
15
|
+
}
|
|
16
|
+
})
|
|
17
|
+
alt.themes.enable('custom_theme')
|
|
18
|
+
|
|
19
|
+
BASE_DIR = "logs"
|
|
20
|
+
REWARD_THRESHOLD = 0.03
|
|
21
|
+
PARTIAL_LIFT_THRESHOLD = 0.01
|
|
22
|
+
|
|
23
|
+
MODELS = {
|
|
24
|
+
"CAP + Oracle": "rum/oracle",
|
|
25
|
+
"CAP + Gemini": "rum/gemini",
|
|
26
|
+
"CAP + Molmo": "rum/molmo",
|
|
27
|
+
"CAP + Moondream": "rum/moondream",
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def compute_outcomes_from_csv(csv_path):
|
|
32
|
+
if not os.path.exists(csv_path):
|
|
33
|
+
return None, None
|
|
34
|
+
df = pd.read_csv(csv_path, sep="\t")
|
|
35
|
+
total_episodes = len(df)
|
|
36
|
+
|
|
37
|
+
# Count successes
|
|
38
|
+
successes = (df["max_reward"] > REWARD_THRESHOLD).sum()
|
|
39
|
+
|
|
40
|
+
def get_failure_mode(row):
|
|
41
|
+
if row["max_reward"] > REWARD_THRESHOLD:
|
|
42
|
+
return "Success"
|
|
43
|
+
|
|
44
|
+
bodies_contacted = str(row.get("grasped_bodies", ""))
|
|
45
|
+
object_name = str(row.get("object_name", ""))
|
|
46
|
+
grasping_object = row.get("grasping_object", False)
|
|
47
|
+
is_grasping = row.get("is_grasping", False) # Final gripper state only
|
|
48
|
+
|
|
49
|
+
# Extract body names from grasped_bodies list
|
|
50
|
+
has_target_contact = object_name in bodies_contacted
|
|
51
|
+
has_gripper_contact = "left" in bodies_contacted or "right" in bodies_contacted
|
|
52
|
+
has_any_object_contact = "object" in bodies_contacted
|
|
53
|
+
has_wrong_object_contact = has_any_object_contact and not has_target_contact
|
|
54
|
+
|
|
55
|
+
# Decision tree (most specific to least specific)
|
|
56
|
+
# Note: is_grasping is only final state, grasping_object tracks if target was ever grasped
|
|
57
|
+
|
|
58
|
+
# 1. Successfully grasped target but didn't lift high enough
|
|
59
|
+
if grasping_object and row["max_reward"] >= PARTIAL_LIFT_THRESHOLD:
|
|
60
|
+
return "Did not lift enough"
|
|
61
|
+
|
|
62
|
+
# 2. Grasped target but barely lifted (or dropped immediately)
|
|
63
|
+
if grasping_object:
|
|
64
|
+
return "Object touched but not grasped"
|
|
65
|
+
|
|
66
|
+
# 3. Made contact with target but never achieved grasp
|
|
67
|
+
if has_target_contact:
|
|
68
|
+
return "Object touched but not grasped"
|
|
69
|
+
|
|
70
|
+
# 4. Grasped wrong object with significant lift
|
|
71
|
+
if has_wrong_object_contact and row["max_reward"] >= PARTIAL_LIFT_THRESHOLD:
|
|
72
|
+
return "Picked wrong object"
|
|
73
|
+
|
|
74
|
+
# 5. Contacted wrong object (with or without final grasp state)
|
|
75
|
+
if has_wrong_object_contact:
|
|
76
|
+
return "Picked wrong object"
|
|
77
|
+
|
|
78
|
+
# 6. Gripper closed at end or had gripper contact but no object identification
|
|
79
|
+
if is_grasping or has_gripper_contact:
|
|
80
|
+
return "Empty Grasp"
|
|
81
|
+
|
|
82
|
+
# 7. Never made meaningful contact with anything
|
|
83
|
+
return "Did not grasp"
|
|
84
|
+
|
|
85
|
+
df["outcome"] = df.apply(get_failure_mode, axis=1)
|
|
86
|
+
|
|
87
|
+
outcomes = {
|
|
88
|
+
"Success": 0,
|
|
89
|
+
"Did not lift enough": 0,
|
|
90
|
+
"Object touched but not grasped": 0,
|
|
91
|
+
"Picked wrong object": 0,
|
|
92
|
+
"Empty Grasp": 0,
|
|
93
|
+
"Did not grasp": 0,
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
for _, row in df.iterrows():
|
|
97
|
+
mode = row["outcome"]
|
|
98
|
+
if mode in outcomes:
|
|
99
|
+
outcomes[mode] += 1
|
|
100
|
+
|
|
101
|
+
return outcomes, total_episodes
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def plot_failure_modes():
|
|
105
|
+
outcome_data = {}
|
|
106
|
+
episode_totals = {}
|
|
107
|
+
|
|
108
|
+
# Get data for 5 objects (4 distractors) for each model
|
|
109
|
+
for model_name, model_folder in MODELS.items():
|
|
110
|
+
# Try multiple folder name variations and nested evaluation folders
|
|
111
|
+
possible_paths = [
|
|
112
|
+
os.path.join(BASE_DIR, model_folder, "5_objects", "log.csv"),
|
|
113
|
+
os.path.join(BASE_DIR, model_folder, "5_object", "log.csv"),
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
csv_path = None
|
|
117
|
+
for path in possible_paths:
|
|
118
|
+
if os.path.exists(path):
|
|
119
|
+
csv_path = path
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
# Check for nested evaluation folders
|
|
123
|
+
if csv_path is None:
|
|
124
|
+
for folder_name in ["5_objects", "5_object"]:
|
|
125
|
+
folder_path = os.path.join(BASE_DIR, model_folder, folder_name)
|
|
126
|
+
if os.path.exists(folder_path) and os.path.isdir(folder_path):
|
|
127
|
+
for subdir in os.listdir(folder_path):
|
|
128
|
+
subdir_path = os.path.join(folder_path, subdir)
|
|
129
|
+
if os.path.isdir(subdir_path):
|
|
130
|
+
candidate = os.path.join(subdir_path, "log.csv")
|
|
131
|
+
if os.path.exists(candidate):
|
|
132
|
+
csv_path = candidate
|
|
133
|
+
break
|
|
134
|
+
if csv_path:
|
|
135
|
+
break
|
|
136
|
+
|
|
137
|
+
if csv_path is None:
|
|
138
|
+
print(f"Warning: Could not find CSV data for {model_name}")
|
|
139
|
+
continue
|
|
140
|
+
|
|
141
|
+
outcomes, total_episodes = compute_outcomes_from_csv(csv_path)
|
|
142
|
+
if outcomes is None:
|
|
143
|
+
print(f"Warning: Could not find CSV data for {model_name}")
|
|
144
|
+
continue
|
|
145
|
+
outcome_data[model_name] = outcomes
|
|
146
|
+
episode_totals[model_name] = total_episodes
|
|
147
|
+
|
|
148
|
+
if not outcome_data:
|
|
149
|
+
print("No data found!")
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
model_names = list(outcome_data.keys())
|
|
153
|
+
|
|
154
|
+
# Print results to terminal
|
|
155
|
+
print("\n" + "="*80)
|
|
156
|
+
print("WRONG OBJECT SELECTION RATE BY VLM (4 Distractors)")
|
|
157
|
+
print("="*80)
|
|
158
|
+
for model in model_names:
|
|
159
|
+
total = episode_totals.get(model, 0)
|
|
160
|
+
print(f"\n{model} (n={total}):")
|
|
161
|
+
print("-" * 60)
|
|
162
|
+
for outcome, count in outcome_data[model].items():
|
|
163
|
+
percentage = (count / total * 100) if total > 0 else 0
|
|
164
|
+
print(f" {outcome:40s}: {count:3d} ({percentage:5.1f}%)")
|
|
165
|
+
print("\n" + "="*80 + "\n")
|
|
166
|
+
|
|
167
|
+
# Only show "Picked wrong object" percentage
|
|
168
|
+
chart_data = []
|
|
169
|
+
for model in model_names:
|
|
170
|
+
total = episode_totals.get(model, 0) or 1
|
|
171
|
+
count = outcome_data[model].get("Picked wrong object", 0)
|
|
172
|
+
proportion = count / total
|
|
173
|
+
percentage = proportion * 100
|
|
174
|
+
|
|
175
|
+
# Calculate standard error for binomial proportion
|
|
176
|
+
se = np.sqrt(proportion * (1 - proportion) / total) * 100
|
|
177
|
+
|
|
178
|
+
chart_data.append({
|
|
179
|
+
'Model': model,
|
|
180
|
+
'Percentage': percentage,
|
|
181
|
+
'SE': se,
|
|
182
|
+
})
|
|
183
|
+
|
|
184
|
+
df = pd.DataFrame(chart_data)
|
|
185
|
+
|
|
186
|
+
# Sort by percentage (lowest to highest)
|
|
187
|
+
df = df.sort_values('Percentage')
|
|
188
|
+
|
|
189
|
+
# Create ordered list for x-axis
|
|
190
|
+
model_order = df['Model'].tolist()
|
|
191
|
+
|
|
192
|
+
# Assign colors from purple family (lightest for best, darkest for worst)
|
|
193
|
+
purple_shades = ["#D6BAE2", "#9B66BB", "#7B4FA2", "#4B136D"] # Lightest to darkest
|
|
194
|
+
color_mapping = {model: purple_shades[i] for i, model in enumerate(model_order)}
|
|
195
|
+
df['Color'] = df['Model'].map(color_mapping)
|
|
196
|
+
|
|
197
|
+
# Create simple bar chart with gradient colors
|
|
198
|
+
bars = alt.Chart(df).mark_bar(
|
|
199
|
+
stroke='white',
|
|
200
|
+
strokeWidth=1,
|
|
201
|
+
cornerRadiusTopLeft=8,
|
|
202
|
+
cornerRadiusTopRight=8
|
|
203
|
+
).encode(
|
|
204
|
+
x=alt.X('Model:N', title=None, sort=model_order, axis=alt.Axis(labelFontSize=18, labelAngle=0)),
|
|
205
|
+
y=alt.Y('Percentage:Q', title=None, axis=alt.Axis(labelFontSize=18)),
|
|
206
|
+
color=alt.Color('Model:N', scale=alt.Scale(domain=model_order, range=[color_mapping[m] for m in model_order]), legend=None),
|
|
207
|
+
tooltip=['Model', alt.Tooltip('Percentage:Q', format='.1f')]
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Add error bars
|
|
211
|
+
error_bars = alt.Chart(df).mark_errorbar(
|
|
212
|
+
color='black',
|
|
213
|
+
thickness=2,
|
|
214
|
+
ticks=alt.MarkConfig(width=10)
|
|
215
|
+
).encode(
|
|
216
|
+
x=alt.X('Model:N', sort=model_order),
|
|
217
|
+
y=alt.Y('Percentage:Q'),
|
|
218
|
+
yError=alt.YError('SE:Q')
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Add percentage labels on top of bars
|
|
222
|
+
text = alt.Chart(df).mark_text(
|
|
223
|
+
align='center',
|
|
224
|
+
baseline='bottom',
|
|
225
|
+
dy=-15,
|
|
226
|
+
fontSize=16,
|
|
227
|
+
fontWeight='bold'
|
|
228
|
+
).encode(
|
|
229
|
+
x=alt.X('Model:N', sort=model_order),
|
|
230
|
+
y=alt.Y('Percentage:Q'),
|
|
231
|
+
text=alt.Text('Percentage:Q', format='.1f')
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
chart = (bars + error_bars + text).properties(
|
|
235
|
+
width=600,
|
|
236
|
+
height=400,
|
|
237
|
+
title={
|
|
238
|
+
'text': ' Wrong Object Selection Rate by VLM (4 Distractors)',
|
|
239
|
+
'fontSize': 22,
|
|
240
|
+
'anchor': 'start',
|
|
241
|
+
'dx': 60,
|
|
242
|
+
'dy': -20
|
|
243
|
+
},
|
|
244
|
+
padding={"left": 5, "right": 5, "top": 20, "bottom": 0}
|
|
245
|
+
).configure_view(
|
|
246
|
+
strokeWidth=0
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Save chart
|
|
250
|
+
chart.save("failure_modes_by_model.html")
|
|
251
|
+
chart.save("failure_modes_by_model.png", scale_factor=3)
|
|
252
|
+
chart.save("failure_modes_by_model.pdf")
|
|
253
|
+
print("Saved: failure_modes_by_model.html, .png, .pdf")
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
if __name__ == "__main__":
|
|
257
|
+
plot_failure_modes()
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import altair as alt
|
|
5
|
+
|
|
6
|
+
# Register custom font for export
|
|
7
|
+
alt.themes.register('custom_theme', lambda: {
|
|
8
|
+
'config': {
|
|
9
|
+
'title': {'font': 'Produkt'},
|
|
10
|
+
'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
11
|
+
'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
12
|
+
'mark': {'font': 'Produkt'},
|
|
13
|
+
'text': {'font': 'Produkt'},
|
|
14
|
+
}
|
|
15
|
+
})
|
|
16
|
+
alt.themes.enable('custom_theme')
|
|
17
|
+
|
|
18
|
+
# Hardcoded data from experiments
|
|
19
|
+
MODEL_DATA = {
|
|
20
|
+
"CAP + Oracle": {
|
|
21
|
+
"total": 5000,
|
|
22
|
+
"outcomes": {
|
|
23
|
+
"Success": 3994,
|
|
24
|
+
"Did not lift enough": 89,
|
|
25
|
+
"Object touched but not grasped": 412,
|
|
26
|
+
"Picked wrong object": 59,
|
|
27
|
+
"Empty Grasp": 446,
|
|
28
|
+
"Did not grasp": 0,
|
|
29
|
+
}
|
|
30
|
+
},
|
|
31
|
+
"CAP + Gemini": {
|
|
32
|
+
"total": 5000,
|
|
33
|
+
"outcomes": {
|
|
34
|
+
"Success": 3441,
|
|
35
|
+
"Did not lift enough": 88,
|
|
36
|
+
"Object touched but not grasped": 389,
|
|
37
|
+
"Picked wrong object": 448,
|
|
38
|
+
"Empty Grasp": 634,
|
|
39
|
+
"Did not grasp": 0,
|
|
40
|
+
}
|
|
41
|
+
},
|
|
42
|
+
"CAP + Molmo": {
|
|
43
|
+
"total": 5000,
|
|
44
|
+
"outcomes": {
|
|
45
|
+
"Success": 3008,
|
|
46
|
+
"Did not lift enough": 94,
|
|
47
|
+
"Object touched but not grasped": 467,
|
|
48
|
+
"Picked wrong object": 877,
|
|
49
|
+
"Empty Grasp": 554,
|
|
50
|
+
"Did not grasp": 0,
|
|
51
|
+
}
|
|
52
|
+
},
|
|
53
|
+
"CAP + Moondream": {
|
|
54
|
+
"total": 5000,
|
|
55
|
+
"outcomes": {
|
|
56
|
+
"Success": 3033,
|
|
57
|
+
"Did not lift enough": 83,
|
|
58
|
+
"Object touched but not grasped": 342,
|
|
59
|
+
"Picked wrong object": 746,
|
|
60
|
+
"Empty Grasp": 796,
|
|
61
|
+
"Did not grasp": 0,
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def plot_failure_modes():
|
|
68
|
+
model_names = list(MODEL_DATA.keys())
|
|
69
|
+
|
|
70
|
+
# Print results to terminal
|
|
71
|
+
print("\n" + "="*80)
|
|
72
|
+
print("WRONG OBJECT SELECTION RATE BY VLM (4 Distractors)")
|
|
73
|
+
print("="*80)
|
|
74
|
+
for model in model_names:
|
|
75
|
+
data = MODEL_DATA[model]
|
|
76
|
+
total = data["total"]
|
|
77
|
+
print(f"\n{model} (n={total}):")
|
|
78
|
+
print("-" * 60)
|
|
79
|
+
for outcome, count in data["outcomes"].items():
|
|
80
|
+
percentage = (count / total * 100) if total > 0 else 0
|
|
81
|
+
print(f" {outcome:40s}: {count:3d} ({percentage:5.1f}%)")
|
|
82
|
+
print("\n" + "="*80 + "\n")
|
|
83
|
+
|
|
84
|
+
# Only show "Picked wrong object" percentage
|
|
85
|
+
chart_data = []
|
|
86
|
+
for model in model_names:
|
|
87
|
+
data = MODEL_DATA[model]
|
|
88
|
+
total = data["total"] or 1
|
|
89
|
+
count = data["outcomes"].get("Picked wrong object", 0)
|
|
90
|
+
proportion = count / total
|
|
91
|
+
percentage = proportion * 100
|
|
92
|
+
|
|
93
|
+
# Calculate standard error for binomial proportion
|
|
94
|
+
se = np.sqrt(proportion * (1 - proportion) / total) * 100
|
|
95
|
+
|
|
96
|
+
chart_data.append({
|
|
97
|
+
'Model': model,
|
|
98
|
+
'Percentage': percentage,
|
|
99
|
+
'SE': se,
|
|
100
|
+
})
|
|
101
|
+
|
|
102
|
+
df = pd.DataFrame(chart_data)
|
|
103
|
+
|
|
104
|
+
# Sort by percentage (lowest to highest)
|
|
105
|
+
df = df.sort_values('Percentage')
|
|
106
|
+
|
|
107
|
+
# Create ordered list for x-axis
|
|
108
|
+
model_order = df['Model'].tolist()
|
|
109
|
+
|
|
110
|
+
# Assign colors from purple family (lightest for best, darkest for worst)
|
|
111
|
+
purple_shades = ["#D6BAE2", "#9B66BB", "#7B4FA2", "#4B136D"] # Lightest to darkest
|
|
112
|
+
color_mapping = {model: purple_shades[i] for i, model in enumerate(model_order)}
|
|
113
|
+
df['Color'] = df['Model'].map(color_mapping)
|
|
114
|
+
|
|
115
|
+
# Create simple bar chart with gradient colors
|
|
116
|
+
bars = alt.Chart(df).mark_bar(
|
|
117
|
+
stroke='white',
|
|
118
|
+
strokeWidth=1,
|
|
119
|
+
cornerRadiusTopLeft=8,
|
|
120
|
+
cornerRadiusTopRight=8
|
|
121
|
+
).encode(
|
|
122
|
+
x=alt.X('Model:N', title=None, sort=model_order, axis=alt.Axis(labelFontSize=18, labelAngle=0)),
|
|
123
|
+
y=alt.Y('Percentage:Q', title=None, axis=alt.Axis(labelFontSize=18)),
|
|
124
|
+
color=alt.Color('Model:N', scale=alt.Scale(domain=model_order, range=[color_mapping[m] for m in model_order]), legend=None),
|
|
125
|
+
tooltip=['Model', alt.Tooltip('Percentage:Q', format='.1f')]
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Add error bars
|
|
129
|
+
error_bars = alt.Chart(df).mark_errorbar(
|
|
130
|
+
color='black',
|
|
131
|
+
thickness=2,
|
|
132
|
+
ticks=alt.MarkConfig(width=10)
|
|
133
|
+
).encode(
|
|
134
|
+
x=alt.X('Model:N', sort=model_order),
|
|
135
|
+
y=alt.Y('Percentage:Q'),
|
|
136
|
+
yError=alt.YError('SE:Q')
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Add percentage labels on top of bars
|
|
140
|
+
text = alt.Chart(df).mark_text(
|
|
141
|
+
align='center',
|
|
142
|
+
baseline='bottom',
|
|
143
|
+
dy=-15,
|
|
144
|
+
fontSize=16,
|
|
145
|
+
fontWeight='bold'
|
|
146
|
+
).encode(
|
|
147
|
+
x=alt.X('Model:N', sort=model_order),
|
|
148
|
+
y=alt.Y('Percentage:Q'),
|
|
149
|
+
text=alt.Text('Percentage:Q', format='.0f')
|
|
150
|
+
).transform_calculate(
|
|
151
|
+
label='datum.Percentage + "%"'
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
chart = (bars + error_bars + text).properties(
|
|
155
|
+
width=600,
|
|
156
|
+
height=400,
|
|
157
|
+
title={
|
|
158
|
+
'text': ' Wrong Object Selection Rate by VLM (4 Distractors)',
|
|
159
|
+
'fontSize': 22,
|
|
160
|
+
'anchor': 'start',
|
|
161
|
+
'dx': 60,
|
|
162
|
+
'dy': -20
|
|
163
|
+
},
|
|
164
|
+
padding={"left": 5, "right": 5, "top": 20, "bottom": 0}
|
|
165
|
+
).configure_view(
|
|
166
|
+
strokeWidth=0
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Save chart
|
|
170
|
+
chart.save("failure_modes_by_model.html")
|
|
171
|
+
chart.save("failure_modes_by_model.png", scale_factor=3)
|
|
172
|
+
chart.save("failure_modes_by_model.pdf")
|
|
173
|
+
print("Saved: failure_modes_by_model.html, .png, .pdf")
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
if __name__ == "__main__":
|
|
177
|
+
plot_failure_modes()
|