kinfer 0.3.3__cp312-cp312-macosx_11_0_arm64.whl → 0.4.1__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kinfer/__init__.py +0 -5
  2. kinfer/common/__init__.py +0 -0
  3. kinfer/common/types.py +12 -0
  4. kinfer/export/common.py +41 -0
  5. kinfer/export/jax.py +53 -0
  6. kinfer/export/pytorch.py +45 -110
  7. kinfer/export/serialize.py +93 -0
  8. kinfer/requirements.txt +3 -4
  9. kinfer/rust/Cargo.toml +20 -8
  10. kinfer/rust/src/lib.rs +2 -11
  11. kinfer/rust/src/model.rs +286 -121
  12. kinfer/rust/src/runtime.rs +104 -0
  13. kinfer/rust_bindings/Cargo.toml +8 -1
  14. kinfer/rust_bindings/rust_bindings.pyi +36 -0
  15. kinfer/rust_bindings/src/lib.rs +326 -1
  16. kinfer/rust_bindings.cpython-312-darwin.so +0 -0
  17. kinfer/rust_bindings.pyi +30 -1
  18. kinfer-0.4.1.dist-info/METADATA +55 -0
  19. kinfer-0.4.1.dist-info/RECORD +26 -0
  20. {kinfer-0.3.3.dist-info → kinfer-0.4.1.dist-info}/WHEEL +2 -1
  21. kinfer/inference/__init__.py +0 -2
  22. kinfer/inference/base.py +0 -64
  23. kinfer/inference/python.py +0 -66
  24. kinfer/proto/__init__.py +0 -40
  25. kinfer/proto/kinfer_pb2.py +0 -103
  26. kinfer/proto/kinfer_pb2.pyi +0 -1097
  27. kinfer/requirements-dev.txt +0 -8
  28. kinfer/rust/build.rs +0 -16
  29. kinfer/rust/src/kinfer_proto.rs +0 -14
  30. kinfer/rust/src/main.rs +0 -6
  31. kinfer/rust/src/onnx_serializer.rs +0 -804
  32. kinfer/rust/src/serializer.rs +0 -221
  33. kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
  34. kinfer/serialize/__init__.py +0 -60
  35. kinfer/serialize/base.py +0 -536
  36. kinfer/serialize/json.py +0 -399
  37. kinfer/serialize/numpy.py +0 -426
  38. kinfer/serialize/pytorch.py +0 -402
  39. kinfer/serialize/schema.py +0 -125
  40. kinfer/serialize/types.py +0 -17
  41. kinfer/serialize/utils.py +0 -177
  42. kinfer-0.3.3.dist-info/METADATA +0 -57
  43. kinfer-0.3.3.dist-info/RECORD +0 -40
  44. {kinfer-0.3.3.dist-info → kinfer-0.4.1.dist-info/licenses}/LICENSE +0 -0
  45. {kinfer-0.3.3.dist-info → kinfer-0.4.1.dist-info}/top_level.txt +0 -0
kinfer/rust/src/model.rs CHANGED
@@ -1,153 +1,318 @@
1
- use crate::kinfer_proto::{ModelSchema, ProtoIO, ProtoIOSchema};
2
- use crate::onnx_serializer::OnnxMultiSerializer;
1
+ use async_trait::async_trait;
2
+ use flate2::read::GzDecoder;
3
+ use futures_util::future;
4
+ use ndarray::{Array, IxDyn};
5
+ use ort::session::Session;
6
+ use ort::value::Value;
7
+ use serde::Deserialize;
8
+ use std::collections::HashMap;
9
+ use std::io::Read;
3
10
  use std::path::Path;
11
+ use std::sync::Arc;
12
+ use tar::Archive;
13
+ use tokio::fs::File;
14
+ use tokio::io::AsyncReadExt;
4
15
 
5
- use ort::session::builder::GraphOptimizationLevel;
6
- use prost::Message;
7
- use ort::{session::Session, Error as OrtError};
16
+ #[derive(Debug, Deserialize)]
17
+ struct ModelMetadata {
18
+ joint_names: Vec<String>,
19
+ num_commands: Option<usize>,
20
+ }
8
21
 
9
- pub fn load_onnx_model<P: AsRef<Path>>(model_path: P) -> Result<Session, OrtError> {
10
- let model = Session::builder()?
11
- .with_optimization_level(GraphOptimizationLevel::Level3)?
12
- .with_intra_threads(4)?
13
- .commit_from_file(model_path)?;
22
+ impl ModelMetadata {
23
+ fn model_validate_json(json: String) -> Result<Self, Box<dyn std::error::Error>> {
24
+ Ok(serde_json::from_str(&json)?)
25
+ }
26
+ }
14
27
 
15
- Ok(model)
28
+ #[derive(Debug, thiserror::Error)]
29
+ pub enum ModelError {
30
+ #[error("IO error: {0}")]
31
+ Io(#[from] std::io::Error),
32
+ #[error("Provider error: {0}")]
33
+ Provider(String),
16
34
  }
17
35
 
18
- const KINFER_METADATA_KEY: &str = "kinfer_metadata";
36
+ #[async_trait]
37
+ pub trait ModelProvider: Send + Sync {
38
+ async fn get_joint_angles(
39
+ &self,
40
+ joint_names: &[String],
41
+ ) -> Result<Array<f32, IxDyn>, ModelError>;
42
+ async fn get_joint_angular_velocities(
43
+ &self,
44
+ joint_names: &[String],
45
+ ) -> Result<Array<f32, IxDyn>, ModelError>;
46
+ async fn get_projected_gravity(&self) -> Result<Array<f32, IxDyn>, ModelError>;
47
+ async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError>;
48
+ async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError>;
49
+ async fn get_command(&self) -> Result<Array<f32, IxDyn>, ModelError>;
50
+ async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError>;
51
+ async fn take_action(
52
+ &self,
53
+ joint_names: Vec<String>,
54
+ action: Array<f32, IxDyn>,
55
+ ) -> Result<(), ModelError>;
56
+ }
19
57
 
20
58
  pub struct ModelRunner {
21
- session: Session,
22
- attached_metadata: std::collections::HashMap<String, String>,
23
- schema: ModelSchema,
24
- input_serializer: OnnxMultiSerializer,
25
- output_serializer: OnnxMultiSerializer,
59
+ init_session: Session,
60
+ step_session: Session,
61
+ metadata: ModelMetadata,
62
+ provider: Arc<dyn ModelProvider>,
26
63
  }
27
64
 
28
65
  impl ModelRunner {
29
- pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self, Box<dyn std::error::Error>> {
30
- let session = load_onnx_model(model_path)?;
31
- let mut attached_metadata = std::collections::HashMap::new();
32
-
33
- // Extract metadata and attempt to parse schema
34
- let mut schema = None;
35
- {
36
- let metadata = session.metadata()?;
37
- for prop in metadata.custom_keys()? {
38
- if prop == KINFER_METADATA_KEY {
39
- let schema_bytes = metadata.custom(prop.as_str())?;
40
- if let Some(bytes) = schema_bytes {
41
- schema = Some(ModelSchema::decode(&mut bytes.as_bytes())?);
42
- }
43
- } else {
44
- attached_metadata.insert(
45
- prop.to_string(),
46
- metadata
47
- .custom(prop.as_str())?
48
- .map_or_else(String::new, |s| s.to_string()),
49
- );
66
+ pub async fn new<P: AsRef<Path>>(
67
+ model_path: P,
68
+ input_provider: Arc<dyn ModelProvider>,
69
+ ) -> Result<Self, Box<dyn std::error::Error>> {
70
+ let mut file = File::open(model_path).await?;
71
+
72
+ // Read entire file into memory
73
+ let mut buffer = Vec::new();
74
+ file.read_to_end(&mut buffer).await?;
75
+
76
+ // Decompress and read the tar archive from memory
77
+ let gz = GzDecoder::new(&buffer[..]);
78
+ let mut archive = Archive::new(gz);
79
+
80
+ // Extract and validate joint names
81
+ let mut metadata: Option<String> = None;
82
+ let mut init_fn: Option<Vec<u8>> = None;
83
+ let mut step_fn: Option<Vec<u8>> = None;
84
+
85
+ for entry in archive.entries()? {
86
+ let mut entry = entry?;
87
+ let path = entry.path()?;
88
+ let path_str = path.to_string_lossy();
89
+
90
+ match path_str.as_ref() {
91
+ "metadata.json" => {
92
+ let mut contents = String::new();
93
+ entry.read_to_string(&mut contents)?;
94
+ metadata = Some(contents);
50
95
  }
96
+ "init_fn.onnx" => {
97
+ let size = entry.size() as usize;
98
+ let mut contents = vec![0u8; size];
99
+ entry.read_exact(&mut contents)?;
100
+ assert_eq!(contents.len(), entry.size() as usize);
101
+ init_fn = Some(contents);
102
+ }
103
+ "step_fn.onnx" => {
104
+ let size = entry.size() as usize;
105
+ let mut contents = vec![0u8; size];
106
+ entry.read_exact(&mut contents)?;
107
+ assert_eq!(contents.len(), entry.size() as usize);
108
+ step_fn = Some(contents);
109
+ }
110
+ _ => return Err("Unknown entry".into()),
51
111
  }
52
112
  }
53
113
 
54
- let schema: ModelSchema = schema.ok_or_else(|| "kinfer_metadata not found in model metadata")?;
114
+ // Reads the files.
115
+ let metadata = ModelMetadata::model_validate_json(
116
+ metadata.ok_or("metadata.json not found in archive")?,
117
+ )?;
118
+ let init_session = Session::builder()?
119
+ .commit_from_memory(&init_fn.ok_or("init_fn.onnx not found in archive")?)?;
120
+ let step_session = Session::builder()?
121
+ .commit_from_memory(&step_fn.ok_or("step_fn.onnx not found in archive")?)?;
122
+
123
+ // Validate init_fn has no inputs and one output
124
+ if !init_session.inputs.is_empty() {
125
+ return Err("init_fn should not have any inputs".into());
126
+ }
127
+ if init_session.outputs.len() != 1 {
128
+ return Err("init_fn should have exactly one output".into());
129
+ }
55
130
 
56
- // Use as_ref() to borrow the Option contents and clone after ok_or
57
- let input_schema = schema
58
- .input_schema
59
- .as_ref()
60
- .ok_or("Missing input schema")?
61
- .clone();
62
- let output_schema = schema
63
- .output_schema
64
- .as_ref()
65
- .ok_or("Missing output schema")?
66
- .clone();
131
+ // Get carry shape from init_fn output
132
+ let carry_shape = init_session.outputs[0]
133
+ .output_type
134
+ .tensor_dimensions()
135
+ .ok_or("Missing tensor type")?
136
+ .to_vec();
67
137
 
68
- // Create serializers for input and output
69
- let input_serializer = OnnxMultiSerializer::new(input_schema);
70
- let output_serializer = OnnxMultiSerializer::new(output_schema);
138
+ // Validate step_fn inputs and outputs
139
+ Self::validate_step_fn(&step_session, &metadata, &carry_shape)?;
71
140
 
72
141
  Ok(Self {
73
- session,
74
- attached_metadata,
75
- schema,
76
- input_serializer,
77
- output_serializer,
142
+ init_session,
143
+ step_session,
144
+ metadata,
145
+ provider: input_provider,
78
146
  })
79
147
  }
80
148
 
81
- pub fn run(&self, inputs: ProtoIO) -> Result<ProtoIO, Box<dyn std::error::Error>> {
82
- // Serialize inputs to ONNX format
83
- let inputs = self.input_serializer.serialize_io(inputs)?;
149
+ fn validate_step_fn(
150
+ session: &Session,
151
+ metadata: &ModelMetadata,
152
+ carry_shape: &[i64],
153
+ ) -> Result<(), Box<dyn std::error::Error>> {
154
+ // Validate inputs
155
+ for input in &session.inputs {
156
+ let dims = input.input_type.tensor_dimensions().ok_or(format!(
157
+ "Input {} is not a tensor with known dimensions",
158
+ input.name
159
+ ))?;
84
160
 
85
- // Get input names from the session
86
- let input_names = self
87
- .session
88
- .inputs
89
- .iter()
90
- .map(|input| input.name.as_str())
91
- .collect::<Vec<_>>();
92
-
93
- // Create input name-value pairs
94
- let input_values = vec![(input_names[0], inputs)];
95
-
96
- let outputs = self.session.run(input_values)?;
97
-
98
- let output_values = outputs
99
- .values()
100
- .map(|v: ort::value::ValueRef<'_>| {
101
- v.try_upgrade().map_err(|e| {
102
- Box::new(std::io::Error::new(
103
- std::io::ErrorKind::Other,
104
- format!("Failed to upgrade value"),
105
- )) as Box<dyn std::error::Error>
106
- })
107
- })
108
- .collect::<Result<Vec<_>, _>>()?;
109
- // Deserialize outputs from ONNX format
110
- let outputs = self.output_serializer.deserialize_io(output_values)?;
111
-
112
- Ok(outputs)
113
- }
161
+ match input.name.as_str() {
162
+ "joint_angles" | "joint_angular_velocities" => {
163
+ let num_joints = metadata.joint_names.len();
164
+ if *dims != vec![num_joints as i64] {
165
+ return Err(format!(
166
+ "Expected shape [{num_joints}] for input `{}`, got {:?}",
167
+ input.name, dims
168
+ )
169
+ .into());
170
+ }
171
+ }
172
+ "projected_gravity" | "accelerometer" | "gyroscope" => {
173
+ if *dims != vec![3] {
174
+ return Err(format!(
175
+ "Expected shape [3] for input `{}`, got {:?}",
176
+ input.name, dims
177
+ )
178
+ .into());
179
+ }
180
+ }
181
+ "command" => {
182
+ let num_commands = metadata.num_commands.ok_or("num_commands is not set")?;
183
+ if *dims != vec![num_commands as i64] {
184
+ return Err(format!(
185
+ "Expected shape [{num_commands}] for input `{}`, got {:?}",
186
+ input.name, dims
187
+ )
188
+ .into());
189
+ }
190
+ }
191
+ "carry" => {
192
+ if dims != carry_shape {
193
+ return Err(format!(
194
+ "Expected shape {:?} for input `carry`, got {:?}",
195
+ carry_shape, dims
196
+ )
197
+ .into());
198
+ }
199
+ }
200
+ _ => return Err(format!("Unknown input name: {}", input.name).into()),
201
+ }
202
+ }
203
+
204
+ // Validate outputs
205
+ if session.outputs.len() != 2 {
206
+ return Err("Step function must have exactly 2 outputs".into());
207
+ }
208
+
209
+ let output_shape = session.outputs[0]
210
+ .output_type
211
+ .tensor_dimensions()
212
+ .ok_or("Missing tensor type")?;
213
+ let num_joints = metadata.joint_names.len();
214
+ if *output_shape != vec![num_joints as i64] {
215
+ return Err(format!(
216
+ "Expected output shape [{num_joints}], got {:?}",
217
+ output_shape
218
+ )
219
+ .into());
220
+ }
221
+
222
+ let infered_carry_shape = session.outputs[1]
223
+ .output_type
224
+ .tensor_dimensions()
225
+ .ok_or("Missing tensor type")?;
226
+ if *infered_carry_shape != *carry_shape {
227
+ return Err(format!(
228
+ "Expected carry shape {:?}, got {:?}",
229
+ carry_shape, infered_carry_shape
230
+ )
231
+ .into());
232
+ }
114
233
 
115
- pub fn export_model<P: AsRef<Path>>(&self, model_path: P) -> Result<(), Box<dyn std::error::Error>> {
116
- let model_bytes = self.session.model_as_bytes()?;
117
- let mut model = ModelProto::decode(&mut model_bytes.as_slice())?;
118
- model.set_metadata_props(self.schema.encode_to_vec())?;
119
- std::fs::write(model_path, model.write_to_bytes()?)?;
120
234
  Ok(())
121
235
  }
122
236
 
123
- pub fn input_schema(&self) -> Result<ProtoIOSchema, Box<dyn std::error::Error>> {
124
- self.schema
125
- .input_schema
126
- .as_ref()
127
- .ok_or_else(|| {
128
- Box::new(std::io::Error::new(
129
- std::io::ErrorKind::NotFound,
130
- "Missing input schema",
131
- )) as Box<dyn std::error::Error>
132
- })
133
- .map(|schema| schema.clone())
237
+ pub async fn init(&self) -> Result<Array<f32, IxDyn>, Box<dyn std::error::Error>> {
238
+ let input_values: Vec<(&str, Value)> = Vec::new();
239
+ let outputs = self.init_session.run(input_values)?;
240
+ let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
241
+ Ok(output_tensor.view().to_owned())
134
242
  }
135
243
 
136
- pub fn output_schema(&self) -> Result<ProtoIOSchema, Box<dyn std::error::Error>> {
137
- self.schema
138
- .output_schema
139
- .as_ref()
140
- .ok_or_else(|| {
141
- Box::new(std::io::Error::new(
142
- std::io::ErrorKind::NotFound,
143
- "Missing output schema",
144
- )) as Box<dyn std::error::Error>
145
- })
146
- .map(|schema| schema.clone())
244
+ pub async fn step(
245
+ &self,
246
+ carry: Array<f32, IxDyn>,
247
+ ) -> Result<(Array<f32, IxDyn>, Array<f32, IxDyn>), Box<dyn std::error::Error>> {
248
+ // Gets the model input names.
249
+ let input_names: Vec<String> = self
250
+ .step_session
251
+ .inputs
252
+ .iter()
253
+ .map(|i| i.name.clone())
254
+ .collect();
255
+
256
+ // Calls the relevant getter methods in parallel.
257
+ let mut futures = Vec::new();
258
+ for name in &input_names {
259
+ match name.as_str() {
260
+ "joint_angles" => {
261
+ futures.push(self.provider.get_joint_angles(&self.metadata.joint_names))
262
+ }
263
+ "joint_angular_velocities" => futures.push(
264
+ self.provider
265
+ .get_joint_angular_velocities(&self.metadata.joint_names),
266
+ ),
267
+ "projected_gravity" => futures.push(self.provider.get_projected_gravity()),
268
+ "accelerometer" => futures.push(self.provider.get_accelerometer()),
269
+ "gyroscope" => futures.push(self.provider.get_gyroscope()),
270
+ "command" => futures.push(self.provider.get_command()),
271
+ "carry" => futures.push(self.provider.get_carry(carry.clone())),
272
+ _ => return Err(format!("Unknown input name: {}", name).into()),
273
+ }
274
+ }
275
+
276
+ let results = future::try_join_all(futures).await?;
277
+ let mut inputs = HashMap::new();
278
+ for (name, value) in input_names.iter().zip(results) {
279
+ inputs.insert(name.clone(), value);
280
+ }
281
+
282
+ // Convert inputs to ONNX values
283
+ let mut input_values: Vec<(&str, Value)> = Vec::new();
284
+ for input in &self.step_session.inputs {
285
+ let input_data = inputs
286
+ .get(&input.name)
287
+ .ok_or_else(|| format!("Missing input: {}", input.name))?;
288
+ let input_value = Value::from_array(input_data.view())?.into_dyn();
289
+ input_values.push((input.name.as_str(), input_value));
290
+ }
291
+
292
+ // Run the model
293
+ let outputs = self.step_session.run(input_values)?;
294
+ let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
295
+ let carry_tensor = outputs[1].try_extract_tensor::<f32>()?;
296
+
297
+ Ok((
298
+ output_tensor.view().to_owned(),
299
+ carry_tensor.view().to_owned(),
300
+ ))
301
+ }
302
+
303
+ pub async fn take_action(
304
+ &self,
305
+ action: Array<f32, IxDyn>,
306
+ ) -> Result<(), Box<dyn std::error::Error>> {
307
+ self.provider
308
+ .take_action(self.metadata.joint_names.clone(), action)
309
+ .await?;
310
+ Ok(())
147
311
  }
148
- }
149
312
 
150
- fn main() -> Result<(), Box<dyn std::error::Error>> {
151
- println!("Hello, world!");
152
- Ok(())
313
+ pub async fn get_joint_angles(&self) -> Result<Array<f32, IxDyn>, Box<dyn std::error::Error>> {
314
+ let joint_names = &self.metadata.joint_names;
315
+ let joint_angles = self.provider.get_joint_angles(joint_names).await?;
316
+ Ok(joint_angles)
317
+ }
153
318
  }
@@ -0,0 +1,104 @@
1
+ use std::sync::atomic::{AtomicBool, Ordering};
2
+ use std::sync::Arc;
3
+ use tokio::runtime::Runtime;
4
+
5
+ use crate::model::{ModelError, ModelRunner};
6
+ use std::time::{Duration, Instant};
7
+ use tokio::time::sleep;
8
+
9
+ pub struct ModelRuntime {
10
+ model_runner: Arc<ModelRunner>,
11
+ dt: Duration,
12
+ slowdown_factor: i32,
13
+ magnitude_factor: f32,
14
+ running: Arc<AtomicBool>,
15
+ runtime: Option<Runtime>,
16
+ }
17
+
18
+ impl ModelRuntime {
19
+ pub fn new(model_runner: Arc<ModelRunner>, dt: u64) -> Self {
20
+ Self {
21
+ model_runner,
22
+ dt: Duration::from_millis(dt),
23
+ slowdown_factor: 1,
24
+ magnitude_factor: 1.0,
25
+ running: Arc::new(AtomicBool::new(false)),
26
+ runtime: None,
27
+ }
28
+ }
29
+
30
+ pub fn set_slowdown_factor(&mut self, slowdown_factor: i32) {
31
+ assert!(slowdown_factor >= 1);
32
+ self.slowdown_factor = slowdown_factor;
33
+ }
34
+
35
+ pub fn set_magnitude_factor(&mut self, magnitude_factor: f32) {
36
+ assert!(magnitude_factor >= 0.0);
37
+ assert!(magnitude_factor <= 1.0);
38
+ self.magnitude_factor = magnitude_factor;
39
+ }
40
+
41
+ pub fn start(&mut self) -> Result<(), ModelError> {
42
+ if self.running.load(Ordering::Relaxed) {
43
+ return Ok(());
44
+ }
45
+
46
+ let running = self.running.clone();
47
+ let model_runner = self.model_runner.clone();
48
+ let dt = self.dt;
49
+ let slowdown_factor = self.slowdown_factor;
50
+ let magnitude_factor = self.magnitude_factor;
51
+
52
+ let runtime = Runtime::new()?;
53
+ running.store(true, Ordering::Relaxed);
54
+
55
+ runtime.spawn(async move {
56
+ let mut carry = model_runner
57
+ .init()
58
+ .await
59
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
60
+ let mut joint_positions = model_runner
61
+ .get_joint_angles()
62
+ .await
63
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
64
+ let mut last_time = Instant::now();
65
+
66
+ while running.load(Ordering::Relaxed) {
67
+ let (output, next_carry) = model_runner
68
+ .step(carry)
69
+ .await
70
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
71
+ carry = next_carry;
72
+
73
+ for i in 1..(slowdown_factor + 1) {
74
+ if !running.load(Ordering::Relaxed) {
75
+ break;
76
+ }
77
+ let t = i as f32 / slowdown_factor as f32;
78
+ let interp_joint_positions = &joint_positions * (1.0 - t) + &output * t;
79
+ model_runner
80
+ .take_action(interp_joint_positions * magnitude_factor)
81
+ .await
82
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
83
+ last_time = last_time + dt;
84
+ if let Some(sleep_duration) = last_time.checked_duration_since(Instant::now()) {
85
+ sleep(sleep_duration).await;
86
+ }
87
+ }
88
+
89
+ joint_positions = output;
90
+ }
91
+ Ok::<(), ModelError>(())
92
+ });
93
+
94
+ self.runtime = Some(runtime);
95
+ Ok(())
96
+ }
97
+
98
+ pub fn stop(&mut self) {
99
+ self.running.store(false, Ordering::Relaxed);
100
+ if let Some(runtime) = self.runtime.take() {
101
+ runtime.shutdown_background();
102
+ }
103
+ }
104
+ }
@@ -15,5 +15,12 @@ name = "rust_bindings"
15
15
  crate-type = ["cdylib", "rlib"]
16
16
 
17
17
  [dependencies]
18
- pyo3 = { version = ">= 0.21.0", features = ["extension-module"] }
18
+
19
+ async-trait = "0.1"
20
+ kinfer = { path = "../rust" }
21
+ ndarray = "0.16.1"
22
+ numpy = ">= 0.24.0"
23
+ pyo3 = { version = ">= 0.24.0", features = ["extension-module"] }
24
+ pyo3-async-runtimes = { version = ">= 0.24.0", features = ["attributes", "tokio-runtime"] }
19
25
  pyo3-stub-gen = ">= 0.6.0"
26
+ tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
@@ -0,0 +1,36 @@
1
+ # This file is automatically generated by pyo3_stub_gen
2
+ # ruff: noqa: E501, F401
3
+
4
+ import builtins
5
+ import numpy
6
+ import numpy.typing
7
+ import typing
8
+
9
+ class ModelProviderABC:
10
+ def __new__(cls) -> ModelProviderABC: ...
11
+ def get_joint_angles(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
12
+ def get_joint_angular_velocities(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
13
+ def get_projected_gravity(self) -> numpy.typing.NDArray[numpy.float32]: ...
14
+ def get_accelerometer(self) -> numpy.typing.NDArray[numpy.float32]: ...
15
+ def get_gyroscope(self) -> numpy.typing.NDArray[numpy.float32]: ...
16
+ def get_command(self) -> numpy.typing.NDArray[numpy.float32]: ...
17
+ def take_action(self, joint_names:typing.Sequence[builtins.str], action:numpy.typing.NDArray[numpy.float32]) -> None: ...
18
+
19
+ class PyModelProvider:
20
+ ...
21
+
22
+ class PyModelRunner:
23
+ def __new__(cls, model_path:builtins.str, provider:ModelProviderABC) -> PyModelRunner: ...
24
+ def init(self) -> numpy.typing.NDArray[numpy.float32]: ...
25
+ def step(self, carry:numpy.typing.NDArray[numpy.float32]) -> tuple[numpy.typing.NDArray[numpy.float32], numpy.typing.NDArray[numpy.float32]]: ...
26
+ def take_action(self, action:numpy.typing.NDArray[numpy.float32]) -> None: ...
27
+
28
+ class PyModelRuntime:
29
+ def __new__(cls, model_runner:PyModelRunner, dt:builtins.int) -> PyModelRuntime: ...
30
+ def set_slowdown_factor(self, slowdown_factor:builtins.int) -> None: ...
31
+ def set_magnitude_factor(self, magnitude_factor:builtins.float) -> None: ...
32
+ def start(self) -> None: ...
33
+ def stop(self) -> None: ...
34
+
35
+ def get_version() -> builtins.str: ...
36
+