roboreason 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roboreason-0.1.0/PKG-INFO +305 -0
- roboreason-0.1.0/README.md +263 -0
- roboreason-0.1.0/pyproject.toml +81 -0
- roboreason-0.1.0/roboreason/__init__.py +8 -0
- roboreason-0.1.0/roboreason/api_models.py +249 -0
- roboreason-0.1.0/roboreason/core.py +617 -0
- roboreason-0.1.0/roboreason/robometer/__init__.py +2 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/__init__.py +1 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/data_scripts/__init__.py +1 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/data_scripts/agibot/__init__.py +21 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/data_scripts/agibot/agibot_helper.py +170 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/data_scripts/agibot/download_task_jsons.py +42 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/data_scripts/libero/rerender_libero.py +322 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_helpers/generate_soar_labels_vlm.py +619 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_helpers/oxe_helper.py +815 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/agibotworld_loader.py +619 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/autoeval_loader.py +85 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/egocot_loader.py +188 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/egodex_loader.py +146 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/epic_loader.py +268 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/failsafe_loader.py +202 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/fino_net_loader.py +422 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/galaxea_loader.py +302 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/h2r_loader.py +380 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/hand_paired_loader.py +168 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/humanoid_everyday_loader.py +346 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/libero_loader.py +158 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/mit_franka_prank_loader.py +216 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/molmoact_loader.py +244 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/motif_loader.py +131 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/mw_collected_loader.py +161 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/mw_task_annotations.py +205 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/new_mit_franka_loader.py +262 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/oxe_loader.py +457 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/ph2d_loader.py +104 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/racer_loader.py +189 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/roboarena_loader.py +123 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/robofac_loader.py +452 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/robofail_loader.py +191 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/roboreward_loader.py +170 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/soar_loader.py +235 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/usc_franka_policy_ranking_loader.py +148 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/usc_koch_human_robot_paired_loader.py +466 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/usc_koch_p_ranking_loader.py +274 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/usc_xarm_policy_ranking_loader.py +164 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/utd_so101_clean_policy_ranking_loader.py +276 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/dataset_loaders/utd_so101_loader.py +158 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/generate_hf_dataset.py +1083 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/helpers.py +404 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/validate_dataset.py +196 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/video_helpers.py +238 -0
- roboreason-0.1.0/roboreason/robometer/dataset_upload/visualize_dataset.py +273 -0
- roboreason-0.1.0/roboreason/robometer/robometer/configs/eval_configs.py +274 -0
- roboreason-0.1.0/roboreason/robometer/robometer/configs/experiment_configs.py +562 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/collators/__init__.py +6 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/collators/base.py +86 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/collators/rbm_heads.py +694 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/collators/rewind.py +164 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/collators/utils.py +197 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/dataset_category.py +558 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/dataset_types.py +68 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/datasets/__init__.py +13 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/datasets/base.py +774 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/datasets/custom_eval.py +67 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/datasets/helpers.py +703 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/datasets/name_mapping.py +96 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/datasets/name_mapping_final.py +96 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/datasets/rbm_data.py +253 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/datasets/repeated_dataset.py +24 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/datasets/strategy_first_dataset.py +657 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/samplers/__init__.py +19 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/samplers/base.py +782 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/samplers/eval/base_pref.py +73 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/samplers/eval/confusion_matrix.py +300 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/samplers/eval/progress_policy_ranking.py +235 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/samplers/eval/quality_preference.py +219 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/samplers/eval/reward_alignment.py +174 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/samplers/eval/roboarena_quality_preference.py +121 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/samplers/pref.py +370 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/samplers/progress.py +176 -0
- roboreason-0.1.0/roboreason/robometer/robometer/data/scripts/preprocess_datasets.py +1091 -0
- roboreason-0.1.0/roboreason/robometer/robometer/evals/eval_server.py +802 -0
- roboreason-0.1.0/roboreason/robometer/robometer/evals/eval_utils.py +455 -0
- roboreason-0.1.0/roboreason/robometer/robometer/evals/eval_viz_utils.py +205 -0
- roboreason-0.1.0/roboreason/robometer/robometer/models/__init__.py +8 -0
- roboreason-0.1.0/roboreason/robometer/robometer/models/heads.py +101 -0
- roboreason-0.1.0/roboreason/robometer/robometer/models/rbm.py +993 -0
- roboreason-0.1.0/roboreason/robometer/robometer/models/rewind_transformer.py +396 -0
- roboreason-0.1.0/roboreason/robometer/robometer/models/utils.py +77 -0
- roboreason-0.1.0/roboreason/robometer/robometer/utils/__init__.py +0 -0
- roboreason-0.1.0/roboreason/robometer/robometer/utils/config_utils.py +41 -0
- roboreason-0.1.0/roboreason/robometer/robometer/utils/distributed.py +273 -0
- roboreason-0.1.0/roboreason/robometer/robometer/utils/embedding_utils.py +100 -0
- roboreason-0.1.0/roboreason/robometer/robometer/utils/logger.py +441 -0
- roboreason-0.1.0/roboreason/robometer/robometer/utils/metrics.py +210 -0
- roboreason-0.1.0/roboreason/robometer/robometer/utils/save.py +973 -0
- roboreason-0.1.0/roboreason/robometer/robometer/utils/setup_utils.py +1203 -0
- roboreason-0.1.0/roboreason/robometer/robometer/utils/tensor_utils.py +27 -0
- roboreason-0.1.0/roboreason/robometer/robometer/utils/timer.py +40 -0
- roboreason-0.1.0/roboreason/robometer/robometer/utils/upload_to_hub.py +241 -0
- roboreason-0.1.0/roboreason/robometer/robometer/utils/video_utils.py +603 -0
- roboreason-0.1.0/roboreason/robometer/roboreason_robometer.py +224 -0
- roboreason-0.1.0/roboreason/robometer/scripts/count_trajectories.py +370 -0
- roboreason-0.1.0/roboreason/robometer/scripts/example_inference.py +482 -0
- roboreason-0.1.0/roboreason/robometer/scripts/example_inference_local.py +201 -0
- roboreason-0.1.0/roboreason/robometer/scripts/example_libero_robometer_wrapper.py +736 -0
- roboreason-0.1.0/roboreason/robometer/scripts/robotics_demo_video_scraper.py +1039 -0
- roboreason-0.1.0/roboreason/robometer/train.py +358 -0
- roboreason-0.1.0/roboreason/roboreward.py +283 -0
- roboreason-0.1.0/roboreason/sole.py +746 -0
- roboreason-0.1.0/roboreason/topreward.py +479 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothBCOTrainer.py +2200 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothCPOTrainer.py +1980 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothDPOTrainer.py +2918 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothGKDTrainer.py +1331 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothGRPOTrainer.py +4230 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothKTOTrainer.py +2397 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothNashMDTrainer.py +1384 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothORPOTrainer.py +1904 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +2487 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothPPOTrainer.py +1678 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothPRMTrainer.py +1153 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothRLOOTrainer.py +2848 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothRewardTrainer.py +1371 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothSFTTrainer.py +1632 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/UnslothXPOTrainer.py +1429 -0
- roboreason-0.1.0/roboreason/unsloth_compiled_cache/moe_utils.py +1320 -0
- roboreason-0.1.0/roboreason/utils/__init__.py +2 -0
- roboreason-0.1.0/roboreason/utils/model_utils.py +48 -0
- roboreason-0.1.0/roboreason.egg-info/PKG-INFO +305 -0
- roboreason-0.1.0/roboreason.egg-info/SOURCES.txt +133 -0
- roboreason-0.1.0/roboreason.egg-info/dependency_links.txt +1 -0
- roboreason-0.1.0/roboreason.egg-info/requires.txt +37 -0
- roboreason-0.1.0/roboreason.egg-info/top_level.txt +1 -0
- roboreason-0.1.0/setup.cfg +4 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: roboreason
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Roboreason package
|
|
5
|
+
Author: Philip
|
|
6
|
+
Requires-Python: >=3.8
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: transformers==4.57.3
|
|
9
|
+
Requires-Dist: vllm==0.12.0
|
|
10
|
+
Requires-Dist: wandb>=0.19.1
|
|
11
|
+
Requires-Dist: pillow
|
|
12
|
+
Requires-Dist: accelerate>=1.2.1
|
|
13
|
+
Requires-Dist: qwen_vl_utils
|
|
14
|
+
Requires-Dist: datasets>=3.2.0
|
|
15
|
+
Requires-Dist: trl
|
|
16
|
+
Requires-Dist: imageio
|
|
17
|
+
Requires-Dist: matplotlib
|
|
18
|
+
Requires-Dist: av>=16.0.0
|
|
19
|
+
Requires-Dist: openai
|
|
20
|
+
Requires-Dist: google-genai
|
|
21
|
+
Provides-Extra: robometer
|
|
22
|
+
Requires-Dist: omegaconf>=2.3.0; extra == "robometer"
|
|
23
|
+
Requires-Dist: hydra-core; extra == "robometer"
|
|
24
|
+
Requires-Dist: tensorboard; extra == "robometer"
|
|
25
|
+
Requires-Dist: peft; extra == "robometer"
|
|
26
|
+
Requires-Dist: codetiming>=1.4.0; extra == "robometer"
|
|
27
|
+
Requires-Dist: unsloth>=2025.10; extra == "robometer"
|
|
28
|
+
Requires-Dist: sentence-transformers>=2.0.0; extra == "robometer"
|
|
29
|
+
Requires-Dist: decord>=0.6.0; extra == "robometer"
|
|
30
|
+
Provides-Extra: lerobot
|
|
31
|
+
Requires-Dist: lerobot; extra == "lerobot"
|
|
32
|
+
Provides-Extra: all
|
|
33
|
+
Requires-Dist: omegaconf>=2.3.0; extra == "all"
|
|
34
|
+
Requires-Dist: hydra-core; extra == "all"
|
|
35
|
+
Requires-Dist: tensorboard; extra == "all"
|
|
36
|
+
Requires-Dist: peft; extra == "all"
|
|
37
|
+
Requires-Dist: codetiming>=1.4.0; extra == "all"
|
|
38
|
+
Requires-Dist: unsloth>=2025.10; extra == "all"
|
|
39
|
+
Requires-Dist: sentence-transformers>=2.0.0; extra == "all"
|
|
40
|
+
Requires-Dist: decord>=0.6.0; extra == "all"
|
|
41
|
+
Requires-Dist: lerobot; extra == "all"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
<div align="center">
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# RoboReason
|
|
48
|
+
|
|
49
|
+
**RoboReason** is a python package that makes it easy to apply any ***reward model*** or ***video-language reasoning model*** to your robot videos.
|
|
50
|
+
|
|
51
|
+
</div>
|
|
52
|
+
|
|
53
|
+
## Supported Models
|
|
54
|
+
- Robometer (https://robometer.github.io)
|
|
55
|
+
- TOPReward (https://topreward.github.io/webpage/)
|
|
56
|
+
- RoboReward (https://arxiv.org/abs/2601.00675)
|
|
57
|
+
- SOLE-R1 (https://philipmit.github.io/sole-r1/)
|
|
58
|
+
- OpenAI models (e.g., `"gpt-5"`)
|
|
59
|
+
- Google models (e.g., `"gemini-3-pro-preview"`)
|
|
60
|
+
|
|
61
|
+
## ToDos
|
|
62
|
+
- [ ] Enable fine-tuning of reward models on custom datasets
|
|
63
|
+
|
|
64
|
+
## 📦 File Structure
|
|
65
|
+
|
|
66
|
+
```
|
|
67
|
+
roboreason/
|
|
68
|
+
├── roboreason/ # Main package
|
|
69
|
+
│ ├── robometer/ # Robometer code
|
|
70
|
+
│ ├── sole.py # SOLE-R1 code
|
|
71
|
+
│ ├── roboreward.py # RoboReward code
|
|
72
|
+
│ ├── topreward.py # TOPReward code
|
|
73
|
+
│ └── api_models.py # OpenAI and Gemini APIs
|
|
74
|
+
├── test_videos/ # Example videos to test
|
|
75
|
+
├── model_outputs/ # Videos showing model outputs
|
|
76
|
+
└── pyproject.toml # Dependencies (uv)
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
## Install
|
|
81
|
+
### Option 1: quick pip install
|
|
82
|
+
```bash
|
|
83
|
+
pip install -U roboreason
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
### Option 2: use [uv](https://github.com/astral-sh/uv) for dependency management
|
|
87
|
+
|
|
88
|
+
#### 1. Clone the repository:
|
|
89
|
+
```bash
|
|
90
|
+
git clone https://github.com/philipmit/roboreason
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
#### 2. Install `uv`
|
|
94
|
+
```bash
|
|
95
|
+
pip install uv
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
#### 3. Sync environment
|
|
99
|
+
```bash
|
|
100
|
+
uv sync
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
#### 4. Activate environment
|
|
104
|
+
```bash
|
|
105
|
+
source .venv/bin/activate
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
---
|
|
109
|
+
|
|
110
|
+
## Download model checkpoints
|
|
111
|
+
```bash
|
|
112
|
+
|
|
113
|
+
# SOLE-R1 (8B)
|
|
114
|
+
python -c "from roboreason.utils.model_utils import get_model_dir; get_model_dir('sole')"
|
|
115
|
+
|
|
116
|
+
# Robometer (4B)
|
|
117
|
+
python -c "from roboreason.utils.model_utils import get_model_dir; get_model_dir('robometer')"
|
|
118
|
+
|
|
119
|
+
# TOPReward (based on Qwen3-VL-8B)
|
|
120
|
+
python -c "from roboreason.utils.model_utils import get_model_dir; get_model_dir('topreward')"
|
|
121
|
+
|
|
122
|
+
# RoboReward (8B)
|
|
123
|
+
python -c "from roboreason.utils.model_utils import get_model_dir; get_model_dir('roboreward')"
|
|
124
|
+
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
---
|
|
128
|
+
## Quick start: Example reward generation and plotting
|
|
129
|
+
```python
|
|
130
|
+
# pip install -U roboreason
|
|
131
|
+
import roboreason as rr
|
|
132
|
+
|
|
133
|
+
video_paths = ['../test_videos/robosuite/robosuite_lift_example_00.mp4']
|
|
134
|
+
task_description="Pick up the cube from the table."
|
|
135
|
+
|
|
136
|
+
# Robometer
|
|
137
|
+
rewards, success_probs = rr.generate(model="robometer", task_description=task_description, video_paths=video_paths, view_type_per_video=['external'])
|
|
138
|
+
output_robometer = {"model": "robometer", "rewards": rewards[0]}
|
|
139
|
+
|
|
140
|
+
# SOLE-R1
|
|
141
|
+
rewards, reasoning_traces = rr.generate(model="sole-r1", task_description=task_description, video_paths=video_paths, view_type_per_video=['external and wrist'])
|
|
142
|
+
output_sole = {"model": "sole-r1", "rewards": rewards[0], "reasoning_traces": reasoning_traces[0]}
|
|
143
|
+
|
|
144
|
+
rr.video_plot(outputs=[output_sole, output_robometer], plot_save_path='../model_outputs/combined/robosuite_lift_example_00.mp4', video_path = video_paths[0])
|
|
145
|
+
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
---
|
|
149
|
+
## Examples for generating across all models
|
|
150
|
+
|
|
151
|
+
### Robometer
|
|
152
|
+
```python
|
|
153
|
+
|
|
154
|
+
import roboreason as rr
|
|
155
|
+
|
|
156
|
+
rewards, success_probs = rr.generate(
|
|
157
|
+
model="robometer",
|
|
158
|
+
task_description="Pick up the cube from the table.",
|
|
159
|
+
video_paths=['../test_videos/robosuite/robosuite_lift_example_00.mp4'],
|
|
160
|
+
view_type_per_video=['external']
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
```
|
|
164
|
+
|
|
165
|
+
### SOLE-R1
|
|
166
|
+
```python
|
|
167
|
+
|
|
168
|
+
import roboreason as rr
|
|
169
|
+
|
|
170
|
+
rewards, reasoning_traces = rr.generate(
|
|
171
|
+
model="sole-r1",
|
|
172
|
+
task_description="Pick up the cube from the table.",
|
|
173
|
+
video_paths=['../test_videos/robosuite/robosuite_lift_example_00.mp4'],
|
|
174
|
+
view_type_per_video=['external and wrist']
|
|
175
|
+
)
|
|
176
|
+
```
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
### TOPReward
|
|
180
|
+
```python
|
|
181
|
+
|
|
182
|
+
import roboreason as rr
|
|
183
|
+
|
|
184
|
+
rewards = rr.generate(
|
|
185
|
+
model="topreward",
|
|
186
|
+
task_description="Pick up the cube from the table.",
|
|
187
|
+
video_paths=['../test_videos/robosuite/robosuite_lift_example_00.mp4'],
|
|
188
|
+
view_type_per_video=['external']
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
```
|
|
192
|
+
|
|
193
|
+
### RoboReward
|
|
194
|
+
```python
|
|
195
|
+
|
|
196
|
+
import roboreason as rr
|
|
197
|
+
|
|
198
|
+
rewards = rr.generate(
|
|
199
|
+
model="roboreward",
|
|
200
|
+
task_description="Pick up the cube from the table.",
|
|
201
|
+
video_paths=['../test_videos/robosuite/robosuite_lift_example_00.mp4'],
|
|
202
|
+
view_type_per_video=['external']
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
```
|
|
206
|
+
|
|
207
|
+
### GPT-5 (and other OpenAI models)
|
|
208
|
+
```python
|
|
209
|
+
|
|
210
|
+
import roboreason as rr
|
|
211
|
+
|
|
212
|
+
# requires OpenAI API key: https://developers.openai.com/api/docs/quickstart
|
|
213
|
+
API_KEY = "..."
|
|
214
|
+
|
|
215
|
+
rewards, reasoning_traces = rr.generate(
|
|
216
|
+
model="gpt-5",
|
|
217
|
+
task_description="Pick up the cube from the table.",
|
|
218
|
+
video_paths=['../test_videos/robosuite/robosuite_lift_example_00.mp4'],
|
|
219
|
+
view_type_per_video=['external'],
|
|
220
|
+
key=API_KEY
|
|
221
|
+
)
|
|
222
|
+
```
|
|
223
|
+
|
|
224
|
+
### Gemini-3-Pro (and other Google models)
|
|
225
|
+
```python
|
|
226
|
+
|
|
227
|
+
import roboreason as rr
|
|
228
|
+
|
|
229
|
+
# requires Gemini API key: https://ai.google.dev/gemini-api/docs/api-key
|
|
230
|
+
API_KEY = "..."
|
|
231
|
+
|
|
232
|
+
rewards, reasoning_traces = rr.generate(
|
|
233
|
+
model="gemini-3-pro-preview",
|
|
234
|
+
task_description="Pick up the cube from the table.",
|
|
235
|
+
video_paths=['../test_videos/robosuite/robosuite_lift_example_00.mp4'],
|
|
236
|
+
view_type_per_video=['external'],
|
|
237
|
+
key=API_KEY
|
|
238
|
+
)
|
|
239
|
+
```
|
|
240
|
+
|
|
241
|
+
## Video plotting
|
|
242
|
+
```python
|
|
243
|
+
|
|
244
|
+
import roboreason as rr
|
|
245
|
+
|
|
246
|
+
# Robometer
|
|
247
|
+
rewards, success_probs = rr.generate(model="robometer", task_description=task_description, video_paths=video_paths, view_type_per_video=['external'])
|
|
248
|
+
output_robometer = {"model": "robometer", "rewards": rewards[0]}
|
|
249
|
+
|
|
250
|
+
# SOLE-R1
|
|
251
|
+
rewards, reasoning_traces = rr.generate(model="sole-r1", task_description=task_description, video_paths=video_paths, view_type_per_video=['external and wrist'])
|
|
252
|
+
output_sole = {"model": "sole-r1", "rewards": rewards[0], "reasoning_traces": reasoning_traces[0]}
|
|
253
|
+
|
|
254
|
+
rr.video_plot(
|
|
255
|
+
outputs=[output_sole, output_robometer],
|
|
256
|
+
plot_save_path='../model_outputs/combined/robosuite_lift_example_00.mp4',
|
|
257
|
+
video_path = '../test_videos/robosuite/robosuite_lift_example_00.mp4'
|
|
258
|
+
)
|
|
259
|
+
```
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
---
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
## rr.generate
|
|
266
|
+
|
|
267
|
+
| Argument | Type | Required | Description |
|
|
268
|
+
| --------------------- | ----------- | -------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
269
|
+
| `model` | `str` | ✅ | Name of the model to use. Options include: `"robometer"`, `"sole-r1"`, `"topreward"`, `"roboreward"`, OpenAI models (e.g.`"gpt-5"`), Google models (e.g., `"gemini-3-pro-preview"`) |
|
|
270
|
+
| `task_description` | `str` | ✅ | Natural language description of the task the robot is performing. |
|
|
271
|
+
| `video_paths` | `List[str]` | ✅ | List of paths to input video files. |
|
|
272
|
+
| `view_type_per_video` | `List[str]` | ✅ | List specifying the camera view(s) used for reward reasoning for each video (e.g., `"external"`, `"wrist"`, or `"external and wrist"`). |
|
|
273
|
+
| `key` | `str` | ❌ | API key required for external models (e.g., OpenAI or Gemini). Not needed for local models. |
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
| Model Type | Return Values |
|
|
277
|
+
| ---------------------- | --------------------------- |
|
|
278
|
+
| SOLE-R1 / GPT / Gemini | `rewards, reasoning_traces` |
|
|
279
|
+
| Robometer | `rewards, success_probs` |
|
|
280
|
+
| TOPReward / RoboReward | `rewards` |
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
## rr.video_plot
|
|
284
|
+
|
|
285
|
+
| Argument | Type | Required | Description |
|
|
286
|
+
| ----------------------- | ------------ | -------- | ----------------------------------------------------------------------------------------- |
|
|
287
|
+
| `outputs` | `List[dict]` | ❌* | List of model outputs (e.g., from `rr.generate`) to visualize together. |
|
|
288
|
+
| `plot_save_path` | `str` | ❌ | Path where the output video with overlays will be saved. |
|
|
289
|
+
| `video_path` | `str` | ❌ | Path to the original video file being visualized. |
|
|
290
|
+
| `view_type` | `str` | ❌ | View type used for visualization (e.g., `"external"`, `"wrist"`, `"external and wrist"`). |
|
|
291
|
+
| `show_reasoning_traces` | `bool` | ❌ | Whether to overlay reasoning traces on the video. Default: `False`. |
|
|
292
|
+
| `show_all_frames` | `bool` | ❌ | Whether to render all frames instead of sampled frames. Default: `False`. |
|
|
293
|
+
| `model` | `str` | ❌** | Model name (used when calling `video_plot` directly instead of passing `outputs`). |
|
|
294
|
+
| `task_description` | `str` | ❌** | Task description (used in direct-call mode). |
|
|
295
|
+
| `video_paths` | `List[str]` | ❌** | Input videos (used in direct-call mode). |
|
|
296
|
+
| `view_type_per_video` | `List[str]` | ❌** | View types per video (used in direct-call mode). |
|
|
297
|
+
| `key` | `str` | ❌** | API key (if required for model). |
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
|
|
2
|
+
<div align="center">
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
# RoboReason
|
|
6
|
+
|
|
7
|
+
**RoboReason** is a python package that makes it easy to apply any ***reward model*** or ***video-language reasoning model*** to your robot videos.
|
|
8
|
+
|
|
9
|
+
</div>
|
|
10
|
+
|
|
11
|
+
## Supported Models
|
|
12
|
+
- Robometer (https://robometer.github.io)
|
|
13
|
+
- TOPReward (https://topreward.github.io/webpage/)
|
|
14
|
+
- RoboReward (https://arxiv.org/abs/2601.00675)
|
|
15
|
+
- SOLE-R1 (https://philipmit.github.io/sole-r1/)
|
|
16
|
+
- OpenAI models (e.g., `"gpt-5"`)
|
|
17
|
+
- Google models (e.g., `"gemini-3-pro-preview"`)
|
|
18
|
+
|
|
19
|
+
## ToDos
|
|
20
|
+
- [ ] Enable fine-tuning of reward models on custom datasets
|
|
21
|
+
|
|
22
|
+
## 📦 File Structure
|
|
23
|
+
|
|
24
|
+
```
|
|
25
|
+
roboreason/
|
|
26
|
+
├── roboreason/ # Main package
|
|
27
|
+
│ ├── robometer/ # Robometer code
|
|
28
|
+
│ ├── sole.py # SOLE-R1 code
|
|
29
|
+
│ ├── roboreward.py # RoboReward code
|
|
30
|
+
│ ├── topreward.py # TOPReward code
|
|
31
|
+
│ └── api_models.py # OpenAI and Gemini APIs
|
|
32
|
+
├── test_videos/ # Example videos to test
|
|
33
|
+
├── model_outputs/ # Videos showing model outputs
|
|
34
|
+
└── pyproject.toml # Dependencies (uv)
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
## Install
|
|
39
|
+
### Option 1: quick pip install
|
|
40
|
+
```bash
|
|
41
|
+
pip install -U roboreason
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
### Option 2: use [uv](https://github.com/astral-sh/uv) for dependency management
|
|
45
|
+
|
|
46
|
+
#### 1. Clone the repository:
|
|
47
|
+
```bash
|
|
48
|
+
git clone https://github.com/philipmit/roboreason
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
#### 2. Install `uv`
|
|
52
|
+
```bash
|
|
53
|
+
pip install uv
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
#### 3. Sync environment
|
|
57
|
+
```bash
|
|
58
|
+
uv sync
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
#### 4. Activate environment
|
|
62
|
+
```bash
|
|
63
|
+
source .venv/bin/activate
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
---
|
|
67
|
+
|
|
68
|
+
## Download model checkpoints
|
|
69
|
+
```bash
|
|
70
|
+
|
|
71
|
+
# SOLE-R1 (8B)
|
|
72
|
+
python -c "from roboreason.utils.model_utils import get_model_dir; get_model_dir('sole')"
|
|
73
|
+
|
|
74
|
+
# Robometer (4B)
|
|
75
|
+
python -c "from roboreason.utils.model_utils import get_model_dir; get_model_dir('robometer')"
|
|
76
|
+
|
|
77
|
+
# TOPReward (based on Qwen3-VL-8B)
|
|
78
|
+
python -c "from roboreason.utils.model_utils import get_model_dir; get_model_dir('topreward')"
|
|
79
|
+
|
|
80
|
+
# RoboReward (8B)
|
|
81
|
+
python -c "from roboreason.utils.model_utils import get_model_dir; get_model_dir('roboreward')"
|
|
82
|
+
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
---
|
|
86
|
+
## Quick start: Example reward generation and plotting
|
|
87
|
+
```python
|
|
88
|
+
# pip install -U roboreason
|
|
89
|
+
import roboreason as rr
|
|
90
|
+
|
|
91
|
+
video_paths = ['../test_videos/robosuite/robosuite_lift_example_00.mp4']
|
|
92
|
+
task_description="Pick up the cube from the table."
|
|
93
|
+
|
|
94
|
+
# Robometer
|
|
95
|
+
rewards, success_probs = rr.generate(model="robometer", task_description=task_description, video_paths=video_paths, view_type_per_video=['external'])
|
|
96
|
+
output_robometer = {"model": "robometer", "rewards": rewards[0]}
|
|
97
|
+
|
|
98
|
+
# SOLE-R1
|
|
99
|
+
rewards, reasoning_traces = rr.generate(model="sole-r1", task_description=task_description, video_paths=video_paths, view_type_per_video=['external and wrist'])
|
|
100
|
+
output_sole = {"model": "sole-r1", "rewards": rewards[0], "reasoning_traces": reasoning_traces[0]}
|
|
101
|
+
|
|
102
|
+
rr.video_plot(outputs=[output_sole, output_robometer], plot_save_path='../model_outputs/combined/robosuite_lift_example_00.mp4', video_path = video_paths[0])
|
|
103
|
+
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
---
|
|
107
|
+
## Examples for generating across all models
|
|
108
|
+
|
|
109
|
+
### Robometer
|
|
110
|
+
```python
|
|
111
|
+
|
|
112
|
+
import roboreason as rr
|
|
113
|
+
|
|
114
|
+
rewards, success_probs = rr.generate(
|
|
115
|
+
model="robometer",
|
|
116
|
+
task_description="Pick up the cube from the table.",
|
|
117
|
+
video_paths=['../test_videos/robosuite/robosuite_lift_example_00.mp4'],
|
|
118
|
+
view_type_per_video=['external']
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
### SOLE-R1
|
|
124
|
+
```python
|
|
125
|
+
|
|
126
|
+
import roboreason as rr
|
|
127
|
+
|
|
128
|
+
rewards, reasoning_traces = rr.generate(
|
|
129
|
+
model="sole-r1",
|
|
130
|
+
task_description="Pick up the cube from the table.",
|
|
131
|
+
video_paths=['../test_videos/robosuite/robosuite_lift_example_00.mp4'],
|
|
132
|
+
view_type_per_video=['external and wrist']
|
|
133
|
+
)
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
### TOPReward
|
|
138
|
+
```python
|
|
139
|
+
|
|
140
|
+
import roboreason as rr
|
|
141
|
+
|
|
142
|
+
rewards = rr.generate(
|
|
143
|
+
model="topreward",
|
|
144
|
+
task_description="Pick up the cube from the table.",
|
|
145
|
+
video_paths=['../test_videos/robosuite/robosuite_lift_example_00.mp4'],
|
|
146
|
+
view_type_per_video=['external']
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
### RoboReward
|
|
152
|
+
```python
|
|
153
|
+
|
|
154
|
+
import roboreason as rr
|
|
155
|
+
|
|
156
|
+
rewards = rr.generate(
|
|
157
|
+
model="roboreward",
|
|
158
|
+
task_description="Pick up the cube from the table.",
|
|
159
|
+
video_paths=['../test_videos/robosuite/robosuite_lift_example_00.mp4'],
|
|
160
|
+
view_type_per_video=['external']
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
```
|
|
164
|
+
|
|
165
|
+
### GPT-5 (and other OpenAI models)
|
|
166
|
+
```python
|
|
167
|
+
|
|
168
|
+
import roboreason as rr
|
|
169
|
+
|
|
170
|
+
# requires OpenAI API key: https://developers.openai.com/api/docs/quickstart
|
|
171
|
+
API_KEY = "..."
|
|
172
|
+
|
|
173
|
+
rewards, reasoning_traces = rr.generate(
|
|
174
|
+
model="gpt-5",
|
|
175
|
+
task_description="Pick up the cube from the table.",
|
|
176
|
+
video_paths=['../test_videos/robosuite/robosuite_lift_example_00.mp4'],
|
|
177
|
+
view_type_per_video=['external'],
|
|
178
|
+
key=API_KEY
|
|
179
|
+
)
|
|
180
|
+
```
|
|
181
|
+
|
|
182
|
+
### Gemini-3-Pro (and other Google models)
|
|
183
|
+
```python
|
|
184
|
+
|
|
185
|
+
import roboreason as rr
|
|
186
|
+
|
|
187
|
+
# requires Gemini API key: https://ai.google.dev/gemini-api/docs/api-key
|
|
188
|
+
API_KEY = "..."
|
|
189
|
+
|
|
190
|
+
rewards, reasoning_traces = rr.generate(
|
|
191
|
+
model="gemini-3-pro-preview",
|
|
192
|
+
task_description="Pick up the cube from the table.",
|
|
193
|
+
video_paths=['../test_videos/robosuite/robosuite_lift_example_00.mp4'],
|
|
194
|
+
view_type_per_video=['external'],
|
|
195
|
+
key=API_KEY
|
|
196
|
+
)
|
|
197
|
+
```
|
|
198
|
+
|
|
199
|
+
## Video plotting
|
|
200
|
+
```python
|
|
201
|
+
|
|
202
|
+
import roboreason as rr
|
|
203
|
+
|
|
204
|
+
# Robometer
|
|
205
|
+
rewards, success_probs = rr.generate(model="robometer", task_description=task_description, video_paths=video_paths, view_type_per_video=['external'])
|
|
206
|
+
output_robometer = {"model": "robometer", "rewards": rewards[0]}
|
|
207
|
+
|
|
208
|
+
# SOLE-R1
|
|
209
|
+
rewards, reasoning_traces = rr.generate(model="sole-r1", task_description=task_description, video_paths=video_paths, view_type_per_video=['external and wrist'])
|
|
210
|
+
output_sole = {"model": "sole-r1", "rewards": rewards[0], "reasoning_traces": reasoning_traces[0]}
|
|
211
|
+
|
|
212
|
+
rr.video_plot(
|
|
213
|
+
outputs=[output_sole, output_robometer],
|
|
214
|
+
plot_save_path='../model_outputs/combined/robosuite_lift_example_00.mp4',
|
|
215
|
+
video_path = '../test_videos/robosuite/robosuite_lift_example_00.mp4'
|
|
216
|
+
)
|
|
217
|
+
```
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
---
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
## rr.generate
|
|
224
|
+
|
|
225
|
+
| Argument | Type | Required | Description |
|
|
226
|
+
| --------------------- | ----------- | -------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
227
|
+
| `model` | `str` | ✅ | Name of the model to use. Options include: `"robometer"`, `"sole-r1"`, `"topreward"`, `"roboreward"`, OpenAI models (e.g.`"gpt-5"`), Google models (e.g., `"gemini-3-pro-preview"`) |
|
|
228
|
+
| `task_description` | `str` | ✅ | Natural language description of the task the robot is performing. |
|
|
229
|
+
| `video_paths` | `List[str]` | ✅ | List of paths to input video files. |
|
|
230
|
+
| `view_type_per_video` | `List[str]` | ✅ | List specifying the camera view(s) used for reward reasoning for each video (e.g., `"external"`, `"wrist"`, or `"external and wrist"`). |
|
|
231
|
+
| `key` | `str` | ❌ | API key required for external models (e.g., OpenAI or Gemini). Not needed for local models. |
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
| Model Type | Return Values |
|
|
235
|
+
| ---------------------- | --------------------------- |
|
|
236
|
+
| SOLE-R1 / GPT / Gemini | `rewards, reasoning_traces` |
|
|
237
|
+
| Robometer | `rewards, success_probs` |
|
|
238
|
+
| TOPReward / RoboReward | `rewards` |
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
## rr.video_plot
|
|
242
|
+
|
|
243
|
+
| Argument | Type | Required | Description |
|
|
244
|
+
| ----------------------- | ------------ | -------- | ----------------------------------------------------------------------------------------- |
|
|
245
|
+
| `outputs` | `List[dict]` | ❌* | List of model outputs (e.g., from `rr.generate`) to visualize together. |
|
|
246
|
+
| `plot_save_path` | `str` | ❌ | Path where the output video with overlays will be saved. |
|
|
247
|
+
| `video_path` | `str` | ❌ | Path to the original video file being visualized. |
|
|
248
|
+
| `view_type` | `str` | ❌ | View type used for visualization (e.g., `"external"`, `"wrist"`, `"external and wrist"`). |
|
|
249
|
+
| `show_reasoning_traces` | `bool` | ❌ | Whether to overlay reasoning traces on the video. Default: `False`. |
|
|
250
|
+
| `show_all_frames` | `bool` | ❌ | Whether to render all frames instead of sampled frames. Default: `False`. |
|
|
251
|
+
| `model` | `str` | ❌** | Model name (used when calling `video_plot` directly instead of passing `outputs`). |
|
|
252
|
+
| `task_description` | `str` | ❌** | Task description (used in direct-call mode). |
|
|
253
|
+
| `video_paths` | `List[str]` | ❌** | Input videos (used in direct-call mode). |
|
|
254
|
+
| `view_type_per_video` | `List[str]` | ❌** | View types per video (used in direct-call mode). |
|
|
255
|
+
| `key` | `str` | ❌** | API key (if required for model). |
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
|
|
2
|
+
[build-system]
|
|
3
|
+
requires = ["setuptools>=61"]
|
|
4
|
+
build-backend = "setuptools.build_meta"
|
|
5
|
+
|
|
6
|
+
[project]
|
|
7
|
+
name = "roboreason"
|
|
8
|
+
version = "0.1.0"
|
|
9
|
+
description = "Roboreason package"
|
|
10
|
+
readme = "README.md"
|
|
11
|
+
requires-python = ">=3.8"
|
|
12
|
+
|
|
13
|
+
authors = [
|
|
14
|
+
{ name = "Philip" }
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
dependencies = [
|
|
18
|
+
# Core ML stack
|
|
19
|
+
"transformers==4.57.3",
|
|
20
|
+
"vllm==0.12.0",
|
|
21
|
+
"wandb>=0.19.1",
|
|
22
|
+
"pillow",
|
|
23
|
+
"accelerate>=1.2.1",
|
|
24
|
+
"qwen_vl_utils",
|
|
25
|
+
"datasets>=3.2.0",
|
|
26
|
+
"trl",
|
|
27
|
+
"imageio",
|
|
28
|
+
"matplotlib",
|
|
29
|
+
"av>=16.0.0",
|
|
30
|
+
|
|
31
|
+
# General utilities
|
|
32
|
+
"openai",
|
|
33
|
+
"google-genai",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
# Optional dependency groups
|
|
37
|
+
[project.optional-dependencies]
|
|
38
|
+
|
|
39
|
+
robometer = [
|
|
40
|
+
"omegaconf>=2.3.0",
|
|
41
|
+
"hydra-core",
|
|
42
|
+
"tensorboard",
|
|
43
|
+
"peft",
|
|
44
|
+
"codetiming>=1.4.0",
|
|
45
|
+
"unsloth>=2025.10",
|
|
46
|
+
"sentence-transformers>=2.0.0",
|
|
47
|
+
"decord>=0.6.0",
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
lerobot = [
|
|
51
|
+
"lerobot",
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
all = [
|
|
55
|
+
"omegaconf>=2.3.0",
|
|
56
|
+
"hydra-core",
|
|
57
|
+
"tensorboard",
|
|
58
|
+
"peft",
|
|
59
|
+
"codetiming>=1.4.0",
|
|
60
|
+
"unsloth>=2025.10",
|
|
61
|
+
"sentence-transformers>=2.0.0",
|
|
62
|
+
"decord>=0.6.0",
|
|
63
|
+
"lerobot",
|
|
64
|
+
]
|
|
65
|
+
# -------------------------
|
|
66
|
+
# setuptools config
|
|
67
|
+
# -------------------------
|
|
68
|
+
[tool.setuptools.packages.find]
|
|
69
|
+
where = ["."]
|
|
70
|
+
include = ["roboreason*"]
|
|
71
|
+
|
|
72
|
+
# -------------------------
|
|
73
|
+
# UV-specific config
|
|
74
|
+
# -------------------------
|
|
75
|
+
[tool.uv]
|
|
76
|
+
|
|
77
|
+
# This allows PyTorch CUDA wheels
|
|
78
|
+
extra-index-url = [
|
|
79
|
+
"https://download.pytorch.org/whl/cu116"
|
|
80
|
+
]
|
|
81
|
+
|