kinfer 0.3.3__cp311-cp311-macosx_11_0_arm64.whl → 0.4.0__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +0 -5
- kinfer/common/__init__.py +0 -0
- kinfer/common/types.py +11 -0
- kinfer/export/common.py +35 -0
- kinfer/export/jax.py +51 -0
- kinfer/export/pytorch.py +42 -110
- kinfer/export/serialize.py +86 -0
- kinfer/requirements.txt +3 -4
- kinfer/rust/Cargo.toml +8 -6
- kinfer/rust/src/lib.rs +2 -11
- kinfer/rust/src/model.rs +271 -121
- kinfer/rust/src/runtime.rs +104 -0
- kinfer/rust_bindings/Cargo.toml +8 -1
- kinfer/rust_bindings/rust_bindings.pyi +35 -0
- kinfer/rust_bindings/src/lib.rs +310 -1
- kinfer/rust_bindings.cpython-311-darwin.so +0 -0
- kinfer/rust_bindings.pyi +29 -1
- kinfer-0.4.0.dist-info/METADATA +55 -0
- kinfer-0.4.0.dist-info/RECORD +26 -0
- {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/WHEEL +2 -1
- kinfer/inference/__init__.py +0 -2
- kinfer/inference/base.py +0 -64
- kinfer/inference/python.py +0 -66
- kinfer/proto/__init__.py +0 -40
- kinfer/proto/kinfer_pb2.py +0 -103
- kinfer/proto/kinfer_pb2.pyi +0 -1097
- kinfer/requirements-dev.txt +0 -8
- kinfer/rust/build.rs +0 -16
- kinfer/rust/src/kinfer_proto.rs +0 -14
- kinfer/rust/src/main.rs +0 -6
- kinfer/rust/src/onnx_serializer.rs +0 -804
- kinfer/rust/src/serializer.rs +0 -221
- kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
- kinfer/serialize/__init__.py +0 -60
- kinfer/serialize/base.py +0 -536
- kinfer/serialize/json.py +0 -399
- kinfer/serialize/numpy.py +0 -426
- kinfer/serialize/pytorch.py +0 -402
- kinfer/serialize/schema.py +0 -125
- kinfer/serialize/types.py +0 -17
- kinfer/serialize/utils.py +0 -177
- kinfer-0.3.3.dist-info/METADATA +0 -57
- kinfer-0.3.3.dist-info/RECORD +0 -40
- {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info/licenses}/LICENSE +0 -0
- {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/top_level.txt +0 -0
kinfer/rust/src/model.rs
CHANGED
@@ -1,153 +1,303 @@
|
|
1
|
-
use
|
2
|
-
use
|
1
|
+
use async_trait::async_trait;
|
2
|
+
use flate2::read::GzDecoder;
|
3
|
+
use futures_util::future;
|
4
|
+
use ndarray::{Array, IxDyn};
|
5
|
+
use ort::session::Session;
|
6
|
+
use ort::value::Value;
|
7
|
+
use serde::Deserialize;
|
8
|
+
use std::collections::HashMap;
|
9
|
+
use std::io::Read;
|
3
10
|
use std::path::Path;
|
11
|
+
use std::sync::Arc;
|
12
|
+
use tar::Archive;
|
13
|
+
use tokio::fs::File;
|
14
|
+
use tokio::io::AsyncReadExt;
|
4
15
|
|
5
|
-
|
6
|
-
|
7
|
-
|
16
|
+
#[derive(Debug, Deserialize)]
|
17
|
+
struct ModelMetadata {
|
18
|
+
joint_names: Vec<String>,
|
19
|
+
}
|
8
20
|
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
21
|
+
impl ModelMetadata {
|
22
|
+
fn model_validate_json(json: String) -> Result<Self, Box<dyn std::error::Error>> {
|
23
|
+
Ok(serde_json::from_str(&json)?)
|
24
|
+
}
|
25
|
+
}
|
14
26
|
|
15
|
-
|
27
|
+
#[derive(Debug, thiserror::Error)]
|
28
|
+
pub enum ModelError {
|
29
|
+
#[error("IO error: {0}")]
|
30
|
+
Io(#[from] std::io::Error),
|
31
|
+
#[error("Provider error: {0}")]
|
32
|
+
Provider(String),
|
16
33
|
}
|
17
34
|
|
18
|
-
|
35
|
+
#[async_trait]
|
36
|
+
pub trait ModelProvider: Send + Sync {
|
37
|
+
async fn get_joint_angles(
|
38
|
+
&self,
|
39
|
+
joint_names: &[String],
|
40
|
+
) -> Result<Array<f32, IxDyn>, ModelError>;
|
41
|
+
async fn get_joint_angular_velocities(
|
42
|
+
&self,
|
43
|
+
joint_names: &[String],
|
44
|
+
) -> Result<Array<f32, IxDyn>, ModelError>;
|
45
|
+
async fn get_projected_gravity(&self) -> Result<Array<f32, IxDyn>, ModelError>;
|
46
|
+
async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError>;
|
47
|
+
async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError>;
|
48
|
+
async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError>;
|
49
|
+
async fn take_action(
|
50
|
+
&self,
|
51
|
+
joint_names: Vec<String>,
|
52
|
+
action: Array<f32, IxDyn>,
|
53
|
+
) -> Result<(), ModelError>;
|
54
|
+
}
|
19
55
|
|
20
56
|
pub struct ModelRunner {
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
output_serializer: OnnxMultiSerializer,
|
57
|
+
init_session: Session,
|
58
|
+
step_session: Session,
|
59
|
+
metadata: ModelMetadata,
|
60
|
+
provider: Arc<dyn ModelProvider>,
|
26
61
|
}
|
27
62
|
|
28
63
|
impl ModelRunner {
|
29
|
-
pub fn new<P: AsRef<Path>>(
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
64
|
+
pub async fn new<P: AsRef<Path>>(
|
65
|
+
model_path: P,
|
66
|
+
input_provider: Arc<dyn ModelProvider>,
|
67
|
+
) -> Result<Self, Box<dyn std::error::Error>> {
|
68
|
+
let mut file = File::open(model_path).await?;
|
69
|
+
|
70
|
+
// Read entire file into memory
|
71
|
+
let mut buffer = Vec::new();
|
72
|
+
file.read_to_end(&mut buffer).await?;
|
73
|
+
|
74
|
+
// Decompress and read the tar archive from memory
|
75
|
+
let gz = GzDecoder::new(&buffer[..]);
|
76
|
+
let mut archive = Archive::new(gz);
|
77
|
+
|
78
|
+
// Extract and validate joint names
|
79
|
+
let mut metadata: Option<String> = None;
|
80
|
+
let mut init_fn: Option<Vec<u8>> = None;
|
81
|
+
let mut step_fn: Option<Vec<u8>> = None;
|
82
|
+
|
83
|
+
for entry in archive.entries()? {
|
84
|
+
let mut entry = entry?;
|
85
|
+
let path = entry.path()?;
|
86
|
+
let path_str = path.to_string_lossy();
|
87
|
+
|
88
|
+
match path_str.as_ref() {
|
89
|
+
"metadata.json" => {
|
90
|
+
let mut contents = String::new();
|
91
|
+
entry.read_to_string(&mut contents)?;
|
92
|
+
metadata = Some(contents);
|
50
93
|
}
|
94
|
+
"init_fn.onnx" => {
|
95
|
+
let size = entry.size() as usize;
|
96
|
+
let mut contents = vec![0u8; size];
|
97
|
+
entry.read_exact(&mut contents)?;
|
98
|
+
assert_eq!(contents.len(), entry.size() as usize);
|
99
|
+
init_fn = Some(contents);
|
100
|
+
}
|
101
|
+
"step_fn.onnx" => {
|
102
|
+
let size = entry.size() as usize;
|
103
|
+
let mut contents = vec![0u8; size];
|
104
|
+
entry.read_exact(&mut contents)?;
|
105
|
+
assert_eq!(contents.len(), entry.size() as usize);
|
106
|
+
step_fn = Some(contents);
|
107
|
+
}
|
108
|
+
_ => return Err("Unknown entry".into()),
|
51
109
|
}
|
52
110
|
}
|
53
111
|
|
54
|
-
|
112
|
+
// Reads the files.
|
113
|
+
let metadata = ModelMetadata::model_validate_json(
|
114
|
+
metadata.ok_or("metadata.json not found in archive")?,
|
115
|
+
)?;
|
116
|
+
let init_session = Session::builder()?
|
117
|
+
.commit_from_memory(&init_fn.ok_or("init_fn.onnx not found in archive")?)?;
|
118
|
+
let step_session = Session::builder()?
|
119
|
+
.commit_from_memory(&step_fn.ok_or("step_fn.onnx not found in archive")?)?;
|
55
120
|
|
56
|
-
//
|
57
|
-
|
58
|
-
.
|
59
|
-
|
60
|
-
|
61
|
-
.
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
.
|
121
|
+
// Validate init_fn has no inputs and one output
|
122
|
+
if !init_session.inputs.is_empty() {
|
123
|
+
return Err("init_fn should not have any inputs".into());
|
124
|
+
}
|
125
|
+
if init_session.outputs.len() != 1 {
|
126
|
+
return Err("init_fn should have exactly one output".into());
|
127
|
+
}
|
128
|
+
|
129
|
+
// Get carry shape from init_fn output
|
130
|
+
let carry_shape = init_session.outputs[0]
|
131
|
+
.output_type
|
132
|
+
.tensor_dimensions()
|
133
|
+
.ok_or("Missing tensor type")?
|
134
|
+
.to_vec();
|
67
135
|
|
68
|
-
//
|
69
|
-
|
70
|
-
let output_serializer = OnnxMultiSerializer::new(output_schema);
|
136
|
+
// Validate step_fn inputs and outputs
|
137
|
+
Self::validate_step_fn(&step_session, metadata.joint_names.len(), &carry_shape)?;
|
71
138
|
|
72
139
|
Ok(Self {
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
output_serializer,
|
140
|
+
init_session,
|
141
|
+
step_session,
|
142
|
+
metadata,
|
143
|
+
provider: input_provider,
|
78
144
|
})
|
79
145
|
}
|
80
146
|
|
81
|
-
|
82
|
-
|
83
|
-
|
147
|
+
fn validate_step_fn(
|
148
|
+
session: &Session,
|
149
|
+
num_joints: usize,
|
150
|
+
carry_shape: &[i64],
|
151
|
+
) -> Result<(), Box<dyn std::error::Error>> {
|
152
|
+
// Validate inputs
|
153
|
+
for input in &session.inputs {
|
154
|
+
let dims = input.input_type.tensor_dimensions().ok_or(format!(
|
155
|
+
"Input {} is not a tensor with known dimensions",
|
156
|
+
input.name
|
157
|
+
))?;
|
84
158
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
159
|
+
match input.name.as_str() {
|
160
|
+
"joint_angles" | "joint_angular_velocities" => {
|
161
|
+
if *dims != vec![num_joints as i64] {
|
162
|
+
return Err(format!(
|
163
|
+
"Expected shape [{num_joints}] for input `{}`, got {:?}",
|
164
|
+
input.name, dims
|
165
|
+
)
|
166
|
+
.into());
|
167
|
+
}
|
168
|
+
}
|
169
|
+
"projected_gravity" | "accelerometer" | "gyroscope" => {
|
170
|
+
if *dims != vec![3] {
|
171
|
+
return Err(format!(
|
172
|
+
"Expected shape [3] for input `{}`, got {:?}",
|
173
|
+
input.name, dims
|
174
|
+
)
|
175
|
+
.into());
|
176
|
+
}
|
177
|
+
}
|
178
|
+
"carry" => {
|
179
|
+
if dims != carry_shape {
|
180
|
+
return Err(format!(
|
181
|
+
"Expected shape {:?} for input `carry`, got {:?}",
|
182
|
+
carry_shape, dims
|
183
|
+
)
|
184
|
+
.into());
|
185
|
+
}
|
186
|
+
}
|
187
|
+
_ => return Err(format!("Unknown input name: {}", input.name).into()),
|
188
|
+
}
|
189
|
+
}
|
190
|
+
|
191
|
+
// Validate outputs
|
192
|
+
if session.outputs.len() != 2 {
|
193
|
+
return Err("Step function must have exactly 2 outputs".into());
|
194
|
+
}
|
195
|
+
|
196
|
+
let output_shape = session.outputs[0]
|
197
|
+
.output_type
|
198
|
+
.tensor_dimensions()
|
199
|
+
.ok_or("Missing tensor type")?;
|
200
|
+
if *output_shape != vec![num_joints as i64] {
|
201
|
+
return Err(format!(
|
202
|
+
"Expected output shape [{num_joints}], got {:?}",
|
203
|
+
output_shape
|
204
|
+
)
|
205
|
+
.into());
|
206
|
+
}
|
207
|
+
|
208
|
+
let infered_carry_shape = session.outputs[1]
|
209
|
+
.output_type
|
210
|
+
.tensor_dimensions()
|
211
|
+
.ok_or("Missing tensor type")?;
|
212
|
+
if *infered_carry_shape != *carry_shape {
|
213
|
+
return Err(format!(
|
214
|
+
"Expected carry shape {:?}, got {:?}",
|
215
|
+
carry_shape, infered_carry_shape
|
216
|
+
)
|
217
|
+
.into());
|
218
|
+
}
|
114
219
|
|
115
|
-
pub fn export_model<P: AsRef<Path>>(&self, model_path: P) -> Result<(), Box<dyn std::error::Error>> {
|
116
|
-
let model_bytes = self.session.model_as_bytes()?;
|
117
|
-
let mut model = ModelProto::decode(&mut model_bytes.as_slice())?;
|
118
|
-
model.set_metadata_props(self.schema.encode_to_vec())?;
|
119
|
-
std::fs::write(model_path, model.write_to_bytes()?)?;
|
120
220
|
Ok(())
|
121
221
|
}
|
122
222
|
|
123
|
-
pub fn
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
223
|
+
pub async fn init(&self) -> Result<Array<f32, IxDyn>, Box<dyn std::error::Error>> {
|
224
|
+
let input_values: Vec<(&str, Value)> = Vec::new();
|
225
|
+
let outputs = self.init_session.run(input_values)?;
|
226
|
+
let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
|
227
|
+
Ok(output_tensor.view().to_owned())
|
228
|
+
}
|
229
|
+
|
230
|
+
pub async fn step(
|
231
|
+
&self,
|
232
|
+
carry: Array<f32, IxDyn>,
|
233
|
+
) -> Result<(Array<f32, IxDyn>, Array<f32, IxDyn>), Box<dyn std::error::Error>> {
|
234
|
+
// Gets the model input names.
|
235
|
+
let input_names: Vec<String> = self
|
236
|
+
.step_session
|
237
|
+
.inputs
|
238
|
+
.iter()
|
239
|
+
.map(|i| i.name.clone())
|
240
|
+
.collect();
|
241
|
+
|
242
|
+
// Calls the relevant getter methods in parallel.
|
243
|
+
let mut futures = Vec::new();
|
244
|
+
for name in &input_names {
|
245
|
+
match name.as_str() {
|
246
|
+
"joint_angles" => {
|
247
|
+
futures.push(self.provider.get_joint_angles(&self.metadata.joint_names))
|
248
|
+
}
|
249
|
+
"joint_angular_velocities" => futures.push(
|
250
|
+
self.provider
|
251
|
+
.get_joint_angular_velocities(&self.metadata.joint_names),
|
252
|
+
),
|
253
|
+
"projected_gravity" => futures.push(self.provider.get_projected_gravity()),
|
254
|
+
"accelerometer" => futures.push(self.provider.get_accelerometer()),
|
255
|
+
"gyroscope" => futures.push(self.provider.get_gyroscope()),
|
256
|
+
"carry" => futures.push(self.provider.get_carry(carry.clone())),
|
257
|
+
_ => return Err(format!("Unknown input name: {}", name).into()),
|
258
|
+
}
|
259
|
+
}
|
260
|
+
|
261
|
+
let results = future::try_join_all(futures).await?;
|
262
|
+
let mut inputs = HashMap::new();
|
263
|
+
for (name, value) in input_names.iter().zip(results) {
|
264
|
+
inputs.insert(name.clone(), value);
|
265
|
+
}
|
266
|
+
|
267
|
+
// Convert inputs to ONNX values
|
268
|
+
let mut input_values: Vec<(&str, Value)> = Vec::new();
|
269
|
+
for input in &self.step_session.inputs {
|
270
|
+
let input_data = inputs
|
271
|
+
.get(&input.name)
|
272
|
+
.ok_or_else(|| format!("Missing input: {}", input.name))?;
|
273
|
+
let input_value = Value::from_array(input_data.view())?.into_dyn();
|
274
|
+
input_values.push((input.name.as_str(), input_value));
|
275
|
+
}
|
276
|
+
|
277
|
+
// Run the model
|
278
|
+
let outputs = self.step_session.run(input_values)?;
|
279
|
+
let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
|
280
|
+
let carry_tensor = outputs[1].try_extract_tensor::<f32>()?;
|
281
|
+
|
282
|
+
Ok((
|
283
|
+
output_tensor.view().to_owned(),
|
284
|
+
carry_tensor.view().to_owned(),
|
285
|
+
))
|
134
286
|
}
|
135
287
|
|
136
|
-
pub fn
|
137
|
-
self
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
)) as Box<dyn std::error::Error>
|
145
|
-
})
|
146
|
-
.map(|schema| schema.clone())
|
288
|
+
pub async fn take_action(
|
289
|
+
&self,
|
290
|
+
action: Array<f32, IxDyn>,
|
291
|
+
) -> Result<(), Box<dyn std::error::Error>> {
|
292
|
+
self.provider
|
293
|
+
.take_action(self.metadata.joint_names.clone(), action)
|
294
|
+
.await?;
|
295
|
+
Ok(())
|
147
296
|
}
|
148
|
-
}
|
149
297
|
|
150
|
-
fn
|
151
|
-
|
152
|
-
|
298
|
+
pub async fn get_joint_angles(&self) -> Result<Array<f32, IxDyn>, Box<dyn std::error::Error>> {
|
299
|
+
let joint_names = &self.metadata.joint_names;
|
300
|
+
let joint_angles = self.provider.get_joint_angles(joint_names).await?;
|
301
|
+
Ok(joint_angles)
|
302
|
+
}
|
153
303
|
}
|
@@ -0,0 +1,104 @@
|
|
1
|
+
use std::sync::atomic::{AtomicBool, Ordering};
|
2
|
+
use std::sync::Arc;
|
3
|
+
use tokio::runtime::Runtime;
|
4
|
+
|
5
|
+
use crate::model::{ModelError, ModelRunner};
|
6
|
+
use std::time::{Duration, Instant};
|
7
|
+
use tokio::time::sleep;
|
8
|
+
|
9
|
+
pub struct ModelRuntime {
|
10
|
+
model_runner: Arc<ModelRunner>,
|
11
|
+
dt: Duration,
|
12
|
+
slowdown_factor: i32,
|
13
|
+
magnitude_factor: f32,
|
14
|
+
running: Arc<AtomicBool>,
|
15
|
+
runtime: Option<Runtime>,
|
16
|
+
}
|
17
|
+
|
18
|
+
impl ModelRuntime {
|
19
|
+
pub fn new(model_runner: Arc<ModelRunner>, dt: u64) -> Self {
|
20
|
+
Self {
|
21
|
+
model_runner,
|
22
|
+
dt: Duration::from_millis(dt),
|
23
|
+
slowdown_factor: 1,
|
24
|
+
magnitude_factor: 1.0,
|
25
|
+
running: Arc::new(AtomicBool::new(false)),
|
26
|
+
runtime: None,
|
27
|
+
}
|
28
|
+
}
|
29
|
+
|
30
|
+
pub fn set_slowdown_factor(&mut self, slowdown_factor: i32) {
|
31
|
+
assert!(slowdown_factor >= 1);
|
32
|
+
self.slowdown_factor = slowdown_factor;
|
33
|
+
}
|
34
|
+
|
35
|
+
pub fn set_magnitude_factor(&mut self, magnitude_factor: f32) {
|
36
|
+
assert!(magnitude_factor >= 0.0);
|
37
|
+
assert!(magnitude_factor <= 1.0);
|
38
|
+
self.magnitude_factor = magnitude_factor;
|
39
|
+
}
|
40
|
+
|
41
|
+
pub fn start(&mut self) -> Result<(), ModelError> {
|
42
|
+
if self.running.load(Ordering::Relaxed) {
|
43
|
+
return Ok(());
|
44
|
+
}
|
45
|
+
|
46
|
+
let running = self.running.clone();
|
47
|
+
let model_runner = self.model_runner.clone();
|
48
|
+
let dt = self.dt;
|
49
|
+
let slowdown_factor = self.slowdown_factor;
|
50
|
+
let magnitude_factor = self.magnitude_factor;
|
51
|
+
|
52
|
+
let runtime = Runtime::new()?;
|
53
|
+
running.store(true, Ordering::Relaxed);
|
54
|
+
|
55
|
+
runtime.spawn(async move {
|
56
|
+
let mut carry = model_runner
|
57
|
+
.init()
|
58
|
+
.await
|
59
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
60
|
+
let mut joint_positions = model_runner
|
61
|
+
.get_joint_angles()
|
62
|
+
.await
|
63
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
64
|
+
let mut last_time = Instant::now();
|
65
|
+
|
66
|
+
while running.load(Ordering::Relaxed) {
|
67
|
+
let (output, next_carry) = model_runner
|
68
|
+
.step(carry)
|
69
|
+
.await
|
70
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
71
|
+
carry = next_carry;
|
72
|
+
|
73
|
+
for i in 1..(slowdown_factor + 1) {
|
74
|
+
if !running.load(Ordering::Relaxed) {
|
75
|
+
break;
|
76
|
+
}
|
77
|
+
let t = i as f32 / slowdown_factor as f32;
|
78
|
+
let interp_joint_positions = &joint_positions * (1.0 - t) + &output * t;
|
79
|
+
model_runner
|
80
|
+
.take_action(interp_joint_positions * magnitude_factor)
|
81
|
+
.await
|
82
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
83
|
+
last_time = last_time + dt;
|
84
|
+
if let Some(sleep_duration) = last_time.checked_duration_since(Instant::now()) {
|
85
|
+
sleep(sleep_duration).await;
|
86
|
+
}
|
87
|
+
}
|
88
|
+
|
89
|
+
joint_positions = output;
|
90
|
+
}
|
91
|
+
Ok::<(), ModelError>(())
|
92
|
+
});
|
93
|
+
|
94
|
+
self.runtime = Some(runtime);
|
95
|
+
Ok(())
|
96
|
+
}
|
97
|
+
|
98
|
+
pub fn stop(&mut self) {
|
99
|
+
self.running.store(false, Ordering::Relaxed);
|
100
|
+
if let Some(runtime) = self.runtime.take() {
|
101
|
+
runtime.shutdown_background();
|
102
|
+
}
|
103
|
+
}
|
104
|
+
}
|
kinfer/rust_bindings/Cargo.toml
CHANGED
@@ -15,5 +15,12 @@ name = "rust_bindings"
|
|
15
15
|
crate-type = ["cdylib", "rlib"]
|
16
16
|
|
17
17
|
[dependencies]
|
18
|
-
|
18
|
+
|
19
|
+
async-trait = "0.1"
|
20
|
+
kinfer = { path = "../rust" }
|
21
|
+
ndarray = "0.16.1"
|
22
|
+
numpy = ">= 0.24.0"
|
23
|
+
pyo3 = { version = ">= 0.24.0", features = ["extension-module"] }
|
24
|
+
pyo3-async-runtimes = { version = ">= 0.24.0", features = ["attributes", "tokio-runtime"] }
|
19
25
|
pyo3-stub-gen = ">= 0.6.0"
|
26
|
+
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
|
@@ -0,0 +1,35 @@
|
|
1
|
+
# This file is automatically generated by pyo3_stub_gen
|
2
|
+
# ruff: noqa: E501, F401
|
3
|
+
|
4
|
+
import builtins
|
5
|
+
import numpy
|
6
|
+
import numpy.typing
|
7
|
+
import typing
|
8
|
+
|
9
|
+
class ModelProviderABC:
|
10
|
+
def __new__(cls) -> ModelProviderABC: ...
|
11
|
+
def get_joint_angles(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
|
12
|
+
def get_joint_angular_velocities(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
|
13
|
+
def get_projected_gravity(self) -> numpy.typing.NDArray[numpy.float32]: ...
|
14
|
+
def get_accelerometer(self) -> numpy.typing.NDArray[numpy.float32]: ...
|
15
|
+
def get_gyroscope(self) -> numpy.typing.NDArray[numpy.float32]: ...
|
16
|
+
def take_action(self, joint_names:typing.Sequence[builtins.str], action:numpy.typing.NDArray[numpy.float32]) -> None: ...
|
17
|
+
|
18
|
+
class PyModelProvider:
|
19
|
+
...
|
20
|
+
|
21
|
+
class PyModelRunner:
|
22
|
+
def __new__(cls, model_path:builtins.str, provider:ModelProviderABC) -> PyModelRunner: ...
|
23
|
+
def init(self) -> numpy.typing.NDArray[numpy.float32]: ...
|
24
|
+
def step(self, carry:numpy.typing.NDArray[numpy.float32]) -> tuple[numpy.typing.NDArray[numpy.float32], numpy.typing.NDArray[numpy.float32]]: ...
|
25
|
+
def take_action(self, action:numpy.typing.NDArray[numpy.float32]) -> None: ...
|
26
|
+
|
27
|
+
class PyModelRuntime:
|
28
|
+
def __new__(cls, model_runner:PyModelRunner, dt:builtins.int) -> PyModelRuntime: ...
|
29
|
+
def set_slowdown_factor(self, slowdown_factor:builtins.int) -> None: ...
|
30
|
+
def set_magnitude_factor(self, magnitude_factor:builtins.float) -> None: ...
|
31
|
+
def start(self) -> None: ...
|
32
|
+
def stop(self) -> None: ...
|
33
|
+
|
34
|
+
def get_version() -> builtins.str: ...
|
35
|
+
|