egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
egogym/scripts/plot.py ADDED
@@ -0,0 +1,87 @@
1
+ import os
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+
5
+
6
+ def plot_success_rates(logs_dir):
7
+ """
8
+ Plot success rates from all CSV files found recursively in the logs directory.
9
+
10
+ Args:
11
+ logs_dir: Directory containing subdirectories with CSV files
12
+ """
13
+ success_rates = {}
14
+
15
+ # Recursively find all CSV files
16
+ for root, dirs, files in os.walk(logs_dir):
17
+ for file in files:
18
+ if file.endswith('.csv'):
19
+ csv_path = os.path.join(root, file)
20
+
21
+ # Get relative path from logs_dir
22
+ rel_path = os.path.relpath(csv_path, logs_dir)
23
+
24
+ try:
25
+ # Read the CSV file
26
+ df = pd.read_csv(csv_path, sep="\t")
27
+
28
+ # Calculate success rate (assuming max_reward > 0 means success)
29
+ success_rate = (df["max_reward"] > 0.03).mean() * 100
30
+
31
+ success_rates[rel_path] = success_rate
32
+ print(f"{rel_path}: {success_rate:.2f}%")
33
+
34
+ except Exception as e:
35
+ print(f"Error processing {rel_path}: {e}")
36
+ continue
37
+
38
+ if not success_rates:
39
+ print("No valid CSV files found!")
40
+ return
41
+
42
+ # Sort by relative path
43
+ sorted_items = sorted(success_rates.items())
44
+
45
+ folders, rates = zip(*sorted_items)
46
+
47
+ # Create bar plot
48
+ plt.figure(figsize=(12, 6))
49
+ bars = plt.bar(range(len(folders)), rates, color='steelblue', edgecolor='black')
50
+
51
+ # Add value labels on top of bars
52
+ for i, bar in enumerate(bars):
53
+ height = bar.get_height()
54
+ plt.text(bar.get_x() + bar.get_width()/2., height,
55
+ f'{height:.1f}%',
56
+ ha='center', va='bottom', fontsize=8)
57
+
58
+ plt.xticks(range(len(folders)), folders, rotation=45, ha='right', fontsize=8)
59
+ plt.xlabel('Log File', fontsize=12)
60
+ plt.ylabel('Success Rate (%)', fontsize=12)
61
+ plt.title('Success Rate by Log File', fontsize=14, fontweight='bold')
62
+ plt.ylim(0, 100)
63
+ plt.grid(axis='y', alpha=0.3)
64
+ plt.tight_layout()
65
+
66
+ # Save the plot
67
+ output_path = os.path.join(logs_dir, 'success_rates.png')
68
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
69
+ print(f"\nPlot saved to: {output_path}")
70
+
71
+ plt.show()
72
+
73
+
74
+ if __name__ == "__main__":
75
+ import sys
76
+
77
+ if len(sys.argv) > 1:
78
+ logs_directory = sys.argv[1]
79
+ else:
80
+ # Default to logs folder in the project
81
+ logs_directory = "logs"
82
+
83
+ if not os.path.exists(logs_directory):
84
+ print(f"Error: Directory '{logs_directory}' not found!")
85
+ sys.exit(1)
86
+
87
+ plot_success_rates(logs_directory)
@@ -0,0 +1,392 @@
1
+
2
+ import os
3
+ import numpy as np
4
+ import pandas as pd
5
+ import altair as alt
6
+ from scipy.stats import beta, pearsonr, gaussian_kde
7
+
8
+ BASE_DIR = "logs/5_objects"
9
+ REWARD_THRESHOLD = 0.03
10
+ PARTIAL_LIFT_THRESHOLD = 0.01
11
+
12
+ CHECKPOINT_REAL_SR = {
13
+ "checkpoint_10": 24.0,
14
+ "checkpoint_50": 38.8,
15
+ "checkpoint_64": 67.5,
16
+ "checkpoint_80": 83.2,
17
+ }
18
+
19
+ REAL_SAMPLES = 250
20
+ SIM_SAMPLES_PER_CHECKPOINT = None
21
+
22
+ USE_CSV_FILES = True
23
+
24
+
25
+ def compute_success_from_csv(csv_path):
26
+ if not os.path.exists(csv_path):
27
+ return None, None
28
+ df = pd.read_csv(csv_path, sep="\t")
29
+ successes = (df["max_reward"] > REWARD_THRESHOLD).sum()
30
+ total = len(df)
31
+ return successes, total
32
+
33
+
34
+ def compute_failure_modes_from_csv(csv_path):
35
+ if not os.path.exists(csv_path):
36
+ return None, None
37
+ df = pd.read_csv(csv_path, sep="\t")
38
+ total_episodes = len(df)
39
+ def get_failure_mode(row):
40
+ if row["max_reward"] > REWARD_THRESHOLD:
41
+ return None
42
+ bodies_contacted = str(row.get("grasped_bodies", ""))
43
+ object_name = str(row.get("object_name", ""))
44
+ grasping_object = row.get("grasping_object", False)
45
+ is_grasping = row.get("is_grasping", False)
46
+ has_target_contact = object_name in bodies_contacted
47
+ has_gripper_contact = "left" in bodies_contacted or "right" in bodies_contacted
48
+ has_any_object_contact = "object" in bodies_contacted
49
+ has_wrong_object_contact = has_any_object_contact and not has_target_contact
50
+ if grasping_object:
51
+ if row["max_reward"] >= PARTIAL_LIFT_THRESHOLD:
52
+ return "did not lift enough"
53
+ return "object fell/slipped"
54
+ if has_target_contact:
55
+ return "object fell/slipped"
56
+ if has_wrong_object_contact and (is_grasping or has_gripper_contact):
57
+ return "picked wrong object"
58
+ if has_gripper_contact or is_grasping:
59
+ return "Empty Grasp"
60
+ return "did not grasp"
61
+ df["failure_mode"] = df.apply(get_failure_mode, axis=1)
62
+ failure_modes = {
63
+ "Did not lift enough": 0,
64
+ "Object touched but not grasped": 0,
65
+ "Picked wrong object": 0,
66
+ "Empty Grasp": 0,
67
+ "Did not grasp": 0,
68
+ }
69
+ failure_condition = df["max_reward"] <= REWARD_THRESHOLD
70
+ for _, row in df[failure_condition].iterrows():
71
+ mode = row["failure_mode"]
72
+ if mode == "did not lift enough":
73
+ failure_modes["Did not lift enough"] += 1
74
+ elif mode == "object fell/slipped":
75
+ failure_modes["Object touched but not grasped"] += 1
76
+ elif mode == "picked wrong object":
77
+ failure_modes["Picked wrong object"] += 1
78
+ elif mode == "Empty Grasp":
79
+ failure_modes["Empty Grasp"] += 1
80
+ elif mode == "did not grasp":
81
+ failure_modes["Did not grasp"] += 1
82
+ total_failures = sum(failure_modes.values())
83
+ return failure_modes, total_failures, total_episodes
84
+
85
+
86
+ def plot_bayesian_correlation():
87
+ # Set random seed for reproducibility
88
+ np.random.seed(42)
89
+
90
+ # Register and enable custom font theme
91
+ alt.themes.register('custom_theme', lambda: {
92
+ 'config': {
93
+ 'title': {'font': 'Produkt'},
94
+ 'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
95
+ 'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
96
+ 'mark': {'font': 'Produkt'},
97
+ 'text': {'font': 'Produkt'},
98
+ }
99
+ })
100
+ alt.themes.enable('custom_theme')
101
+
102
+ stats = {}
103
+ checkpoint_order = ["checkpoint_10", "checkpoint_50", "checkpoint_64", "checkpoint_80"]
104
+ for checkpoint in checkpoint_order:
105
+ if checkpoint not in CHECKPOINT_REAL_SR:
106
+ print(f"Warning: {checkpoint} not in CHECKPOINT_REAL_SR")
107
+ continue
108
+ csv_path = os.path.join(BASE_DIR, checkpoint, "log.csv")
109
+ success, total = compute_success_from_csv(csv_path)
110
+ if success is None:
111
+ print(f"Warning: Could not find CSV data for {checkpoint}")
112
+ continue
113
+ real_sr = CHECKPOINT_REAL_SR[checkpoint]
114
+ a_sim = 1 + success
115
+ b_sim = 1 + (total - success)
116
+ sim_mean = 100 * a_sim / (a_sim + b_sim)
117
+ real_successes = int(real_sr / 100 * REAL_SAMPLES)
118
+ a_real = 1 + real_successes
119
+ b_real = 1 + (REAL_SAMPLES - real_successes)
120
+ real_mean = 100 * a_real / (a_real + b_real)
121
+ print(f"{checkpoint}: sim={success}/{total} ({sim_mean:.1f}%), real={real_sr:.1f}% ({real_successes}/{REAL_SAMPLES})")
122
+ stats[checkpoint] = {
123
+ "sim_mean": sim_mean,
124
+ "sim_alpha": a_sim,
125
+ "sim_beta": b_sim,
126
+ "real_mean": real_mean,
127
+ "real_alpha": a_real,
128
+ "real_beta": b_real,
129
+ }
130
+ if not stats:
131
+ print("No data found!")
132
+ return
133
+ checkpoint_names = list(stats.keys())
134
+ sim_rates = np.array([stats[c]["sim_mean"] for c in checkpoint_names])
135
+ real_rates = np.array([stats[c]["real_mean"] for c in checkpoint_names])
136
+
137
+ r, p_value = pearsonr(sim_rates, real_rates)
138
+
139
+ # Compute confidence intervals and violin plot data
140
+ violin_data = []
141
+ point_data = []
142
+ error_data = []
143
+
144
+ for i, checkpoint in enumerate(checkpoint_names):
145
+ sim_lo = beta.ppf(0.025, stats[checkpoint]["sim_alpha"], stats[checkpoint]["sim_beta"]) * 100
146
+ sim_hi = beta.ppf(0.975, stats[checkpoint]["sim_alpha"], stats[checkpoint]["sim_beta"]) * 100
147
+
148
+ # Generate violin data
149
+ y_samples = beta.rvs(stats[checkpoint]["real_alpha"], stats[checkpoint]["real_beta"], size=10000) * 100
150
+
151
+ xc = sim_rates[i]
152
+ yc = real_rates[i]
153
+
154
+ # Store point data
155
+ point_data.append({
156
+ 'checkpoint': checkpoint,
157
+ 'sim_rate': xc,
158
+ 'real_rate': yc,
159
+ 'sim_lo': sim_lo,
160
+ 'sim_hi': sim_hi
161
+ })
162
+
163
+ # Store violin data for density transform
164
+ for y_val in y_samples:
165
+ violin_data.append({
166
+ 'checkpoint': checkpoint,
167
+ 'sim_rate': xc,
168
+ 'real_rate': y_val
169
+ })
170
+
171
+ # Compute bootstrap correlation confidence interval
172
+ n_samples = 1000
173
+ r_samples = []
174
+ for _ in range(n_samples):
175
+ sim_sample = [beta.rvs(stats[c]["sim_alpha"], stats[c]["sim_beta"]) * 100 for c in checkpoint_names]
176
+ real_sample = [beta.rvs(stats[c]["real_alpha"], stats[c]["real_beta"]) * 100 for c in checkpoint_names]
177
+ r_sample, _ = pearsonr(sim_sample, real_sample)
178
+ r_samples.append(r_sample)
179
+ r_samples = np.array(r_samples)
180
+
181
+ # Convert to r² (coefficient of determination)
182
+ r_squared = r ** 2
183
+ r_squared_samples = r_samples ** 2
184
+ r_squared_lo = np.percentile(r_squared_samples, 2.5)
185
+ r_squared_hi = np.percentile(r_squared_samples, 97.5)
186
+
187
+ # Create DataFrames
188
+ point_df = pd.DataFrame(point_data)
189
+
190
+ # For violins, we need much less data - just sample points
191
+ violin_sample_data = []
192
+ for checkpoint in checkpoint_names:
193
+ # Use only 500 samples per checkpoint to avoid memory issues
194
+ y_samples = beta.rvs(stats[checkpoint]["real_alpha"], stats[checkpoint]["real_beta"], size=500) * 100
195
+ xc = point_df[point_df['checkpoint'] == checkpoint]['sim_rate'].values[0]
196
+ for y_val in y_samples:
197
+ violin_sample_data.append({
198
+ 'checkpoint': checkpoint,
199
+ 'sim_rate': xc,
200
+ 'real_rate': y_val
201
+ })
202
+
203
+ violin_df = pd.DataFrame(violin_sample_data)
204
+
205
+ # Compute regression line
206
+ z = np.polyfit(sim_rates, real_rates, 1)
207
+ p = np.poly1d(z)
208
+ xs = np.linspace(sim_rates.min() - 5, sim_rates.max() + 5, 200)
209
+ regression_df = pd.DataFrame({'sim_rate': xs, 'real_rate': p(xs)})
210
+
211
+ # Manually create violin shapes positioned at sim_rate coordinates
212
+ violin_width_scale = 2.5
213
+ violin_polygon_data = []
214
+
215
+ # Color mapping from lowest to highest checkpoint
216
+ checkpoint_colors = {
217
+ "checkpoint_10": "#F8F0FA", # very light purple
218
+ "checkpoint_50": "#D6BAE2", # medium purple
219
+ "checkpoint_64": "#9B66BB", # medium-dark purple
220
+ "checkpoint_80": "#4B136D" # deep purple
221
+ }
222
+
223
+ for checkpoint in checkpoint_names:
224
+ # Get samples and compute KDE
225
+ y_samples = beta.rvs(stats[checkpoint]["real_alpha"], stats[checkpoint]["real_beta"], size=2000) * 100
226
+ kde = gaussian_kde(y_samples)
227
+
228
+ # Create density curve points
229
+ y_points = np.linspace(y_samples.min(), y_samples.max(), 100)
230
+ densities = kde(y_points)
231
+ densities = densities / densities.max() * violin_width_scale
232
+
233
+ xc = point_df[point_df['checkpoint'] == checkpoint]['sim_rate'].values[0]
234
+ color = checkpoint_colors[checkpoint]
235
+
236
+ # Create closed polygon: left side up, then right side down
237
+ for i, (y, d) in enumerate(zip(y_points, densities)):
238
+ violin_polygon_data.append({
239
+ 'checkpoint': checkpoint,
240
+ 'x': xc - d,
241
+ 'y': y,
242
+ 'order': i,
243
+ 'color': color
244
+ })
245
+ # Right side going back down
246
+ for i, (y, d) in enumerate(zip(y_points[::-1], densities[::-1])):
247
+ violin_polygon_data.append({
248
+ 'checkpoint': checkpoint,
249
+ 'x': xc + d,
250
+ 'y': y,
251
+ 'order': len(y_points) + i,
252
+ 'color': color
253
+ })
254
+
255
+ violin_polygon_df = pd.DataFrame(violin_polygon_data)
256
+
257
+ # Create violin shapes with color encoding
258
+ violins = alt.Chart(violin_polygon_df).mark_line(
259
+ fillOpacity=0.6,
260
+ stroke='#8B4789',
261
+ strokeWidth=1.1,
262
+ interpolate='linear',
263
+ filled=True
264
+ ).encode(
265
+ x=alt.X('x:Q', title='EgoGym Performance (%)').scale(zero=False),
266
+ y=alt.Y('y:Q', title='Real Performance (%)').scale(zero=False),
267
+ order='order:Q',
268
+ detail='checkpoint:N',
269
+ fill=alt.Fill('color:N', scale=None, legend=None)
270
+ )
271
+
272
+
273
+ # Create regression line
274
+ regression_line = alt.Chart(regression_df).mark_line(
275
+ strokeDash=[5, 5],
276
+ color='black',
277
+ opacity=0.5,
278
+ size=1.5
279
+ ).encode(
280
+ x=alt.X('sim_rate:Q').scale(zero=False),
281
+ y=alt.Y('real_rate:Q').scale(zero=False)
282
+ )
283
+
284
+ # Create error bars for simulation
285
+ error_bars = alt.Chart(point_df).mark_errorbar(ticks=True, thickness=1.5).encode(
286
+ x=alt.X('sim_lo:Q', title=''),
287
+ x2=alt.X2('sim_hi:Q'),
288
+ y='real_rate:Q'
289
+ )
290
+
291
+ # Create horizontal lines at mean (for each violin)
292
+ mean_line_data = []
293
+ for checkpoint in checkpoint_names:
294
+ xc = point_df[point_df['checkpoint'] == checkpoint]['sim_rate'].values[0]
295
+ yc = point_df[point_df['checkpoint'] == checkpoint]['real_rate'].values[0]
296
+ mean_line_data.append({
297
+ 'checkpoint': checkpoint,
298
+ 'x_left': xc - 1.0,
299
+ 'x_right': xc + 1.0,
300
+ 'y': yc
301
+ })
302
+ mean_line_df = pd.DataFrame(mean_line_data)
303
+
304
+ mean_lines = alt.Chart(mean_line_df).mark_rule(color='black', size=1.2).encode(
305
+ x=alt.X('x_left:Q'),
306
+ x2=alt.X2('x_right:Q'),
307
+ y='y:Q'
308
+ )
309
+
310
+ # Create central points
311
+ points = alt.Chart(point_df).mark_point(
312
+ filled=True,
313
+ size=100,
314
+ color='#A894DB',
315
+ stroke='#8B4789',
316
+ strokeWidth=1.5,
317
+ shape='diamond'
318
+ ).encode(
319
+ x=alt.X('sim_rate:Q').scale(zero=False),
320
+ y=alt.Y('real_rate:Q').scale(zero=False),
321
+ tooltip=['checkpoint', 'sim_rate', 'real_rate']
322
+ )
323
+
324
+ # Combine layers (violins, regression line, error bars, mean lines and points on top)
325
+ chart = (violins + regression_line + mean_lines + error_bars + points).properties(
326
+ title={
327
+ 'text': ' Blind EgoGym-Pick Sim-to-Real Correlation',
328
+ 'fontSize': 28,
329
+ 'anchor': 'start',
330
+ 'dx': 25,
331
+ 'dy': -10
332
+ },
333
+ width=800,
334
+ height=600
335
+ ).configure_axis(
336
+ labelFontSize=20,
337
+ titleFontSize=24,
338
+ titleFontStyle='normal',
339
+ grid=True,
340
+ gridOpacity=0.3,
341
+ tickCount=8
342
+ ).configure_view(
343
+ strokeWidth=0
344
+ )
345
+
346
+ # Add background box for correlation text (wider and thinner)
347
+ correlation_box = alt.Chart(pd.DataFrame([{
348
+ 'x': point_df['sim_rate'].max() - 18,
349
+ 'y': point_df['real_rate'].min() - 10,
350
+ 'x2': point_df['sim_rate'].max() + 4,
351
+ 'y2': point_df['real_rate'].min() - 4
352
+ }])).mark_rect(
353
+ fill='#F7D45B',
354
+ stroke='gray',
355
+ strokeWidth=1.5,
356
+ opacity=0.9
357
+ ).encode(
358
+ x=alt.X('x:Q'),
359
+ y=alt.Y('y:Q'),
360
+ x2='x2:Q',
361
+ y2='y2:Q'
362
+ )
363
+
364
+ # Add correlation text annotation in box
365
+ correlation_text = alt.Chart(pd.DataFrame([{
366
+ 'x': point_df['sim_rate'].max() - 7,
367
+ 'y': point_df['real_rate'].min() - 7,
368
+ 'text': f'95% CI r²: {r_squared_lo:.3f}, {r_squared_hi:.3f}'
369
+ }])).mark_text(
370
+ align='center',
371
+ baseline='middle',
372
+ fontSize=20,
373
+ font='Produkt'
374
+ ).encode(
375
+ x=alt.X('x:Q'),
376
+ y=alt.Y('y:Q'),
377
+ text='text:N'
378
+ )
379
+
380
+ # Combine with correlation box and text
381
+ chart = chart + correlation_box + correlation_text
382
+
383
+ # Save chart
384
+ chart.save("checkpoint_correlation.html")
385
+ chart.save("checkpoint_correlation.png", scale_factor=2.0)
386
+ chart.save("checkpoint_correlation.pdf", scale_factor=2.0)
387
+ print(f"\nCorrelation: r = {r:.3f}, r² = {r_squared:.3f}, p = {p_value:.4f}")
388
+ print("\nPlot saved to: checkpoint_correlation.html and checkpoint_correlation.png")
389
+
390
+
391
+ if __name__ == "__main__":
392
+ plot_bayesian_correlation()