egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,834 @@
1
+ import os
2
+ import sys
3
+ import re
4
+ import pandas as pd
5
+ import yaml
6
+ import webbrowser
7
+ import numpy as np
8
+ from datetime import datetime
9
+
10
+
11
+ def determine_grasp_type(gripper_pose_str):
12
+ try:
13
+ if isinstance(gripper_pose_str, np.ndarray):
14
+ gripper_pose = gripper_pose_str.reshape(4, 4)
15
+ elif isinstance(gripper_pose_str, str):
16
+ gripper_pose = np.array(eval(gripper_pose_str)).reshape(4, 4)
17
+ else:
18
+ return "Unknown"
19
+
20
+ z_axis = gripper_pose[:3, 2]
21
+ z_axis = z_axis / np.linalg.norm(z_axis)
22
+
23
+ down_direction = np.array([0, 0, 1])
24
+ cos_angle = np.dot(z_axis, down_direction)
25
+
26
+ if cos_angle > 0.7:
27
+ return "Top-down"
28
+ else:
29
+ return "Side"
30
+ except Exception as e:
31
+ return "Unknown"
32
+
33
+
34
+ def analyze_logs(log_folder):
35
+ log_file_path = os.path.join(log_folder, "log.csv")
36
+ if not os.path.exists(log_file_path):
37
+ log_file_path = os.path.join(log_folder, "logs.txt")
38
+ if not os.path.exists(log_file_path):
39
+ print(f"Log file not found: {log_file_path}")
40
+ return
41
+
42
+ task_name = "pick"
43
+
44
+ analysis_folder = os.path.join(log_folder, "analysis")
45
+ os.makedirs(analysis_folder, exist_ok=True)
46
+
47
+ print(f"Reading logs from {log_file_path}...")
48
+ try:
49
+ df = pd.read_csv(log_file_path, sep="\t", on_bad_lines='skip')
50
+ except pd.errors.ParserError as e:
51
+ print(f"Error reading log file: {e}")
52
+ return
53
+
54
+ run_folder_name = os.path.basename(log_folder)
55
+ date_str = ""
56
+
57
+ if run_folder_name.startswith("run_") or run_folder_name.startswith("evaluation_"):
58
+ match = re.search(r"(\d{8})|(\d{10})", run_folder_name)
59
+ if match:
60
+ date_raw = match.group(0)
61
+ try:
62
+ if len(date_raw) == 10:
63
+ timestamp = int(date_raw)
64
+ date_obj = datetime.fromtimestamp(timestamp)
65
+ date_str = date_obj.strftime("%Y-%m-%d")
66
+ else:
67
+ year = date_raw[0:4]
68
+ month = date_raw[4:6]
69
+ day = date_raw[6:8]
70
+ date_str = f"{year}-{month}-{day}"
71
+ except Exception:
72
+ date_str = date_raw
73
+
74
+ title = f"RUM-{task_name} Analysis {date_str}"
75
+
76
+ is_open_task = task_name.lower() == "open"
77
+
78
+ if is_open_task:
79
+
80
+ success_rate = df["success"].mean() * 100
81
+ threshold_text = "Opened more than 50%"
82
+
83
+ def parse_joint_value(val):
84
+ try:
85
+ if isinstance(val, str):
86
+ val = val.strip('"[]')
87
+ return float(val)
88
+ except:
89
+ return 0.0
90
+
91
+ df["max_normalized_joint_parsed"] = df["max_normalized_joint"].apply(parse_joint_value)
92
+
93
+ bins = [0, 0.1, 0.2, 0.3, 0.4, 0.5]
94
+ hist_counts, bin_edges = np.histogram(df["max_normalized_joint_parsed"].clip(0, 0.5), bins=bins)
95
+ histogram_data = list(zip([f"{bins[i]:.1f}-{bins[i+1]:.1f}" for i in range(len(bins)-1)], hist_counts.tolist()))
96
+ max_count = max(hist_counts) if len(hist_counts) > 0 else 1
97
+
98
+ handle_success = (
99
+ df.groupby("handle_type")
100
+ .apply(lambda x: x["success"].mean() * 100)
101
+ .reset_index()
102
+ )
103
+ handle_success.columns = ["handle_type", "success_rate"]
104
+ handle_success = handle_success.sort_values("success_rate", ascending=False)
105
+
106
+ handle_counts = df["handle_type"].value_counts().reset_index()
107
+ handle_counts.columns = ["handle_type", "count"]
108
+
109
+ handle_data = pd.merge(handle_success, handle_counts, on="handle_type")
110
+
111
+ door_type_data = None
112
+ if "is_double_door" in df.columns:
113
+ df["door_type"] = df["is_double_door"].apply(lambda x: "Double Door" if x else "Single Door")
114
+ door_success = (
115
+ df.groupby("door_type")
116
+ .apply(lambda x: x["success"].mean() * 100)
117
+ .reset_index()
118
+ )
119
+ door_success.columns = ["door_type", "success_rate"]
120
+ door_counts = df["door_type"].value_counts().reset_index()
121
+ door_counts.columns = ["door_type", "count"]
122
+ door_type_data = pd.merge(door_success, door_counts, on="door_type")
123
+
124
+ hinge_side_data = None
125
+ if "hinge_side" in df.columns:
126
+ hinge_success = (
127
+ df.groupby("hinge_side")
128
+ .apply(lambda x: x["success"].mean() * 100)
129
+ .reset_index()
130
+ )
131
+ hinge_success.columns = ["hinge_side", "success_rate"]
132
+ hinge_counts = df["hinge_side"].value_counts().reset_index()
133
+ hinge_counts.columns = ["hinge_side", "count"]
134
+ hinge_side_data = pd.merge(hinge_success, hinge_counts, on="hinge_side")
135
+
136
+ html_path = os.path.join(analysis_folder, "analysis_report.html")
137
+ results_csv = df.to_csv(index=False)
138
+ handle_csv = handle_data.to_csv(index=False)
139
+
140
+ def escape_csv_for_js(csv_str):
141
+ return (
142
+ csv_str.replace("\\", "\\\\")
143
+ .replace("'", "\\'")
144
+ .replace("\n", "\\n")
145
+ .replace("\r", "\\r")
146
+ )
147
+
148
+ results_csv_js = escape_csv_for_js(results_csv)
149
+ handle_csv_js = escape_csv_for_js(handle_csv)
150
+
151
+ js_script = """
152
+ function downloadCSV(filename, csvContent) {
153
+ const blob = new Blob([csvContent], { type: 'text/csv;charset=utf-8;' });
154
+ const url = URL.createObjectURL(blob);
155
+ const link = document.createElement('a');
156
+ link.href = url;
157
+ link.setAttribute('download', filename);
158
+ document.body.appendChild(link);
159
+ link.click();
160
+ document.body.removeChild(link);
161
+ URL.revokeObjectURL(url);
162
+ }
163
+ """
164
+
165
+ html_content = f"""
166
+ <!DOCTYPE html>
167
+ <html>
168
+ <head>
169
+ <title>{title}</title>
170
+ <meta charset="utf-8">
171
+ <script>
172
+ {js_script}
173
+ </script>
174
+ <style>
175
+ body {{
176
+ font-family: 'Verdana', 'Arial', sans-serif;
177
+ margin: 20px;
178
+ background-color: #ffffff;
179
+ color: #000000;
180
+ max-width: 800px;
181
+ margin: 0 auto;
182
+ padding: 10px;
183
+ font-size: 12px;
184
+ line-height: 1.3;
185
+ }}
186
+ .download-btn {{
187
+ display: inline-block;
188
+ padding: 1px 3px;
189
+ background-color: transparent;
190
+ color: #555;
191
+ font-size: 10px;
192
+ text-decoration: underline;
193
+ cursor: pointer;
194
+ font-weight: normal;
195
+ text-align: left;
196
+ margin-top: 2px;
197
+ margin-bottom: 10px;
198
+ }}
199
+ .download-btn:hover {{
200
+ color: #000;
201
+ }}
202
+ h1 {{
203
+ color: #000080;
204
+ font-size: 18px;
205
+ border-bottom: 1px solid #cccccc;
206
+ padding-bottom: 5px;
207
+ margin-top: 15px;
208
+ }}
209
+ h2 {{
210
+ color: #000080;
211
+ font-size: 16px;
212
+ margin-top: 15px;
213
+ border-bottom: 1px solid #eeeeee;
214
+ padding-bottom: 3px;
215
+ }}
216
+ h3 {{
217
+ color: #000080;
218
+ font-size: 14px;
219
+ margin-top: 10px;
220
+ }}
221
+ .timestamp {{
222
+ color: #666666;
223
+ font-size: 11px;
224
+ margin-bottom: 15px;
225
+ }}
226
+ .success-rate {{
227
+ font-size: 32px;
228
+ font-weight: bold;
229
+ color: #008000;
230
+ text-align: center;
231
+ margin: 15px 0;
232
+ }}
233
+ .threshold {{
234
+ text-align: center;
235
+ color: #666666;
236
+ margin-bottom: 15px;
237
+ font-size: 11px;
238
+ }}
239
+ table {{
240
+ border-collapse: collapse;
241
+ width: 100%;
242
+ border: 1px solid #cccccc;
243
+ margin: 15px 0;
244
+ }}
245
+ th, td {{
246
+ text-align: left;
247
+ padding: 5px;
248
+ border: 1px solid #cccccc;
249
+ font-size: 12px;
250
+ }}
251
+ th {{
252
+ background-color: #eeeeee;
253
+ }}
254
+ ul {{
255
+ margin: 10px 0;
256
+ padding-left: 20px;
257
+ list-style-type: square;
258
+ }}
259
+ li {{
260
+ margin-bottom: 5px;
261
+ }}
262
+ .info-box {{
263
+ border: 1px solid #cccccc;
264
+ background-color: #f5f5f5;
265
+ padding: 10px;
266
+ margin: 15px 0;
267
+ }}
268
+ .section {{
269
+ margin-bottom: 20px;
270
+ }}
271
+ .histogram {{
272
+ display: flex;
273
+ align-items: flex-end;
274
+ height: 120px;
275
+ gap: 8px;
276
+ padding: 10px 20px;
277
+ }}
278
+ .histogram-bar-label {{
279
+ text-align: center;
280
+ font-size: 10px;
281
+ margin-top: 5px;
282
+ color: #333;
283
+ }}
284
+ .histogram-bar-count {{
285
+ text-align: center;
286
+ font-size: 10px;
287
+ margin-bottom: 3px;
288
+ color: #333;
289
+ font-weight: bold;
290
+ }}
291
+ .histogram-container {{
292
+ margin: 20px 0;
293
+ }}
294
+ .histogram-title {{
295
+ text-align: center;
296
+ margin-top: 15px;
297
+ font-size: 11px;
298
+ color: #666;
299
+ }}
300
+ .histogram-bar-wrapper {{
301
+ display: flex;
302
+ flex-direction: column;
303
+ align-items: center;
304
+ flex: 1;
305
+ }}
306
+ .histogram-bar-inner {{
307
+ width: 100%;
308
+ background-color: #4a90d9;
309
+ }}
310
+ .histogram-bar-inner.success {{
311
+ background-color: #2e7d32;
312
+ }}
313
+ </style>
314
+ </head>
315
+ <body>
316
+ <h1>{title}</h1>
317
+ <div class="timestamp">Generated on {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}</div>
318
+
319
+ <div class="section">
320
+ <div class="success-rate">{success_rate:.2f}%</div>
321
+ <div class="threshold">{threshold_text}</div>
322
+
323
+ <div class="info-box">
324
+ <h3>Summary Statistics</h3>
325
+ <ul>
326
+ <li>Total episodes: {len(df)}</li>
327
+ </ul>
328
+ </div>
329
+ </div>
330
+
331
+ <div class="section">
332
+ <h2>Success Rate by Handle Type</h2>
333
+ <table>
334
+ <tr>
335
+ <th>Handle Type</th>
336
+ <th>Success Rate (%)</th>
337
+ <th>Count</th>
338
+ </tr>
339
+ {"".join([f"<tr><td>{row.handle_type}</td><td>{row.success_rate:.2f}</td><td>{row['count']}</td></tr>" for _, row in handle_data.iterrows()])}
340
+ </table>
341
+ <span class="download-btn" onclick="downloadCSV('handle_types.csv', '{handle_csv_js}')">Save CSV</span>
342
+ </div>
343
+
344
+ {"" if door_type_data is None else f'''
345
+ <div class="section">
346
+ <h2>Success Rate by Door Type</h2>
347
+ <table>
348
+ <tr>
349
+ <th>Door Type</th>
350
+ <th>Success Rate (%)</th>
351
+ <th>Count</th>
352
+ </tr>
353
+ {"".join([f"<tr><td>{row.door_type}</td><td>{row.success_rate:.2f}</td><td>{row['count']}</td></tr>" for _, row in door_type_data.iterrows()])}
354
+ </table>
355
+ </div>
356
+ '''}
357
+
358
+ {"" if hinge_side_data is None else f'''
359
+ <div class="section">
360
+ <h2>Success Rate by Hinge Side</h2>
361
+ <table>
362
+ <tr>
363
+ <th>Hinge Side</th>
364
+ <th>Success Rate (%)</th>
365
+ <th>Count</th>
366
+ </tr>
367
+ {"".join([f"<tr><td>{row.hinge_side}</td><td>{row.success_rate:.2f}</td><td>{row['count']}</td></tr>" for _, row in hinge_side_data.iterrows()])}
368
+ </table>
369
+ </div>
370
+ '''}
371
+
372
+ <div class="section">
373
+ <h2>Max Normalized Joint Distribution</h2>
374
+ <div class="histogram-container">
375
+ <div class="histogram">
376
+ {"".join([f'<div class="histogram-bar-wrapper"><div class="histogram-bar-count">{count}</div><div class="histogram-bar-inner {"success" if float(label.split("-")[1]) >= 0.5 else ""}" style="height: {(count / max_count * 100) if max_count > 0 else 0}px; max-height: 100px;"></div><div class="histogram-bar-label">{label}</div></div>' for label, count in histogram_data])}
377
+ </div>
378
+ <div class="histogram-title">Max Normalized Joint Value (green = success at 0.5)</div>
379
+ </div>
380
+ </div>
381
+
382
+ <div style="text-align: center; margin-top: 20px; margin-bottom: 10px;">
383
+ <span class="download-btn" style="margin: 0; text-align: center; font-size: 12px; padding: 2px 4px;" onclick="downloadCSV('all_results.csv', '{results_csv_js}')">Download All Results as CSV</span>
384
+ </div>
385
+ </body>
386
+ </html>
387
+ """
388
+
389
+ with open(html_path, "w") as f:
390
+ f.write(html_content)
391
+
392
+ print(f"HTML report generated at {html_path}")
393
+ print("Opening report in browser...")
394
+ webbrowser.open("file://" + os.path.abspath(html_path))
395
+ return
396
+
397
+ # For non-open tasks (pick, etc.)
398
+ lift_threshold = 0.03
399
+ lift_success = (df["max_reward"] > lift_threshold).mean() * 100
400
+
401
+ df["object_category"] = df["object_name"].apply(
402
+ lambda x: re.sub(r"[^a-zA-Z]", "", x)
403
+ )
404
+
405
+ category_success = (
406
+ df.groupby("object_category", group_keys=False)
407
+ .apply(lambda x: (x["max_reward"] > lift_threshold).mean() * 100, include_groups=False)
408
+ .reset_index()
409
+ )
410
+ category_success.columns = ["object_category", "success_rate"]
411
+
412
+ category_success = category_success.sort_values("success_rate", ascending=False)
413
+
414
+ category_counts = df["object_category"].value_counts().reset_index()
415
+ category_counts.columns = ["object_category", "count"]
416
+
417
+ category_data = pd.merge(category_success, category_counts, on="object_category")
418
+ print(f"success rate all: {lift_success:.2f}%")
419
+ print("Determining grasp types...")
420
+ df["grasp_type"] = df["gripper_current_position"].apply(determine_grasp_type)
421
+
422
+ successful_df = df[df["max_reward"] > lift_threshold].copy()
423
+ grasp_type_by_category = (
424
+ successful_df.groupby(["object_category", "grasp_type"]).size().reset_index()
425
+ )
426
+ grasp_type_by_category.columns = ["object_category", "grasp_type", "count"]
427
+
428
+ category_total = (
429
+ grasp_type_by_category.groupby("object_category")["count"].sum().reset_index()
430
+ )
431
+ category_total.columns = ["object_category", "total"]
432
+
433
+ grasp_type_by_category = pd.merge(
434
+ grasp_type_by_category, category_total, on="object_category"
435
+ )
436
+ grasp_type_by_category["percentage"] = (
437
+ grasp_type_by_category["count"] / grasp_type_by_category["total"]
438
+ ) * 100
439
+
440
+ grasp_pivot = grasp_type_by_category.pivot_table(
441
+ index="object_category", columns="grasp_type", values="percentage", fill_value=0
442
+ ).reset_index()
443
+
444
+ category_data = pd.merge(
445
+ category_data, grasp_pivot, on="object_category", how="left"
446
+ )
447
+
448
+ if "Side" not in category_data.columns:
449
+ category_data["Side"] = 0
450
+ if "Top-down" not in category_data.columns:
451
+ category_data["Top-down"] = 0
452
+ if "Unknown" not in category_data.columns:
453
+ category_data["Unknown"] = 0
454
+
455
+ texture_success = (
456
+ df.groupby("texture_name", group_keys=False)
457
+ .apply(lambda x: (x["max_reward"] > lift_threshold).mean() * 100, include_groups=False)
458
+ .reset_index()
459
+ )
460
+ texture_success.columns = ["texture_name", "success_rate"]
461
+
462
+ texture_success = texture_success.sort_values("success_rate", ascending=False)
463
+
464
+ texture_counts = df["texture_name"].value_counts().reset_index()
465
+ texture_counts.columns = ["texture_name", "count"]
466
+
467
+ texture_data = pd.merge(texture_success, texture_counts, on="texture_name")
468
+
469
+ # Add object-specific success rates (not just categories)
470
+ object_success = (
471
+ df.groupby("object_name", group_keys=False)
472
+ .apply(lambda x: (x["max_reward"] > lift_threshold).mean() * 100, include_groups=False)
473
+ .reset_index()
474
+ )
475
+ object_success.columns = ["object_name", "success_rate"]
476
+
477
+ object_counts = df["object_name"].value_counts().reset_index()
478
+ object_counts.columns = ["object_name", "count"]
479
+
480
+ object_data = pd.merge(object_success, object_counts, on="object_name")
481
+ object_data = object_data.sort_values("success_rate", ascending=False)
482
+
483
+ def get_failure_mode(row):
484
+ if row["max_reward"] > lift_threshold:
485
+ return None
486
+
487
+ bodies_contacted = str(row.get("grasped_bodies", ""))
488
+ object_name = str(row.get("object_name", ""))
489
+ grasping_object = row.get("grasping_object", False)
490
+ is_grasping = row.get("is_grasping", False)
491
+
492
+
493
+ if object_name in bodies_contacted:
494
+ if grasping_object:
495
+ return "did not lift enough"
496
+ else:
497
+ return "object fell/slipped"
498
+ else:
499
+ if "object" in bodies_contacted and is_grasping and "left" not in bodies_contacted and "right" not in bodies_contacted:
500
+ return "picked wrong object"
501
+ else:
502
+ if "left" in bodies_contacted or "right" in bodies_contacted:
503
+ return "Empty Grasp"
504
+ else:
505
+ return "did not grasp"
506
+
507
+ df["failure_mode"] = df.apply(get_failure_mode, axis=1)
508
+
509
+ failure_modes = {
510
+ "Did not lift enough": 0,
511
+ "Object grasped but slipped": 0,
512
+ "Picked wrong object": 0,
513
+ "Empty Grasp": 0,
514
+ "Did not grasp": 0,
515
+ }
516
+
517
+ failure_condition = df["max_reward"] <= lift_threshold
518
+
519
+ for idx, row in df[failure_condition].iterrows():
520
+ mode = row["failure_mode"]
521
+ if mode == "did not lift enough":
522
+ failure_modes["Did not lift enough"] += 1
523
+ elif mode == "object fell/slipped":
524
+ failure_modes["Object grasped but slipped"] += 1
525
+ elif mode == "picked wrong object":
526
+ failure_modes["Picked wrong object"] += 1
527
+ elif mode == "Empty Grasp":
528
+ failure_modes["Empty Grasp"] += 1
529
+ elif mode == "did not grasp":
530
+ failure_modes["Did not grasp"] += 1
531
+
532
+ failure_modes_data = pd.DataFrame(
533
+ {
534
+ "failure_mode": list(failure_modes.keys()),
535
+ "count": list(failure_modes.values()),
536
+ }
537
+ )
538
+ failure_modes_data = failure_modes_data.sort_values("count", ascending=False)
539
+
540
+ # Calculate percentage of failures
541
+ total_failures = sum(failure_modes.values())
542
+ if total_failures > 0:
543
+ failure_modes_data["percentage"] = (
544
+ failure_modes_data["count"] / total_failures * 100
545
+ )
546
+ else:
547
+ failure_modes_data["percentage"] = 0
548
+
549
+ print("Texture data:")
550
+ print(texture_data.head())
551
+ print(f"Texture data shape: {texture_data.shape}")
552
+ print("Failure modes:")
553
+ print(failure_modes_data)
554
+ html_path = os.path.join(analysis_folder, "analysis_report.html")
555
+
556
+ # Create CSV strings for embedding in HTML
557
+ category_csv = category_data.to_csv(index=False)
558
+ texture_csv = texture_data.to_csv(index=False)
559
+ failure_csv = failure_modes_data.to_csv(index=False)
560
+ results_csv = df.to_csv(index=False)
561
+ object_csv = object_data.to_csv(index=False)
562
+
563
+ # Function to escape CSV data for JavaScript
564
+ def escape_csv_for_js(csv_str):
565
+ return (
566
+ csv_str.replace("\\", "\\\\")
567
+ .replace("'", "\\'")
568
+ .replace("\n", "\\n")
569
+ .replace("\r", "\\r")
570
+ )
571
+
572
+ # Escape CSV content for JavaScript
573
+ category_csv_js = escape_csv_for_js(category_csv)
574
+ texture_csv_js = escape_csv_for_js(texture_csv)
575
+ failure_csv_js = escape_csv_for_js(failure_csv)
576
+ results_csv_js = escape_csv_for_js(results_csv)
577
+ object_csv_js = escape_csv_for_js(object_csv)
578
+
579
+ # JavaScript for downloading CSV
580
+ js_script = """
581
+ function downloadCSV(filename, csvContent) {
582
+ // Create a blob with the CSV data
583
+ const blob = new Blob([csvContent], { type: 'text/csv;charset=utf-8;' });
584
+
585
+ // Create a temporary URL for the blob
586
+ const url = URL.createObjectURL(blob);
587
+
588
+ // Create a temporary link element
589
+ const link = document.createElement('a');
590
+ link.href = url;
591
+ link.setAttribute('download', filename);
592
+
593
+ // Append to the body, click it, and remove it
594
+ document.body.appendChild(link);
595
+ link.click();
596
+ document.body.removeChild(link);
597
+
598
+ // Clean up the URL
599
+ URL.revokeObjectURL(url);
600
+ }
601
+ """
602
+
603
+ if is_open_task:
604
+ additional_sections = ""
605
+ threshold_text = "Opened more than 50%"
606
+ else:
607
+ threshold_text = f"Lifted above {lift_threshold * 100:.0f} cm"
608
+ additional_sections = f"""
609
+ <div class="section">
610
+ <h2>Failure Modes</h2>
611
+ <table>
612
+ <tr>
613
+ <th>Failure Mode</th>
614
+ <th>Count</th>
615
+ <th>Percentage of Failures</th>
616
+ </tr>
617
+ {"".join([f"<tr><td>{row.failure_mode}</td><td>{row['count']}</td><td>{(row['count'] / sum(failure_modes.values()) * 100):.2f}%</td></tr>" for _, row in failure_modes_data.iterrows() if sum(failure_modes.values()) > 0])}
618
+ </table>
619
+ <span class="download-btn" onclick="downloadCSV('failure_modes.csv', '{failure_csv_js}')">Save CSV</span>
620
+ </div>
621
+
622
+ <div class="section">
623
+ <h2>Success Rate by Table Texture</h2>
624
+ <table>
625
+ <tr>
626
+ <th>Table Texture</th>
627
+ <th>Success Rate (%)</th>
628
+ <th>Count</th>
629
+ </tr>
630
+ {"".join([f"<tr><td>{row.texture_name}</td><td>{row.success_rate:.2f}</td><td>{row['count']}</td></tr>" for _, row in texture_data.iterrows()])}
631
+ </table>
632
+ <span class="download-btn" onclick="downloadCSV('textures.csv', '{texture_csv_js}')">Save CSV</span>
633
+ </div>
634
+
635
+ <div class="section">
636
+ <h2>Success Rate by Object Name</h2>
637
+ <table>
638
+ <tr>
639
+ <th>Object Name</th>
640
+ <th>Success Rate (%)</th>
641
+ <th>Count</th>
642
+ <th>Grasp Types (%)</th>
643
+ </tr>
644
+ {"".join([f'<tr><td>{row.object_category}</td><td>{row.success_rate:.2f}</td><td>{row["count"]}</td><td><span style="color: {"#000000" if row.get("Top-down", 0) >= row.get("Side", 0) else "#888888"}">Top-down ({row.get("Top-down", 0):.1f})</span> / <span style="color: {"#000000" if row.get("Side", 0) >= row.get("Top-down", 0) else "#888888"}">Side ({row.get("Side", 0):.1f})</span></td></tr>' for _, row in category_data.iterrows()])}
645
+ </table>
646
+ <span class="download-btn" onclick="downloadCSV('object_categories.csv', '{category_csv_js}')">Save CSV</span>
647
+ </div>
648
+
649
+ <div class="section">
650
+ <h2>Success Rate by Individual Object</h2>
651
+ <table>
652
+ <tr>
653
+ <th>Object Name</th>
654
+ <th>Success Rate (%)</th>
655
+ <th>Count</th>
656
+ </tr>
657
+ {"".join([f"<tr><td>{row.object_name}</td><td>{row.success_rate:.2f}</td><td>{row['count']}</td></tr>" for _, row in object_data.iterrows()])}
658
+ </table>
659
+ <span class="download-btn" onclick="downloadCSV('individual_objects.csv', '{object_csv_js}')">Save CSV</span>
660
+ </div>
661
+ """
662
+
663
+ html_content = f"""
664
+ <!DOCTYPE html>
665
+ <html>
666
+ <head>
667
+ <title>{title}</title>
668
+ <meta charset="utf-8">
669
+ <script>
670
+ {js_script}
671
+ </script>
672
+ <style>
673
+ body {{
674
+ font-family: 'Verdana', 'Arial', sans-serif;
675
+ margin: 20px;
676
+ background-color: #ffffff;
677
+ color: #000000;
678
+ max-width: 800px;
679
+ margin: 0 auto;
680
+ padding: 10px;
681
+ font-size: 12px;
682
+ line-height: 1.3;
683
+ }}
684
+
685
+ .download-btn {{
686
+ display: inline-block;
687
+ padding: 1px 3px;
688
+ background-color: transparent;
689
+ color: #555;
690
+ font-size: 10px;
691
+ text-decoration: underline;
692
+ cursor: pointer;
693
+ font-weight: normal;
694
+ text-align: left;
695
+ margin-top: 2px;
696
+ margin-bottom: 10px;
697
+ }}
698
+
699
+ .download-btn:hover {{
700
+ color: #000;
701
+ }}
702
+
703
+ h1 {{
704
+ color: #000080;
705
+ font-size: 18px;
706
+ border-bottom: 1px solid #cccccc;
707
+ padding-bottom: 5px;
708
+ margin-top: 15px;
709
+ }}
710
+
711
+ h2 {{
712
+ color: #000080;
713
+ font-size: 16px;
714
+ margin-top: 15px;
715
+ border-bottom: 1px solid #eeeeee;
716
+ padding-bottom: 3px;
717
+ }}
718
+
719
+ h3 {{
720
+ color: #000080;
721
+ font-size: 14px;
722
+ margin-top: 10px;
723
+ }}
724
+
725
+ .timestamp {{
726
+ color: #666666;
727
+ font-size: 11px;
728
+ margin-bottom: 15px;
729
+ }}
730
+
731
+ .success-rate {{
732
+ font-size: 32px;
733
+ font-weight: bold;
734
+ color: #008000;
735
+ text-align: center;
736
+ margin: 15px 0;
737
+ }}
738
+
739
+ .threshold {{
740
+ text-align: center;
741
+ color: #666666;
742
+ margin-bottom: 15px;
743
+ font-size: 11px;
744
+ }}
745
+
746
+ table {{
747
+ border-collapse: collapse;
748
+ width: 100%;
749
+ border: 1px solid #cccccc;
750
+ margin: 15px 0;
751
+ }}
752
+
753
+ th, td {{
754
+ text-align: left;
755
+ padding: 5px;
756
+ border: 1px solid #cccccc;
757
+ font-size: 12px;
758
+ }}
759
+
760
+ th {{
761
+ background-color: #eeeeee;
762
+ }}
763
+
764
+ ul {{
765
+ margin: 10px 0;
766
+ padding-left: 20px;
767
+ list-style-type: square;
768
+ }}
769
+
770
+ li {{
771
+ margin-bottom: 5px;
772
+ }}
773
+
774
+ .info-box {{
775
+ border: 1px solid #cccccc;
776
+ background-color: #f5f5f5;
777
+ padding: 10px;
778
+ margin: 15px 0;
779
+ }}
780
+
781
+ .section {{
782
+ margin-bottom: 20px;
783
+ }}
784
+ </style>
785
+ </head>
786
+ <body>
787
+ <h1>{title}</h1>
788
+ <div class="timestamp">Generated on {
789
+ datetime.now().strftime("%Y-%m-%d %H:%M:%S")
790
+ }</div>
791
+
792
+ <div class="section">
793
+ <div class="success-rate">{lift_success:.2f}%</div>
794
+ <div class="threshold">{threshold_text}</div>
795
+
796
+ <div class="info-box">
797
+ <h3>Summary Statistics</h3>
798
+ <ul>
799
+ <li>Total episodes: {len(df)}</li>
800
+ </ul>
801
+ </div>
802
+ </div>
803
+
804
+ {additional_sections}
805
+
806
+ <div style="text-align: center; margin-top: 20px; margin-bottom: 10px;">
807
+ <span class="download-btn" style="margin: 0; text-align: center; font-size: 12px; padding: 2px 4px;" onclick="downloadCSV('all_results.csv', '{
808
+ results_csv_js
809
+ }')">Download All Results as CSV</span>
810
+ </div>
811
+ </body>
812
+ </html>
813
+ """
814
+
815
+ with open(html_path, "w") as f:
816
+ f.write(html_content)
817
+
818
+ print(f"HTML report generated at {html_path}")
819
+
820
+ print("Opening report in browser...")
821
+ webbrowser.open("file://" + os.path.abspath(html_path))
822
+
823
+
824
+ if __name__ == "__main__":
825
+ if len(sys.argv) != 2:
826
+ print("Usage: python analyze.py <log_folder>")
827
+ sys.exit(1)
828
+
829
+ log_folder = sys.argv[1]
830
+ if not os.path.exists(log_folder):
831
+ print(f"Log folder not found: {log_folder}")
832
+ sys.exit(1)
833
+
834
+ analyze_logs(log_folder)