kinfer 0.5.4__cp312-cp312-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +16 -0
- kinfer/export/__init__.py +3 -0
- kinfer/export/jax.py +55 -0
- kinfer/export/pytorch.py +53 -0
- kinfer/export/serialize.py +84 -0
- kinfer/py.typed +0 -0
- kinfer/requirements.txt +13 -0
- kinfer/rust/Cargo.toml +36 -0
- kinfer/rust/src/lib.rs +8 -0
- kinfer/rust/src/logger.rs +141 -0
- kinfer/rust/src/model.rs +354 -0
- kinfer/rust/src/runtime.rs +107 -0
- kinfer/rust/src/types.rs +96 -0
- kinfer/rust_bindings/Cargo.toml +26 -0
- kinfer/rust_bindings/pyproject.toml +7 -0
- kinfer/rust_bindings/rust_bindings.pyi +46 -0
- kinfer/rust_bindings/src/bin/stub_gen.rs +7 -0
- kinfer/rust_bindings/src/lib.rs +486 -0
- kinfer/rust_bindings.cpython-312-x86_64-linux-gnu.so +0 -0
- kinfer/rust_bindings.pyi +46 -0
- kinfer/scripts/plot_ndjson.py +177 -0
- kinfer-0.5.4.dist-info/METADATA +63 -0
- kinfer-0.5.4.dist-info/RECORD +26 -0
- kinfer-0.5.4.dist-info/WHEEL +5 -0
- kinfer-0.5.4.dist-info/licenses/LICENSE +21 -0
- kinfer-0.5.4.dist-info/top_level.txt +1 -0
kinfer/rust/src/model.rs
ADDED
@@ -0,0 +1,354 @@
|
|
1
|
+
use crate::logger::StepLogger;
|
2
|
+
use crate::types::{InputType, ModelMetadata};
|
3
|
+
use async_trait::async_trait;
|
4
|
+
use chrono;
|
5
|
+
use flate2::read::GzDecoder;
|
6
|
+
use ndarray::{Array, IxDyn};
|
7
|
+
use ort::session::Session;
|
8
|
+
use ort::value::Value;
|
9
|
+
use std::collections::HashMap;
|
10
|
+
use std::io::Read;
|
11
|
+
use std::path::Path;
|
12
|
+
use std::sync::Arc;
|
13
|
+
use tar::Archive;
|
14
|
+
use tokio::fs::File;
|
15
|
+
use tokio::io::AsyncReadExt;
|
16
|
+
|
17
|
+
#[derive(Debug, thiserror::Error)]
|
18
|
+
pub enum ModelError {
|
19
|
+
#[error("IO error: {0}")]
|
20
|
+
Io(#[from] std::io::Error),
|
21
|
+
#[error("Provider error: {0}")]
|
22
|
+
Provider(String),
|
23
|
+
}
|
24
|
+
|
25
|
+
#[async_trait]
|
26
|
+
pub trait ModelProvider: Send + Sync {
|
27
|
+
async fn get_inputs(
|
28
|
+
&self,
|
29
|
+
input_types: &[InputType],
|
30
|
+
metadata: &ModelMetadata,
|
31
|
+
) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError>;
|
32
|
+
|
33
|
+
async fn take_action(
|
34
|
+
&self,
|
35
|
+
action: Array<f32, IxDyn>,
|
36
|
+
metadata: &ModelMetadata,
|
37
|
+
) -> Result<(), ModelError>;
|
38
|
+
}
|
39
|
+
|
40
|
+
pub struct ModelRunner {
|
41
|
+
init_session: Session,
|
42
|
+
step_session: Session,
|
43
|
+
metadata: ModelMetadata,
|
44
|
+
provider: Arc<dyn ModelProvider>,
|
45
|
+
logger: Option<Arc<StepLogger>>,
|
46
|
+
}
|
47
|
+
|
48
|
+
impl ModelRunner {
|
49
|
+
pub async fn new<P: AsRef<Path>>(
|
50
|
+
model_path: P,
|
51
|
+
input_provider: Arc<dyn ModelProvider>,
|
52
|
+
) -> Result<Self, Box<dyn std::error::Error>> {
|
53
|
+
let mut file = File::open(model_path).await?;
|
54
|
+
|
55
|
+
// Read entire file into memory
|
56
|
+
let mut buffer = Vec::new();
|
57
|
+
file.read_to_end(&mut buffer).await?;
|
58
|
+
|
59
|
+
// Decompress and read the tar archive from memory
|
60
|
+
let gz = GzDecoder::new(&buffer[..]);
|
61
|
+
let mut archive = Archive::new(gz);
|
62
|
+
|
63
|
+
// Extract and validate joint names
|
64
|
+
let mut metadata: Option<String> = None;
|
65
|
+
let mut init_fn: Option<Vec<u8>> = None;
|
66
|
+
let mut step_fn: Option<Vec<u8>> = None;
|
67
|
+
|
68
|
+
for entry in archive.entries()? {
|
69
|
+
let mut entry = entry?;
|
70
|
+
let path = entry.path()?;
|
71
|
+
let path_str = path.to_string_lossy();
|
72
|
+
|
73
|
+
match path_str.as_ref() {
|
74
|
+
"metadata.json" => {
|
75
|
+
let mut contents = String::new();
|
76
|
+
entry.read_to_string(&mut contents)?;
|
77
|
+
metadata = Some(contents);
|
78
|
+
}
|
79
|
+
"init_fn.onnx" => {
|
80
|
+
let size = entry.size() as usize;
|
81
|
+
let mut contents = vec![0u8; size];
|
82
|
+
entry.read_exact(&mut contents)?;
|
83
|
+
assert_eq!(contents.len(), entry.size() as usize);
|
84
|
+
init_fn = Some(contents);
|
85
|
+
}
|
86
|
+
"step_fn.onnx" => {
|
87
|
+
let size = entry.size() as usize;
|
88
|
+
let mut contents = vec![0u8; size];
|
89
|
+
entry.read_exact(&mut contents)?;
|
90
|
+
assert_eq!(contents.len(), entry.size() as usize);
|
91
|
+
step_fn = Some(contents);
|
92
|
+
}
|
93
|
+
_ => return Err("Unknown entry".into()),
|
94
|
+
}
|
95
|
+
}
|
96
|
+
|
97
|
+
// Reads the files.
|
98
|
+
let metadata = ModelMetadata::model_validate_json(
|
99
|
+
metadata.ok_or("metadata.json not found in archive")?,
|
100
|
+
)?;
|
101
|
+
let init_session = Session::builder()?
|
102
|
+
.commit_from_memory(&init_fn.ok_or("init_fn.onnx not found in archive")?)?;
|
103
|
+
let step_session = Session::builder()?
|
104
|
+
.commit_from_memory(&step_fn.ok_or("step_fn.onnx not found in archive")?)?;
|
105
|
+
|
106
|
+
// Validate init_fn has no inputs and one output
|
107
|
+
if !init_session.inputs.is_empty() {
|
108
|
+
return Err("init_fn should not have any inputs".into());
|
109
|
+
}
|
110
|
+
if init_session.outputs.len() != 1 {
|
111
|
+
return Err("init_fn should have exactly one output".into());
|
112
|
+
}
|
113
|
+
|
114
|
+
// Get carry shape from init_fn output
|
115
|
+
let carry_shape = init_session.outputs[0]
|
116
|
+
.output_type
|
117
|
+
.tensor_dimensions()
|
118
|
+
.ok_or("Missing tensor type")?
|
119
|
+
.to_vec();
|
120
|
+
|
121
|
+
// Validate step_fn inputs and outputs
|
122
|
+
Self::validate_step_fn(&step_session, &metadata, &carry_shape)?;
|
123
|
+
|
124
|
+
let logger = if let Ok(log_dir) = std::env::var("KINFER_LOG_PATH") {
|
125
|
+
let log_dir_path = std::path::Path::new(&log_dir);
|
126
|
+
|
127
|
+
// Create the directory if it doesn't exist
|
128
|
+
if !log_dir_path.exists() {
|
129
|
+
std::fs::create_dir_all(log_dir_path)?;
|
130
|
+
}
|
131
|
+
|
132
|
+
// Use uuid if found, otherwise timestamp
|
133
|
+
let log_name = std::env::var("KINFER_LOG_UUID").unwrap_or_else(|_| {
|
134
|
+
chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S").to_string()
|
135
|
+
});
|
136
|
+
|
137
|
+
let log_file_path = log_dir_path.join(format!("{}.ndjson", log_name));
|
138
|
+
|
139
|
+
Some(StepLogger::new(log_file_path).map(Arc::new)?)
|
140
|
+
} else {
|
141
|
+
None
|
142
|
+
};
|
143
|
+
|
144
|
+
Ok(Self {
|
145
|
+
init_session,
|
146
|
+
step_session,
|
147
|
+
metadata,
|
148
|
+
provider: input_provider,
|
149
|
+
logger,
|
150
|
+
})
|
151
|
+
}
|
152
|
+
|
153
|
+
fn validate_step_fn(
|
154
|
+
session: &Session,
|
155
|
+
metadata: &ModelMetadata,
|
156
|
+
carry_shape: &[i64],
|
157
|
+
) -> Result<(), Box<dyn std::error::Error>> {
|
158
|
+
// Validate inputs
|
159
|
+
for input in &session.inputs {
|
160
|
+
let dims = input.input_type.tensor_dimensions().ok_or(format!(
|
161
|
+
"Input {} is not a tensor with known dimensions",
|
162
|
+
input.name
|
163
|
+
))?;
|
164
|
+
|
165
|
+
let input_type = InputType::from_name(&input.name)?;
|
166
|
+
let expected_shape = input_type.get_shape(metadata);
|
167
|
+
let expected_shape_i64: Vec<i64> = expected_shape.iter().map(|&x| x as i64).collect();
|
168
|
+
if *dims != expected_shape_i64 {
|
169
|
+
return Err(format!(
|
170
|
+
"Expected input shape {:?}, got {:?}",
|
171
|
+
expected_shape_i64, dims
|
172
|
+
)
|
173
|
+
.into());
|
174
|
+
}
|
175
|
+
}
|
176
|
+
|
177
|
+
// Validate outputs
|
178
|
+
if session.outputs.len() != 2 {
|
179
|
+
return Err("Step function must have exactly 2 outputs".into());
|
180
|
+
}
|
181
|
+
|
182
|
+
let output_shape = session.outputs[0]
|
183
|
+
.output_type
|
184
|
+
.tensor_dimensions()
|
185
|
+
.ok_or("Missing tensor type")?;
|
186
|
+
let num_joints = metadata.joint_names.len();
|
187
|
+
if *output_shape != vec![num_joints as i64] {
|
188
|
+
return Err(format!(
|
189
|
+
"Expected output shape [{num_joints}], got {:?}",
|
190
|
+
output_shape
|
191
|
+
)
|
192
|
+
.into());
|
193
|
+
}
|
194
|
+
|
195
|
+
let infered_carry_shape = session.outputs[1]
|
196
|
+
.output_type
|
197
|
+
.tensor_dimensions()
|
198
|
+
.ok_or("Missing tensor type")?;
|
199
|
+
if *infered_carry_shape != *carry_shape {
|
200
|
+
return Err(format!(
|
201
|
+
"Expected carry shape {:?}, got {:?}",
|
202
|
+
carry_shape, infered_carry_shape
|
203
|
+
)
|
204
|
+
.into());
|
205
|
+
}
|
206
|
+
|
207
|
+
Ok(())
|
208
|
+
}
|
209
|
+
|
210
|
+
pub async fn get_inputs(
|
211
|
+
&self,
|
212
|
+
input_types: &[InputType],
|
213
|
+
) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError> {
|
214
|
+
self.provider.get_inputs(input_types, &self.metadata).await
|
215
|
+
}
|
216
|
+
|
217
|
+
pub async fn init(&self) -> Result<Array<f32, IxDyn>, Box<dyn std::error::Error>> {
|
218
|
+
let input_values: Vec<(&str, Value)> = Vec::new();
|
219
|
+
let outputs = self.init_session.run(input_values)?;
|
220
|
+
let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
|
221
|
+
Ok(output_tensor.view().to_owned())
|
222
|
+
}
|
223
|
+
|
224
|
+
pub async fn step(
|
225
|
+
&self,
|
226
|
+
carry: Array<f32, IxDyn>,
|
227
|
+
) -> Result<(Array<f32, IxDyn>, Array<f32, IxDyn>), Box<dyn std::error::Error>> {
|
228
|
+
// Gets the model input names.
|
229
|
+
let input_names: Vec<String> = self
|
230
|
+
.step_session
|
231
|
+
.inputs
|
232
|
+
.iter()
|
233
|
+
.map(|i| i.name.clone())
|
234
|
+
.collect();
|
235
|
+
|
236
|
+
// Calls the relevant getter methods in parallel.
|
237
|
+
let mut input_types = Vec::new();
|
238
|
+
let mut inputs = HashMap::new();
|
239
|
+
for name in &input_names {
|
240
|
+
match name.as_str() {
|
241
|
+
"joint_angles" => {
|
242
|
+
input_types.push(InputType::JointAngles);
|
243
|
+
}
|
244
|
+
"joint_angular_velocities" => {
|
245
|
+
input_types.push(InputType::JointAngularVelocities);
|
246
|
+
}
|
247
|
+
"initial_heading" => {
|
248
|
+
input_types.push(InputType::InitialHeading);
|
249
|
+
}
|
250
|
+
"quaternion" => {
|
251
|
+
input_types.push(InputType::Quaternion);
|
252
|
+
}
|
253
|
+
"projected_gravity" => {
|
254
|
+
input_types.push(InputType::ProjectedGravity);
|
255
|
+
}
|
256
|
+
"accelerometer" => {
|
257
|
+
input_types.push(InputType::Accelerometer);
|
258
|
+
}
|
259
|
+
"gyroscope" => {
|
260
|
+
input_types.push(InputType::Gyroscope);
|
261
|
+
}
|
262
|
+
"command" => {
|
263
|
+
input_types.push(InputType::Command);
|
264
|
+
}
|
265
|
+
"time" => {
|
266
|
+
input_types.push(InputType::Time);
|
267
|
+
}
|
268
|
+
"carry" => {
|
269
|
+
inputs.insert(InputType::Carry, carry.clone());
|
270
|
+
}
|
271
|
+
_ => return Err(format!("Unknown input name: {}", name).into()),
|
272
|
+
}
|
273
|
+
}
|
274
|
+
|
275
|
+
// Gets the input values.
|
276
|
+
let result = self
|
277
|
+
.provider
|
278
|
+
.get_inputs(&input_types, &self.metadata)
|
279
|
+
.await?;
|
280
|
+
|
281
|
+
// Adds the input values to the input map.
|
282
|
+
inputs.extend(result);
|
283
|
+
|
284
|
+
// Convert inputs to ONNX values
|
285
|
+
let mut input_values: Vec<(&str, Value)> = Vec::new();
|
286
|
+
for input in &self.step_session.inputs {
|
287
|
+
let input_type = InputType::from_name(&input.name)?;
|
288
|
+
let input_data = inputs
|
289
|
+
.get(&input_type)
|
290
|
+
.ok_or_else(|| format!("Missing input: {}", input.name))?;
|
291
|
+
let input_value = Value::from_array(input_data.view())?.into_dyn();
|
292
|
+
input_values.push((input.name.as_str(), input_value));
|
293
|
+
}
|
294
|
+
|
295
|
+
// Run the model
|
296
|
+
let outputs = self.step_session.run(input_values)?;
|
297
|
+
let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
|
298
|
+
let carry_tensor = outputs[1].try_extract_tensor::<f32>()?;
|
299
|
+
|
300
|
+
// Log the step if needed
|
301
|
+
if let Some(lg) = &self.logger {
|
302
|
+
let joint_angles_opt = inputs
|
303
|
+
.get(&InputType::JointAngles)
|
304
|
+
.map(|a| a.as_slice().unwrap());
|
305
|
+
let joint_vels_opt = inputs
|
306
|
+
.get(&InputType::JointAngularVelocities)
|
307
|
+
.map(|a| a.as_slice().unwrap());
|
308
|
+
let initial_heading_opt = inputs
|
309
|
+
.get(&InputType::InitialHeading)
|
310
|
+
.map(|a| a.as_slice().unwrap());
|
311
|
+
let quaternion_opt = inputs
|
312
|
+
.get(&InputType::Quaternion)
|
313
|
+
.map(|a| a.as_slice().unwrap());
|
314
|
+
let projected_g_opt = inputs
|
315
|
+
.get(&InputType::ProjectedGravity)
|
316
|
+
.map(|a| a.as_slice().unwrap());
|
317
|
+
let accel_opt = inputs
|
318
|
+
.get(&InputType::Accelerometer)
|
319
|
+
.map(|a| a.as_slice().unwrap());
|
320
|
+
let gyro_opt = inputs
|
321
|
+
.get(&InputType::Gyroscope)
|
322
|
+
.map(|a| a.as_slice().unwrap());
|
323
|
+
let command_opt = inputs
|
324
|
+
.get(&InputType::Command)
|
325
|
+
.map(|a| a.as_slice().unwrap());
|
326
|
+
let output_opt = Some(output_tensor.as_slice().unwrap());
|
327
|
+
|
328
|
+
lg.log_step(
|
329
|
+
joint_angles_opt,
|
330
|
+
joint_vels_opt,
|
331
|
+
initial_heading_opt,
|
332
|
+
quaternion_opt,
|
333
|
+
projected_g_opt,
|
334
|
+
accel_opt,
|
335
|
+
gyro_opt,
|
336
|
+
command_opt,
|
337
|
+
output_opt,
|
338
|
+
);
|
339
|
+
}
|
340
|
+
|
341
|
+
Ok((
|
342
|
+
output_tensor.view().to_owned(),
|
343
|
+
carry_tensor.view().to_owned(),
|
344
|
+
))
|
345
|
+
}
|
346
|
+
|
347
|
+
pub async fn take_action(
|
348
|
+
&self,
|
349
|
+
action: Array<f32, IxDyn>,
|
350
|
+
) -> Result<(), Box<dyn std::error::Error>> {
|
351
|
+
self.provider.take_action(action, &self.metadata).await?;
|
352
|
+
Ok(())
|
353
|
+
}
|
354
|
+
}
|
@@ -0,0 +1,107 @@
|
|
1
|
+
use std::sync::atomic::{AtomicBool, Ordering};
|
2
|
+
use std::sync::Arc;
|
3
|
+
use tokio::runtime::Runtime;
|
4
|
+
|
5
|
+
use crate::model::{ModelError, ModelRunner};
|
6
|
+
use crate::types::InputType;
|
7
|
+
use std::time::Duration;
|
8
|
+
use tokio::time::interval;
|
9
|
+
|
10
|
+
pub struct ModelRuntime {
|
11
|
+
model_runner: Arc<ModelRunner>,
|
12
|
+
dt: Duration,
|
13
|
+
slowdown_factor: i32,
|
14
|
+
magnitude_factor: f32,
|
15
|
+
running: Arc<AtomicBool>,
|
16
|
+
runtime: Option<Runtime>,
|
17
|
+
}
|
18
|
+
|
19
|
+
impl ModelRuntime {
|
20
|
+
pub fn new(model_runner: Arc<ModelRunner>, dt: u64) -> Self {
|
21
|
+
Self {
|
22
|
+
model_runner,
|
23
|
+
dt: Duration::from_millis(dt),
|
24
|
+
slowdown_factor: 1,
|
25
|
+
magnitude_factor: 1.0,
|
26
|
+
running: Arc::new(AtomicBool::new(false)),
|
27
|
+
runtime: None,
|
28
|
+
}
|
29
|
+
}
|
30
|
+
|
31
|
+
pub fn set_slowdown_factor(&mut self, slowdown_factor: i32) {
|
32
|
+
assert!(slowdown_factor >= 1);
|
33
|
+
self.slowdown_factor = slowdown_factor;
|
34
|
+
}
|
35
|
+
|
36
|
+
pub fn set_magnitude_factor(&mut self, magnitude_factor: f32) {
|
37
|
+
assert!(magnitude_factor >= 0.0);
|
38
|
+
assert!(magnitude_factor <= 1.0);
|
39
|
+
self.magnitude_factor = magnitude_factor;
|
40
|
+
}
|
41
|
+
|
42
|
+
pub fn start(&mut self) -> Result<(), ModelError> {
|
43
|
+
if self.running.load(Ordering::Relaxed) {
|
44
|
+
return Ok(());
|
45
|
+
}
|
46
|
+
|
47
|
+
let running = self.running.clone();
|
48
|
+
let model_runner = self.model_runner.clone();
|
49
|
+
let dt = self.dt;
|
50
|
+
let slowdown_factor = self.slowdown_factor;
|
51
|
+
let magnitude_factor = self.magnitude_factor;
|
52
|
+
|
53
|
+
let runtime = Runtime::new()?;
|
54
|
+
running.store(true, Ordering::Relaxed);
|
55
|
+
|
56
|
+
runtime.spawn(async move {
|
57
|
+
let mut carry = model_runner
|
58
|
+
.init()
|
59
|
+
.await
|
60
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
61
|
+
|
62
|
+
let model_inputs = model_runner
|
63
|
+
.get_inputs(&[InputType::JointAngles])
|
64
|
+
.await
|
65
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
66
|
+
let mut joint_positions = model_inputs[&InputType::JointAngles].clone();
|
67
|
+
|
68
|
+
// Wait for the first tick, since it happens immediately.
|
69
|
+
let mut interval = interval(dt);
|
70
|
+
interval.tick().await;
|
71
|
+
|
72
|
+
while running.load(Ordering::Relaxed) {
|
73
|
+
let (output, next_carry) = model_runner
|
74
|
+
.step(carry)
|
75
|
+
.await
|
76
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
77
|
+
carry = next_carry;
|
78
|
+
|
79
|
+
for i in 1..(slowdown_factor + 1) {
|
80
|
+
if !running.load(Ordering::Relaxed) {
|
81
|
+
break;
|
82
|
+
}
|
83
|
+
let t = i as f32 / slowdown_factor as f32;
|
84
|
+
let interp_joint_positions = &joint_positions * (1.0 - t) + &output * t;
|
85
|
+
model_runner
|
86
|
+
.take_action(interp_joint_positions * magnitude_factor)
|
87
|
+
.await
|
88
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
89
|
+
interval.tick().await;
|
90
|
+
}
|
91
|
+
|
92
|
+
joint_positions = output;
|
93
|
+
}
|
94
|
+
Ok::<(), ModelError>(())
|
95
|
+
});
|
96
|
+
|
97
|
+
self.runtime = Some(runtime);
|
98
|
+
Ok(())
|
99
|
+
}
|
100
|
+
|
101
|
+
pub fn stop(&mut self) {
|
102
|
+
self.running.store(false, Ordering::Relaxed);
|
103
|
+
if let Some(runtime) = self.runtime.take() {
|
104
|
+
runtime.shutdown_background();
|
105
|
+
}
|
106
|
+
}
|
107
|
+
}
|
kinfer/rust/src/types.rs
ADDED
@@ -0,0 +1,96 @@
|
|
1
|
+
use serde::Deserialize;
|
2
|
+
use serde::Serialize;
|
3
|
+
|
4
|
+
#[derive(Debug, Deserialize, Serialize, Clone)]
|
5
|
+
pub struct ModelMetadata {
|
6
|
+
pub joint_names: Vec<String>,
|
7
|
+
pub num_commands: Option<usize>,
|
8
|
+
pub carry_size: Vec<usize>,
|
9
|
+
}
|
10
|
+
|
11
|
+
impl ModelMetadata {
|
12
|
+
pub fn model_validate_json(json: String) -> Result<Self, Box<dyn std::error::Error>> {
|
13
|
+
Ok(serde_json::from_str(&json)?)
|
14
|
+
}
|
15
|
+
|
16
|
+
pub fn to_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
17
|
+
Ok(serde_json::to_string(self)?)
|
18
|
+
}
|
19
|
+
}
|
20
|
+
|
21
|
+
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
|
22
|
+
pub enum InputType {
|
23
|
+
JointAngles,
|
24
|
+
JointAngularVelocities,
|
25
|
+
InitialHeading,
|
26
|
+
Quaternion,
|
27
|
+
ProjectedGravity,
|
28
|
+
Accelerometer,
|
29
|
+
Gyroscope,
|
30
|
+
Command,
|
31
|
+
Time,
|
32
|
+
Carry,
|
33
|
+
}
|
34
|
+
|
35
|
+
impl InputType {
|
36
|
+
pub fn get_name(&self) -> &str {
|
37
|
+
match self {
|
38
|
+
InputType::JointAngles => "joint_angles",
|
39
|
+
InputType::JointAngularVelocities => "joint_angular_velocities",
|
40
|
+
InputType::InitialHeading => "initial_heading",
|
41
|
+
InputType::Quaternion => "quaternion",
|
42
|
+
InputType::ProjectedGravity => "projected_gravity",
|
43
|
+
InputType::Accelerometer => "accelerometer",
|
44
|
+
InputType::Gyroscope => "gyroscope",
|
45
|
+
InputType::Command => "command",
|
46
|
+
InputType::Time => "time",
|
47
|
+
InputType::Carry => "carry",
|
48
|
+
}
|
49
|
+
}
|
50
|
+
|
51
|
+
pub fn get_shape(&self, metadata: &ModelMetadata) -> Vec<usize> {
|
52
|
+
match self {
|
53
|
+
InputType::JointAngles => vec![metadata.joint_names.len()],
|
54
|
+
InputType::JointAngularVelocities => vec![metadata.joint_names.len()],
|
55
|
+
InputType::InitialHeading => vec![1],
|
56
|
+
InputType::Quaternion => vec![4],
|
57
|
+
InputType::ProjectedGravity => vec![3],
|
58
|
+
InputType::Accelerometer => vec![3],
|
59
|
+
InputType::Gyroscope => vec![3],
|
60
|
+
InputType::Command => vec![metadata.num_commands.unwrap_or(0)],
|
61
|
+
InputType::Time => vec![1],
|
62
|
+
InputType::Carry => metadata.carry_size.clone(),
|
63
|
+
}
|
64
|
+
}
|
65
|
+
|
66
|
+
pub fn from_name(name: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
67
|
+
match name {
|
68
|
+
"joint_angles" => Ok(InputType::JointAngles),
|
69
|
+
"joint_angular_velocities" => Ok(InputType::JointAngularVelocities),
|
70
|
+
"initial_heading" => Ok(InputType::InitialHeading),
|
71
|
+
"quaternion" => Ok(InputType::Quaternion),
|
72
|
+
"projected_gravity" => Ok(InputType::ProjectedGravity),
|
73
|
+
"accelerometer" => Ok(InputType::Accelerometer),
|
74
|
+
"gyroscope" => Ok(InputType::Gyroscope),
|
75
|
+
"command" => Ok(InputType::Command),
|
76
|
+
"time" => Ok(InputType::Time),
|
77
|
+
"carry" => Ok(InputType::Carry),
|
78
|
+
_ => Err(format!("Unknown input type: {}", name).into()),
|
79
|
+
}
|
80
|
+
}
|
81
|
+
|
82
|
+
pub fn get_names() -> Vec<&'static str> {
|
83
|
+
vec![
|
84
|
+
"joint_angles",
|
85
|
+
"joint_angular_velocities",
|
86
|
+
"initial_heading",
|
87
|
+
"quaternion",
|
88
|
+
"projected_gravity",
|
89
|
+
"accelerometer",
|
90
|
+
"gyroscope",
|
91
|
+
"command",
|
92
|
+
"time",
|
93
|
+
"carry",
|
94
|
+
]
|
95
|
+
}
|
96
|
+
}
|
@@ -0,0 +1,26 @@
|
|
1
|
+
[package]
|
2
|
+
|
3
|
+
name = "rust_bindings"
|
4
|
+
version.workspace = true
|
5
|
+
edition.workspace = true
|
6
|
+
description.workspace = true
|
7
|
+
authors.workspace = true
|
8
|
+
repository.workspace = true
|
9
|
+
license.workspace = true
|
10
|
+
readme.workspace = true
|
11
|
+
|
12
|
+
[lib]
|
13
|
+
|
14
|
+
name = "rust_bindings"
|
15
|
+
crate-type = ["cdylib", "rlib"]
|
16
|
+
|
17
|
+
[dependencies]
|
18
|
+
|
19
|
+
async-trait = "0.1"
|
20
|
+
kinfer = { path = "../rust" }
|
21
|
+
ndarray = "0.16.1"
|
22
|
+
numpy = ">= 0.24.0"
|
23
|
+
pyo3 = { version = ">= 0.24.0", features = ["extension-module"] }
|
24
|
+
pyo3-async-runtimes = { version = ">= 0.24.0", features = ["attributes", "tokio-runtime"] }
|
25
|
+
pyo3-stub-gen = ">= 0.6.0"
|
26
|
+
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# This file is automatically generated by pyo3_stub_gen
|
2
|
+
# ruff: noqa: E501, F401
|
3
|
+
|
4
|
+
import builtins
|
5
|
+
import numpy
|
6
|
+
import numpy.typing
|
7
|
+
import typing
|
8
|
+
|
9
|
+
class ModelProviderABC:
|
10
|
+
def __new__(cls) -> ModelProviderABC: ...
|
11
|
+
def get_inputs(self, input_types:typing.Sequence[builtins.str], metadata:PyModelMetadata) -> builtins.dict[builtins.str, numpy.typing.NDArray[numpy.float32]]: ...
|
12
|
+
def take_action(self, action:numpy.typing.NDArray[numpy.float32], metadata:PyModelMetadata) -> None: ...
|
13
|
+
|
14
|
+
class PyInputType:
|
15
|
+
def __new__(cls, input_type:builtins.str) -> PyInputType: ...
|
16
|
+
def get_name(self) -> builtins.str: ...
|
17
|
+
def get_shape(self, metadata:PyModelMetadata) -> builtins.list[builtins.int]: ...
|
18
|
+
def __repr__(self) -> builtins.str: ...
|
19
|
+
def __eq__(self, other:typing.Any) -> builtins.bool: ...
|
20
|
+
|
21
|
+
class PyModelMetadata:
|
22
|
+
def __new__(self, joint_names:typing.Sequence[builtins.str], num_commands:typing.Optional[builtins.int], carry_size:typing.Sequence[builtins.int]) -> PyModelMetadata: ...
|
23
|
+
def to_json(self) -> builtins.str: ...
|
24
|
+
def __repr__(self) -> builtins.str: ...
|
25
|
+
def __eq__(self, other:typing.Any) -> builtins.bool: ...
|
26
|
+
|
27
|
+
class PyModelProvider:
|
28
|
+
...
|
29
|
+
|
30
|
+
class PyModelRunner:
|
31
|
+
def __new__(cls, model_path:builtins.str, provider:ModelProviderABC) -> PyModelRunner: ...
|
32
|
+
def init(self) -> numpy.typing.NDArray[numpy.float32]: ...
|
33
|
+
def step(self, carry:numpy.typing.NDArray[numpy.float32]) -> tuple[numpy.typing.NDArray[numpy.float32], numpy.typing.NDArray[numpy.float32]]: ...
|
34
|
+
def take_action(self, action:numpy.typing.NDArray[numpy.float32]) -> None: ...
|
35
|
+
|
36
|
+
class PyModelRuntime:
|
37
|
+
def __new__(cls, model_runner:PyModelRunner, dt:builtins.int) -> PyModelRuntime: ...
|
38
|
+
def set_slowdown_factor(self, slowdown_factor:builtins.int) -> None: ...
|
39
|
+
def set_magnitude_factor(self, magnitude_factor:builtins.float) -> None: ...
|
40
|
+
def start(self) -> None: ...
|
41
|
+
def stop(self) -> None: ...
|
42
|
+
|
43
|
+
def get_version() -> builtins.str: ...
|
44
|
+
|
45
|
+
def metadata_from_json(json:builtins.str) -> PyModelMetadata: ...
|
46
|
+
|