egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
egogym/scripts/plot.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def plot_success_rates(logs_dir):
|
|
7
|
+
"""
|
|
8
|
+
Plot success rates from all CSV files found recursively in the logs directory.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
logs_dir: Directory containing subdirectories with CSV files
|
|
12
|
+
"""
|
|
13
|
+
success_rates = {}
|
|
14
|
+
|
|
15
|
+
# Recursively find all CSV files
|
|
16
|
+
for root, dirs, files in os.walk(logs_dir):
|
|
17
|
+
for file in files:
|
|
18
|
+
if file.endswith('.csv'):
|
|
19
|
+
csv_path = os.path.join(root, file)
|
|
20
|
+
|
|
21
|
+
# Get relative path from logs_dir
|
|
22
|
+
rel_path = os.path.relpath(csv_path, logs_dir)
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
# Read the CSV file
|
|
26
|
+
df = pd.read_csv(csv_path, sep="\t")
|
|
27
|
+
|
|
28
|
+
# Calculate success rate (assuming max_reward > 0 means success)
|
|
29
|
+
success_rate = (df["max_reward"] > 0.03).mean() * 100
|
|
30
|
+
|
|
31
|
+
success_rates[rel_path] = success_rate
|
|
32
|
+
print(f"{rel_path}: {success_rate:.2f}%")
|
|
33
|
+
|
|
34
|
+
except Exception as e:
|
|
35
|
+
print(f"Error processing {rel_path}: {e}")
|
|
36
|
+
continue
|
|
37
|
+
|
|
38
|
+
if not success_rates:
|
|
39
|
+
print("No valid CSV files found!")
|
|
40
|
+
return
|
|
41
|
+
|
|
42
|
+
# Sort by relative path
|
|
43
|
+
sorted_items = sorted(success_rates.items())
|
|
44
|
+
|
|
45
|
+
folders, rates = zip(*sorted_items)
|
|
46
|
+
|
|
47
|
+
# Create bar plot
|
|
48
|
+
plt.figure(figsize=(12, 6))
|
|
49
|
+
bars = plt.bar(range(len(folders)), rates, color='steelblue', edgecolor='black')
|
|
50
|
+
|
|
51
|
+
# Add value labels on top of bars
|
|
52
|
+
for i, bar in enumerate(bars):
|
|
53
|
+
height = bar.get_height()
|
|
54
|
+
plt.text(bar.get_x() + bar.get_width()/2., height,
|
|
55
|
+
f'{height:.1f}%',
|
|
56
|
+
ha='center', va='bottom', fontsize=8)
|
|
57
|
+
|
|
58
|
+
plt.xticks(range(len(folders)), folders, rotation=45, ha='right', fontsize=8)
|
|
59
|
+
plt.xlabel('Log File', fontsize=12)
|
|
60
|
+
plt.ylabel('Success Rate (%)', fontsize=12)
|
|
61
|
+
plt.title('Success Rate by Log File', fontsize=14, fontweight='bold')
|
|
62
|
+
plt.ylim(0, 100)
|
|
63
|
+
plt.grid(axis='y', alpha=0.3)
|
|
64
|
+
plt.tight_layout()
|
|
65
|
+
|
|
66
|
+
# Save the plot
|
|
67
|
+
output_path = os.path.join(logs_dir, 'success_rates.png')
|
|
68
|
+
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
|
69
|
+
print(f"\nPlot saved to: {output_path}")
|
|
70
|
+
|
|
71
|
+
plt.show()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
if __name__ == "__main__":
|
|
75
|
+
import sys
|
|
76
|
+
|
|
77
|
+
if len(sys.argv) > 1:
|
|
78
|
+
logs_directory = sys.argv[1]
|
|
79
|
+
else:
|
|
80
|
+
# Default to logs folder in the project
|
|
81
|
+
logs_directory = "logs"
|
|
82
|
+
|
|
83
|
+
if not os.path.exists(logs_directory):
|
|
84
|
+
print(f"Error: Directory '{logs_directory}' not found!")
|
|
85
|
+
sys.exit(1)
|
|
86
|
+
|
|
87
|
+
plot_success_rates(logs_directory)
|
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import altair as alt
|
|
6
|
+
from scipy.stats import beta, pearsonr, gaussian_kde
|
|
7
|
+
|
|
8
|
+
BASE_DIR = "logs/5_objects"
|
|
9
|
+
REWARD_THRESHOLD = 0.03
|
|
10
|
+
PARTIAL_LIFT_THRESHOLD = 0.01
|
|
11
|
+
|
|
12
|
+
CHECKPOINT_REAL_SR = {
|
|
13
|
+
"checkpoint_10": 24.0,
|
|
14
|
+
"checkpoint_50": 38.8,
|
|
15
|
+
"checkpoint_64": 67.5,
|
|
16
|
+
"checkpoint_80": 83.2,
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
REAL_SAMPLES = 250
|
|
20
|
+
SIM_SAMPLES_PER_CHECKPOINT = None
|
|
21
|
+
|
|
22
|
+
USE_CSV_FILES = True
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def compute_success_from_csv(csv_path):
|
|
26
|
+
if not os.path.exists(csv_path):
|
|
27
|
+
return None, None
|
|
28
|
+
df = pd.read_csv(csv_path, sep="\t")
|
|
29
|
+
successes = (df["max_reward"] > REWARD_THRESHOLD).sum()
|
|
30
|
+
total = len(df)
|
|
31
|
+
return successes, total
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def compute_failure_modes_from_csv(csv_path):
|
|
35
|
+
if not os.path.exists(csv_path):
|
|
36
|
+
return None, None
|
|
37
|
+
df = pd.read_csv(csv_path, sep="\t")
|
|
38
|
+
total_episodes = len(df)
|
|
39
|
+
def get_failure_mode(row):
|
|
40
|
+
if row["max_reward"] > REWARD_THRESHOLD:
|
|
41
|
+
return None
|
|
42
|
+
bodies_contacted = str(row.get("grasped_bodies", ""))
|
|
43
|
+
object_name = str(row.get("object_name", ""))
|
|
44
|
+
grasping_object = row.get("grasping_object", False)
|
|
45
|
+
is_grasping = row.get("is_grasping", False)
|
|
46
|
+
has_target_contact = object_name in bodies_contacted
|
|
47
|
+
has_gripper_contact = "left" in bodies_contacted or "right" in bodies_contacted
|
|
48
|
+
has_any_object_contact = "object" in bodies_contacted
|
|
49
|
+
has_wrong_object_contact = has_any_object_contact and not has_target_contact
|
|
50
|
+
if grasping_object:
|
|
51
|
+
if row["max_reward"] >= PARTIAL_LIFT_THRESHOLD:
|
|
52
|
+
return "did not lift enough"
|
|
53
|
+
return "object fell/slipped"
|
|
54
|
+
if has_target_contact:
|
|
55
|
+
return "object fell/slipped"
|
|
56
|
+
if has_wrong_object_contact and (is_grasping or has_gripper_contact):
|
|
57
|
+
return "picked wrong object"
|
|
58
|
+
if has_gripper_contact or is_grasping:
|
|
59
|
+
return "Empty Grasp"
|
|
60
|
+
return "did not grasp"
|
|
61
|
+
df["failure_mode"] = df.apply(get_failure_mode, axis=1)
|
|
62
|
+
failure_modes = {
|
|
63
|
+
"Did not lift enough": 0,
|
|
64
|
+
"Object touched but not grasped": 0,
|
|
65
|
+
"Picked wrong object": 0,
|
|
66
|
+
"Empty Grasp": 0,
|
|
67
|
+
"Did not grasp": 0,
|
|
68
|
+
}
|
|
69
|
+
failure_condition = df["max_reward"] <= REWARD_THRESHOLD
|
|
70
|
+
for _, row in df[failure_condition].iterrows():
|
|
71
|
+
mode = row["failure_mode"]
|
|
72
|
+
if mode == "did not lift enough":
|
|
73
|
+
failure_modes["Did not lift enough"] += 1
|
|
74
|
+
elif mode == "object fell/slipped":
|
|
75
|
+
failure_modes["Object touched but not grasped"] += 1
|
|
76
|
+
elif mode == "picked wrong object":
|
|
77
|
+
failure_modes["Picked wrong object"] += 1
|
|
78
|
+
elif mode == "Empty Grasp":
|
|
79
|
+
failure_modes["Empty Grasp"] += 1
|
|
80
|
+
elif mode == "did not grasp":
|
|
81
|
+
failure_modes["Did not grasp"] += 1
|
|
82
|
+
total_failures = sum(failure_modes.values())
|
|
83
|
+
return failure_modes, total_failures, total_episodes
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def plot_bayesian_correlation():
|
|
87
|
+
# Set random seed for reproducibility
|
|
88
|
+
np.random.seed(42)
|
|
89
|
+
|
|
90
|
+
# Register and enable custom font theme
|
|
91
|
+
alt.themes.register('custom_theme', lambda: {
|
|
92
|
+
'config': {
|
|
93
|
+
'title': {'font': 'Produkt'},
|
|
94
|
+
'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
95
|
+
'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
96
|
+
'mark': {'font': 'Produkt'},
|
|
97
|
+
'text': {'font': 'Produkt'},
|
|
98
|
+
}
|
|
99
|
+
})
|
|
100
|
+
alt.themes.enable('custom_theme')
|
|
101
|
+
|
|
102
|
+
stats = {}
|
|
103
|
+
checkpoint_order = ["checkpoint_10", "checkpoint_50", "checkpoint_64", "checkpoint_80"]
|
|
104
|
+
for checkpoint in checkpoint_order:
|
|
105
|
+
if checkpoint not in CHECKPOINT_REAL_SR:
|
|
106
|
+
print(f"Warning: {checkpoint} not in CHECKPOINT_REAL_SR")
|
|
107
|
+
continue
|
|
108
|
+
csv_path = os.path.join(BASE_DIR, checkpoint, "log.csv")
|
|
109
|
+
success, total = compute_success_from_csv(csv_path)
|
|
110
|
+
if success is None:
|
|
111
|
+
print(f"Warning: Could not find CSV data for {checkpoint}")
|
|
112
|
+
continue
|
|
113
|
+
real_sr = CHECKPOINT_REAL_SR[checkpoint]
|
|
114
|
+
a_sim = 1 + success
|
|
115
|
+
b_sim = 1 + (total - success)
|
|
116
|
+
sim_mean = 100 * a_sim / (a_sim + b_sim)
|
|
117
|
+
real_successes = int(real_sr / 100 * REAL_SAMPLES)
|
|
118
|
+
a_real = 1 + real_successes
|
|
119
|
+
b_real = 1 + (REAL_SAMPLES - real_successes)
|
|
120
|
+
real_mean = 100 * a_real / (a_real + b_real)
|
|
121
|
+
print(f"{checkpoint}: sim={success}/{total} ({sim_mean:.1f}%), real={real_sr:.1f}% ({real_successes}/{REAL_SAMPLES})")
|
|
122
|
+
stats[checkpoint] = {
|
|
123
|
+
"sim_mean": sim_mean,
|
|
124
|
+
"sim_alpha": a_sim,
|
|
125
|
+
"sim_beta": b_sim,
|
|
126
|
+
"real_mean": real_mean,
|
|
127
|
+
"real_alpha": a_real,
|
|
128
|
+
"real_beta": b_real,
|
|
129
|
+
}
|
|
130
|
+
if not stats:
|
|
131
|
+
print("No data found!")
|
|
132
|
+
return
|
|
133
|
+
checkpoint_names = list(stats.keys())
|
|
134
|
+
sim_rates = np.array([stats[c]["sim_mean"] for c in checkpoint_names])
|
|
135
|
+
real_rates = np.array([stats[c]["real_mean"] for c in checkpoint_names])
|
|
136
|
+
|
|
137
|
+
r, p_value = pearsonr(sim_rates, real_rates)
|
|
138
|
+
|
|
139
|
+
# Compute confidence intervals and violin plot data
|
|
140
|
+
violin_data = []
|
|
141
|
+
point_data = []
|
|
142
|
+
error_data = []
|
|
143
|
+
|
|
144
|
+
for i, checkpoint in enumerate(checkpoint_names):
|
|
145
|
+
sim_lo = beta.ppf(0.025, stats[checkpoint]["sim_alpha"], stats[checkpoint]["sim_beta"]) * 100
|
|
146
|
+
sim_hi = beta.ppf(0.975, stats[checkpoint]["sim_alpha"], stats[checkpoint]["sim_beta"]) * 100
|
|
147
|
+
|
|
148
|
+
# Generate violin data
|
|
149
|
+
y_samples = beta.rvs(stats[checkpoint]["real_alpha"], stats[checkpoint]["real_beta"], size=10000) * 100
|
|
150
|
+
|
|
151
|
+
xc = sim_rates[i]
|
|
152
|
+
yc = real_rates[i]
|
|
153
|
+
|
|
154
|
+
# Store point data
|
|
155
|
+
point_data.append({
|
|
156
|
+
'checkpoint': checkpoint,
|
|
157
|
+
'sim_rate': xc,
|
|
158
|
+
'real_rate': yc,
|
|
159
|
+
'sim_lo': sim_lo,
|
|
160
|
+
'sim_hi': sim_hi
|
|
161
|
+
})
|
|
162
|
+
|
|
163
|
+
# Store violin data for density transform
|
|
164
|
+
for y_val in y_samples:
|
|
165
|
+
violin_data.append({
|
|
166
|
+
'checkpoint': checkpoint,
|
|
167
|
+
'sim_rate': xc,
|
|
168
|
+
'real_rate': y_val
|
|
169
|
+
})
|
|
170
|
+
|
|
171
|
+
# Compute bootstrap correlation confidence interval
|
|
172
|
+
n_samples = 1000
|
|
173
|
+
r_samples = []
|
|
174
|
+
for _ in range(n_samples):
|
|
175
|
+
sim_sample = [beta.rvs(stats[c]["sim_alpha"], stats[c]["sim_beta"]) * 100 for c in checkpoint_names]
|
|
176
|
+
real_sample = [beta.rvs(stats[c]["real_alpha"], stats[c]["real_beta"]) * 100 for c in checkpoint_names]
|
|
177
|
+
r_sample, _ = pearsonr(sim_sample, real_sample)
|
|
178
|
+
r_samples.append(r_sample)
|
|
179
|
+
r_samples = np.array(r_samples)
|
|
180
|
+
|
|
181
|
+
# Convert to r² (coefficient of determination)
|
|
182
|
+
r_squared = r ** 2
|
|
183
|
+
r_squared_samples = r_samples ** 2
|
|
184
|
+
r_squared_lo = np.percentile(r_squared_samples, 2.5)
|
|
185
|
+
r_squared_hi = np.percentile(r_squared_samples, 97.5)
|
|
186
|
+
|
|
187
|
+
# Create DataFrames
|
|
188
|
+
point_df = pd.DataFrame(point_data)
|
|
189
|
+
|
|
190
|
+
# For violins, we need much less data - just sample points
|
|
191
|
+
violin_sample_data = []
|
|
192
|
+
for checkpoint in checkpoint_names:
|
|
193
|
+
# Use only 500 samples per checkpoint to avoid memory issues
|
|
194
|
+
y_samples = beta.rvs(stats[checkpoint]["real_alpha"], stats[checkpoint]["real_beta"], size=500) * 100
|
|
195
|
+
xc = point_df[point_df['checkpoint'] == checkpoint]['sim_rate'].values[0]
|
|
196
|
+
for y_val in y_samples:
|
|
197
|
+
violin_sample_data.append({
|
|
198
|
+
'checkpoint': checkpoint,
|
|
199
|
+
'sim_rate': xc,
|
|
200
|
+
'real_rate': y_val
|
|
201
|
+
})
|
|
202
|
+
|
|
203
|
+
violin_df = pd.DataFrame(violin_sample_data)
|
|
204
|
+
|
|
205
|
+
# Compute regression line
|
|
206
|
+
z = np.polyfit(sim_rates, real_rates, 1)
|
|
207
|
+
p = np.poly1d(z)
|
|
208
|
+
xs = np.linspace(sim_rates.min() - 5, sim_rates.max() + 5, 200)
|
|
209
|
+
regression_df = pd.DataFrame({'sim_rate': xs, 'real_rate': p(xs)})
|
|
210
|
+
|
|
211
|
+
# Manually create violin shapes positioned at sim_rate coordinates
|
|
212
|
+
violin_width_scale = 2.5
|
|
213
|
+
violin_polygon_data = []
|
|
214
|
+
|
|
215
|
+
# Color mapping from lowest to highest checkpoint
|
|
216
|
+
checkpoint_colors = {
|
|
217
|
+
"checkpoint_10": "#F8F0FA", # very light purple
|
|
218
|
+
"checkpoint_50": "#D6BAE2", # medium purple
|
|
219
|
+
"checkpoint_64": "#9B66BB", # medium-dark purple
|
|
220
|
+
"checkpoint_80": "#4B136D" # deep purple
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
for checkpoint in checkpoint_names:
|
|
224
|
+
# Get samples and compute KDE
|
|
225
|
+
y_samples = beta.rvs(stats[checkpoint]["real_alpha"], stats[checkpoint]["real_beta"], size=2000) * 100
|
|
226
|
+
kde = gaussian_kde(y_samples)
|
|
227
|
+
|
|
228
|
+
# Create density curve points
|
|
229
|
+
y_points = np.linspace(y_samples.min(), y_samples.max(), 100)
|
|
230
|
+
densities = kde(y_points)
|
|
231
|
+
densities = densities / densities.max() * violin_width_scale
|
|
232
|
+
|
|
233
|
+
xc = point_df[point_df['checkpoint'] == checkpoint]['sim_rate'].values[0]
|
|
234
|
+
color = checkpoint_colors[checkpoint]
|
|
235
|
+
|
|
236
|
+
# Create closed polygon: left side up, then right side down
|
|
237
|
+
for i, (y, d) in enumerate(zip(y_points, densities)):
|
|
238
|
+
violin_polygon_data.append({
|
|
239
|
+
'checkpoint': checkpoint,
|
|
240
|
+
'x': xc - d,
|
|
241
|
+
'y': y,
|
|
242
|
+
'order': i,
|
|
243
|
+
'color': color
|
|
244
|
+
})
|
|
245
|
+
# Right side going back down
|
|
246
|
+
for i, (y, d) in enumerate(zip(y_points[::-1], densities[::-1])):
|
|
247
|
+
violin_polygon_data.append({
|
|
248
|
+
'checkpoint': checkpoint,
|
|
249
|
+
'x': xc + d,
|
|
250
|
+
'y': y,
|
|
251
|
+
'order': len(y_points) + i,
|
|
252
|
+
'color': color
|
|
253
|
+
})
|
|
254
|
+
|
|
255
|
+
violin_polygon_df = pd.DataFrame(violin_polygon_data)
|
|
256
|
+
|
|
257
|
+
# Create violin shapes with color encoding
|
|
258
|
+
violins = alt.Chart(violin_polygon_df).mark_line(
|
|
259
|
+
fillOpacity=0.6,
|
|
260
|
+
stroke='#8B4789',
|
|
261
|
+
strokeWidth=1.1,
|
|
262
|
+
interpolate='linear',
|
|
263
|
+
filled=True
|
|
264
|
+
).encode(
|
|
265
|
+
x=alt.X('x:Q', title='EgoGym Performance (%)').scale(zero=False),
|
|
266
|
+
y=alt.Y('y:Q', title='Real Performance (%)').scale(zero=False),
|
|
267
|
+
order='order:Q',
|
|
268
|
+
detail='checkpoint:N',
|
|
269
|
+
fill=alt.Fill('color:N', scale=None, legend=None)
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
# Create regression line
|
|
274
|
+
regression_line = alt.Chart(regression_df).mark_line(
|
|
275
|
+
strokeDash=[5, 5],
|
|
276
|
+
color='black',
|
|
277
|
+
opacity=0.5,
|
|
278
|
+
size=1.5
|
|
279
|
+
).encode(
|
|
280
|
+
x=alt.X('sim_rate:Q').scale(zero=False),
|
|
281
|
+
y=alt.Y('real_rate:Q').scale(zero=False)
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# Create error bars for simulation
|
|
285
|
+
error_bars = alt.Chart(point_df).mark_errorbar(ticks=True, thickness=1.5).encode(
|
|
286
|
+
x=alt.X('sim_lo:Q', title=''),
|
|
287
|
+
x2=alt.X2('sim_hi:Q'),
|
|
288
|
+
y='real_rate:Q'
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# Create horizontal lines at mean (for each violin)
|
|
292
|
+
mean_line_data = []
|
|
293
|
+
for checkpoint in checkpoint_names:
|
|
294
|
+
xc = point_df[point_df['checkpoint'] == checkpoint]['sim_rate'].values[0]
|
|
295
|
+
yc = point_df[point_df['checkpoint'] == checkpoint]['real_rate'].values[0]
|
|
296
|
+
mean_line_data.append({
|
|
297
|
+
'checkpoint': checkpoint,
|
|
298
|
+
'x_left': xc - 1.0,
|
|
299
|
+
'x_right': xc + 1.0,
|
|
300
|
+
'y': yc
|
|
301
|
+
})
|
|
302
|
+
mean_line_df = pd.DataFrame(mean_line_data)
|
|
303
|
+
|
|
304
|
+
mean_lines = alt.Chart(mean_line_df).mark_rule(color='black', size=1.2).encode(
|
|
305
|
+
x=alt.X('x_left:Q'),
|
|
306
|
+
x2=alt.X2('x_right:Q'),
|
|
307
|
+
y='y:Q'
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Create central points
|
|
311
|
+
points = alt.Chart(point_df).mark_point(
|
|
312
|
+
filled=True,
|
|
313
|
+
size=100,
|
|
314
|
+
color='#A894DB',
|
|
315
|
+
stroke='#8B4789',
|
|
316
|
+
strokeWidth=1.5,
|
|
317
|
+
shape='diamond'
|
|
318
|
+
).encode(
|
|
319
|
+
x=alt.X('sim_rate:Q').scale(zero=False),
|
|
320
|
+
y=alt.Y('real_rate:Q').scale(zero=False),
|
|
321
|
+
tooltip=['checkpoint', 'sim_rate', 'real_rate']
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Combine layers (violins, regression line, error bars, mean lines and points on top)
|
|
325
|
+
chart = (violins + regression_line + mean_lines + error_bars + points).properties(
|
|
326
|
+
title={
|
|
327
|
+
'text': ' Blind EgoGym-Pick Sim-to-Real Correlation',
|
|
328
|
+
'fontSize': 28,
|
|
329
|
+
'anchor': 'start',
|
|
330
|
+
'dx': 25,
|
|
331
|
+
'dy': -10
|
|
332
|
+
},
|
|
333
|
+
width=800,
|
|
334
|
+
height=600
|
|
335
|
+
).configure_axis(
|
|
336
|
+
labelFontSize=20,
|
|
337
|
+
titleFontSize=24,
|
|
338
|
+
titleFontStyle='normal',
|
|
339
|
+
grid=True,
|
|
340
|
+
gridOpacity=0.3,
|
|
341
|
+
tickCount=8
|
|
342
|
+
).configure_view(
|
|
343
|
+
strokeWidth=0
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Add background box for correlation text (wider and thinner)
|
|
347
|
+
correlation_box = alt.Chart(pd.DataFrame([{
|
|
348
|
+
'x': point_df['sim_rate'].max() - 18,
|
|
349
|
+
'y': point_df['real_rate'].min() - 10,
|
|
350
|
+
'x2': point_df['sim_rate'].max() + 4,
|
|
351
|
+
'y2': point_df['real_rate'].min() - 4
|
|
352
|
+
}])).mark_rect(
|
|
353
|
+
fill='#F7D45B',
|
|
354
|
+
stroke='gray',
|
|
355
|
+
strokeWidth=1.5,
|
|
356
|
+
opacity=0.9
|
|
357
|
+
).encode(
|
|
358
|
+
x=alt.X('x:Q'),
|
|
359
|
+
y=alt.Y('y:Q'),
|
|
360
|
+
x2='x2:Q',
|
|
361
|
+
y2='y2:Q'
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# Add correlation text annotation in box
|
|
365
|
+
correlation_text = alt.Chart(pd.DataFrame([{
|
|
366
|
+
'x': point_df['sim_rate'].max() - 7,
|
|
367
|
+
'y': point_df['real_rate'].min() - 7,
|
|
368
|
+
'text': f'95% CI r²: {r_squared_lo:.3f}, {r_squared_hi:.3f}'
|
|
369
|
+
}])).mark_text(
|
|
370
|
+
align='center',
|
|
371
|
+
baseline='middle',
|
|
372
|
+
fontSize=20,
|
|
373
|
+
font='Produkt'
|
|
374
|
+
).encode(
|
|
375
|
+
x=alt.X('x:Q'),
|
|
376
|
+
y=alt.Y('y:Q'),
|
|
377
|
+
text='text:N'
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
# Combine with correlation box and text
|
|
381
|
+
chart = chart + correlation_box + correlation_text
|
|
382
|
+
|
|
383
|
+
# Save chart
|
|
384
|
+
chart.save("checkpoint_correlation.html")
|
|
385
|
+
chart.save("checkpoint_correlation.png", scale_factor=2.0)
|
|
386
|
+
chart.save("checkpoint_correlation.pdf", scale_factor=2.0)
|
|
387
|
+
print(f"\nCorrelation: r = {r:.3f}, r² = {r_squared:.3f}, p = {p_value:.4f}")
|
|
388
|
+
print("\nPlot saved to: checkpoint_correlation.html and checkpoint_correlation.png")
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
if __name__ == "__main__":
|
|
392
|
+
plot_bayesian_correlation()
|