kinfer 0.3.3__cp312-cp312-macosx_11_0_arm64.whl → 0.4.0__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kinfer/__init__.py +0 -5
  2. kinfer/common/__init__.py +0 -0
  3. kinfer/common/types.py +11 -0
  4. kinfer/export/common.py +35 -0
  5. kinfer/export/jax.py +51 -0
  6. kinfer/export/pytorch.py +42 -110
  7. kinfer/export/serialize.py +86 -0
  8. kinfer/requirements.txt +3 -4
  9. kinfer/rust/Cargo.toml +8 -6
  10. kinfer/rust/src/lib.rs +2 -11
  11. kinfer/rust/src/model.rs +271 -121
  12. kinfer/rust/src/runtime.rs +104 -0
  13. kinfer/rust_bindings/Cargo.toml +8 -1
  14. kinfer/rust_bindings/rust_bindings.pyi +35 -0
  15. kinfer/rust_bindings/src/lib.rs +310 -1
  16. kinfer/rust_bindings.cpython-312-darwin.so +0 -0
  17. kinfer/rust_bindings.pyi +29 -1
  18. kinfer-0.4.0.dist-info/METADATA +55 -0
  19. kinfer-0.4.0.dist-info/RECORD +26 -0
  20. {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/WHEEL +2 -1
  21. kinfer/inference/__init__.py +0 -2
  22. kinfer/inference/base.py +0 -64
  23. kinfer/inference/python.py +0 -66
  24. kinfer/proto/__init__.py +0 -40
  25. kinfer/proto/kinfer_pb2.py +0 -103
  26. kinfer/proto/kinfer_pb2.pyi +0 -1097
  27. kinfer/requirements-dev.txt +0 -8
  28. kinfer/rust/build.rs +0 -16
  29. kinfer/rust/src/kinfer_proto.rs +0 -14
  30. kinfer/rust/src/main.rs +0 -6
  31. kinfer/rust/src/onnx_serializer.rs +0 -804
  32. kinfer/rust/src/serializer.rs +0 -221
  33. kinfer/rust/src/tests/onnx_serializer_tests.rs +0 -212
  34. kinfer/serialize/__init__.py +0 -60
  35. kinfer/serialize/base.py +0 -536
  36. kinfer/serialize/json.py +0 -399
  37. kinfer/serialize/numpy.py +0 -426
  38. kinfer/serialize/pytorch.py +0 -402
  39. kinfer/serialize/schema.py +0 -125
  40. kinfer/serialize/types.py +0 -17
  41. kinfer/serialize/utils.py +0 -177
  42. kinfer-0.3.3.dist-info/METADATA +0 -57
  43. kinfer-0.3.3.dist-info/RECORD +0 -40
  44. {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info/licenses}/LICENSE +0 -0
  45. {kinfer-0.3.3.dist-info → kinfer-0.4.0.dist-info}/top_level.txt +0 -0
kinfer/rust/src/model.rs CHANGED
@@ -1,153 +1,303 @@
1
- use crate::kinfer_proto::{ModelSchema, ProtoIO, ProtoIOSchema};
2
- use crate::onnx_serializer::OnnxMultiSerializer;
1
+ use async_trait::async_trait;
2
+ use flate2::read::GzDecoder;
3
+ use futures_util::future;
4
+ use ndarray::{Array, IxDyn};
5
+ use ort::session::Session;
6
+ use ort::value::Value;
7
+ use serde::Deserialize;
8
+ use std::collections::HashMap;
9
+ use std::io::Read;
3
10
  use std::path::Path;
11
+ use std::sync::Arc;
12
+ use tar::Archive;
13
+ use tokio::fs::File;
14
+ use tokio::io::AsyncReadExt;
4
15
 
5
- use ort::session::builder::GraphOptimizationLevel;
6
- use prost::Message;
7
- use ort::{session::Session, Error as OrtError};
16
+ #[derive(Debug, Deserialize)]
17
+ struct ModelMetadata {
18
+ joint_names: Vec<String>,
19
+ }
8
20
 
9
- pub fn load_onnx_model<P: AsRef<Path>>(model_path: P) -> Result<Session, OrtError> {
10
- let model = Session::builder()?
11
- .with_optimization_level(GraphOptimizationLevel::Level3)?
12
- .with_intra_threads(4)?
13
- .commit_from_file(model_path)?;
21
+ impl ModelMetadata {
22
+ fn model_validate_json(json: String) -> Result<Self, Box<dyn std::error::Error>> {
23
+ Ok(serde_json::from_str(&json)?)
24
+ }
25
+ }
14
26
 
15
- Ok(model)
27
+ #[derive(Debug, thiserror::Error)]
28
+ pub enum ModelError {
29
+ #[error("IO error: {0}")]
30
+ Io(#[from] std::io::Error),
31
+ #[error("Provider error: {0}")]
32
+ Provider(String),
16
33
  }
17
34
 
18
- const KINFER_METADATA_KEY: &str = "kinfer_metadata";
35
+ #[async_trait]
36
+ pub trait ModelProvider: Send + Sync {
37
+ async fn get_joint_angles(
38
+ &self,
39
+ joint_names: &[String],
40
+ ) -> Result<Array<f32, IxDyn>, ModelError>;
41
+ async fn get_joint_angular_velocities(
42
+ &self,
43
+ joint_names: &[String],
44
+ ) -> Result<Array<f32, IxDyn>, ModelError>;
45
+ async fn get_projected_gravity(&self) -> Result<Array<f32, IxDyn>, ModelError>;
46
+ async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError>;
47
+ async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError>;
48
+ async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError>;
49
+ async fn take_action(
50
+ &self,
51
+ joint_names: Vec<String>,
52
+ action: Array<f32, IxDyn>,
53
+ ) -> Result<(), ModelError>;
54
+ }
19
55
 
20
56
  pub struct ModelRunner {
21
- session: Session,
22
- attached_metadata: std::collections::HashMap<String, String>,
23
- schema: ModelSchema,
24
- input_serializer: OnnxMultiSerializer,
25
- output_serializer: OnnxMultiSerializer,
57
+ init_session: Session,
58
+ step_session: Session,
59
+ metadata: ModelMetadata,
60
+ provider: Arc<dyn ModelProvider>,
26
61
  }
27
62
 
28
63
  impl ModelRunner {
29
- pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self, Box<dyn std::error::Error>> {
30
- let session = load_onnx_model(model_path)?;
31
- let mut attached_metadata = std::collections::HashMap::new();
32
-
33
- // Extract metadata and attempt to parse schema
34
- let mut schema = None;
35
- {
36
- let metadata = session.metadata()?;
37
- for prop in metadata.custom_keys()? {
38
- if prop == KINFER_METADATA_KEY {
39
- let schema_bytes = metadata.custom(prop.as_str())?;
40
- if let Some(bytes) = schema_bytes {
41
- schema = Some(ModelSchema::decode(&mut bytes.as_bytes())?);
42
- }
43
- } else {
44
- attached_metadata.insert(
45
- prop.to_string(),
46
- metadata
47
- .custom(prop.as_str())?
48
- .map_or_else(String::new, |s| s.to_string()),
49
- );
64
+ pub async fn new<P: AsRef<Path>>(
65
+ model_path: P,
66
+ input_provider: Arc<dyn ModelProvider>,
67
+ ) -> Result<Self, Box<dyn std::error::Error>> {
68
+ let mut file = File::open(model_path).await?;
69
+
70
+ // Read entire file into memory
71
+ let mut buffer = Vec::new();
72
+ file.read_to_end(&mut buffer).await?;
73
+
74
+ // Decompress and read the tar archive from memory
75
+ let gz = GzDecoder::new(&buffer[..]);
76
+ let mut archive = Archive::new(gz);
77
+
78
+ // Extract and validate joint names
79
+ let mut metadata: Option<String> = None;
80
+ let mut init_fn: Option<Vec<u8>> = None;
81
+ let mut step_fn: Option<Vec<u8>> = None;
82
+
83
+ for entry in archive.entries()? {
84
+ let mut entry = entry?;
85
+ let path = entry.path()?;
86
+ let path_str = path.to_string_lossy();
87
+
88
+ match path_str.as_ref() {
89
+ "metadata.json" => {
90
+ let mut contents = String::new();
91
+ entry.read_to_string(&mut contents)?;
92
+ metadata = Some(contents);
50
93
  }
94
+ "init_fn.onnx" => {
95
+ let size = entry.size() as usize;
96
+ let mut contents = vec![0u8; size];
97
+ entry.read_exact(&mut contents)?;
98
+ assert_eq!(contents.len(), entry.size() as usize);
99
+ init_fn = Some(contents);
100
+ }
101
+ "step_fn.onnx" => {
102
+ let size = entry.size() as usize;
103
+ let mut contents = vec![0u8; size];
104
+ entry.read_exact(&mut contents)?;
105
+ assert_eq!(contents.len(), entry.size() as usize);
106
+ step_fn = Some(contents);
107
+ }
108
+ _ => return Err("Unknown entry".into()),
51
109
  }
52
110
  }
53
111
 
54
- let schema: ModelSchema = schema.ok_or_else(|| "kinfer_metadata not found in model metadata")?;
112
+ // Reads the files.
113
+ let metadata = ModelMetadata::model_validate_json(
114
+ metadata.ok_or("metadata.json not found in archive")?,
115
+ )?;
116
+ let init_session = Session::builder()?
117
+ .commit_from_memory(&init_fn.ok_or("init_fn.onnx not found in archive")?)?;
118
+ let step_session = Session::builder()?
119
+ .commit_from_memory(&step_fn.ok_or("step_fn.onnx not found in archive")?)?;
55
120
 
56
- // Use as_ref() to borrow the Option contents and clone after ok_or
57
- let input_schema = schema
58
- .input_schema
59
- .as_ref()
60
- .ok_or("Missing input schema")?
61
- .clone();
62
- let output_schema = schema
63
- .output_schema
64
- .as_ref()
65
- .ok_or("Missing output schema")?
66
- .clone();
121
+ // Validate init_fn has no inputs and one output
122
+ if !init_session.inputs.is_empty() {
123
+ return Err("init_fn should not have any inputs".into());
124
+ }
125
+ if init_session.outputs.len() != 1 {
126
+ return Err("init_fn should have exactly one output".into());
127
+ }
128
+
129
+ // Get carry shape from init_fn output
130
+ let carry_shape = init_session.outputs[0]
131
+ .output_type
132
+ .tensor_dimensions()
133
+ .ok_or("Missing tensor type")?
134
+ .to_vec();
67
135
 
68
- // Create serializers for input and output
69
- let input_serializer = OnnxMultiSerializer::new(input_schema);
70
- let output_serializer = OnnxMultiSerializer::new(output_schema);
136
+ // Validate step_fn inputs and outputs
137
+ Self::validate_step_fn(&step_session, metadata.joint_names.len(), &carry_shape)?;
71
138
 
72
139
  Ok(Self {
73
- session,
74
- attached_metadata,
75
- schema,
76
- input_serializer,
77
- output_serializer,
140
+ init_session,
141
+ step_session,
142
+ metadata,
143
+ provider: input_provider,
78
144
  })
79
145
  }
80
146
 
81
- pub fn run(&self, inputs: ProtoIO) -> Result<ProtoIO, Box<dyn std::error::Error>> {
82
- // Serialize inputs to ONNX format
83
- let inputs = self.input_serializer.serialize_io(inputs)?;
147
+ fn validate_step_fn(
148
+ session: &Session,
149
+ num_joints: usize,
150
+ carry_shape: &[i64],
151
+ ) -> Result<(), Box<dyn std::error::Error>> {
152
+ // Validate inputs
153
+ for input in &session.inputs {
154
+ let dims = input.input_type.tensor_dimensions().ok_or(format!(
155
+ "Input {} is not a tensor with known dimensions",
156
+ input.name
157
+ ))?;
84
158
 
85
- // Get input names from the session
86
- let input_names = self
87
- .session
88
- .inputs
89
- .iter()
90
- .map(|input| input.name.as_str())
91
- .collect::<Vec<_>>();
92
-
93
- // Create input name-value pairs
94
- let input_values = vec![(input_names[0], inputs)];
95
-
96
- let outputs = self.session.run(input_values)?;
97
-
98
- let output_values = outputs
99
- .values()
100
- .map(|v: ort::value::ValueRef<'_>| {
101
- v.try_upgrade().map_err(|e| {
102
- Box::new(std::io::Error::new(
103
- std::io::ErrorKind::Other,
104
- format!("Failed to upgrade value"),
105
- )) as Box<dyn std::error::Error>
106
- })
107
- })
108
- .collect::<Result<Vec<_>, _>>()?;
109
- // Deserialize outputs from ONNX format
110
- let outputs = self.output_serializer.deserialize_io(output_values)?;
111
-
112
- Ok(outputs)
113
- }
159
+ match input.name.as_str() {
160
+ "joint_angles" | "joint_angular_velocities" => {
161
+ if *dims != vec![num_joints as i64] {
162
+ return Err(format!(
163
+ "Expected shape [{num_joints}] for input `{}`, got {:?}",
164
+ input.name, dims
165
+ )
166
+ .into());
167
+ }
168
+ }
169
+ "projected_gravity" | "accelerometer" | "gyroscope" => {
170
+ if *dims != vec![3] {
171
+ return Err(format!(
172
+ "Expected shape [3] for input `{}`, got {:?}",
173
+ input.name, dims
174
+ )
175
+ .into());
176
+ }
177
+ }
178
+ "carry" => {
179
+ if dims != carry_shape {
180
+ return Err(format!(
181
+ "Expected shape {:?} for input `carry`, got {:?}",
182
+ carry_shape, dims
183
+ )
184
+ .into());
185
+ }
186
+ }
187
+ _ => return Err(format!("Unknown input name: {}", input.name).into()),
188
+ }
189
+ }
190
+
191
+ // Validate outputs
192
+ if session.outputs.len() != 2 {
193
+ return Err("Step function must have exactly 2 outputs".into());
194
+ }
195
+
196
+ let output_shape = session.outputs[0]
197
+ .output_type
198
+ .tensor_dimensions()
199
+ .ok_or("Missing tensor type")?;
200
+ if *output_shape != vec![num_joints as i64] {
201
+ return Err(format!(
202
+ "Expected output shape [{num_joints}], got {:?}",
203
+ output_shape
204
+ )
205
+ .into());
206
+ }
207
+
208
+ let infered_carry_shape = session.outputs[1]
209
+ .output_type
210
+ .tensor_dimensions()
211
+ .ok_or("Missing tensor type")?;
212
+ if *infered_carry_shape != *carry_shape {
213
+ return Err(format!(
214
+ "Expected carry shape {:?}, got {:?}",
215
+ carry_shape, infered_carry_shape
216
+ )
217
+ .into());
218
+ }
114
219
 
115
- pub fn export_model<P: AsRef<Path>>(&self, model_path: P) -> Result<(), Box<dyn std::error::Error>> {
116
- let model_bytes = self.session.model_as_bytes()?;
117
- let mut model = ModelProto::decode(&mut model_bytes.as_slice())?;
118
- model.set_metadata_props(self.schema.encode_to_vec())?;
119
- std::fs::write(model_path, model.write_to_bytes()?)?;
120
220
  Ok(())
121
221
  }
122
222
 
123
- pub fn input_schema(&self) -> Result<ProtoIOSchema, Box<dyn std::error::Error>> {
124
- self.schema
125
- .input_schema
126
- .as_ref()
127
- .ok_or_else(|| {
128
- Box::new(std::io::Error::new(
129
- std::io::ErrorKind::NotFound,
130
- "Missing input schema",
131
- )) as Box<dyn std::error::Error>
132
- })
133
- .map(|schema| schema.clone())
223
+ pub async fn init(&self) -> Result<Array<f32, IxDyn>, Box<dyn std::error::Error>> {
224
+ let input_values: Vec<(&str, Value)> = Vec::new();
225
+ let outputs = self.init_session.run(input_values)?;
226
+ let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
227
+ Ok(output_tensor.view().to_owned())
228
+ }
229
+
230
+ pub async fn step(
231
+ &self,
232
+ carry: Array<f32, IxDyn>,
233
+ ) -> Result<(Array<f32, IxDyn>, Array<f32, IxDyn>), Box<dyn std::error::Error>> {
234
+ // Gets the model input names.
235
+ let input_names: Vec<String> = self
236
+ .step_session
237
+ .inputs
238
+ .iter()
239
+ .map(|i| i.name.clone())
240
+ .collect();
241
+
242
+ // Calls the relevant getter methods in parallel.
243
+ let mut futures = Vec::new();
244
+ for name in &input_names {
245
+ match name.as_str() {
246
+ "joint_angles" => {
247
+ futures.push(self.provider.get_joint_angles(&self.metadata.joint_names))
248
+ }
249
+ "joint_angular_velocities" => futures.push(
250
+ self.provider
251
+ .get_joint_angular_velocities(&self.metadata.joint_names),
252
+ ),
253
+ "projected_gravity" => futures.push(self.provider.get_projected_gravity()),
254
+ "accelerometer" => futures.push(self.provider.get_accelerometer()),
255
+ "gyroscope" => futures.push(self.provider.get_gyroscope()),
256
+ "carry" => futures.push(self.provider.get_carry(carry.clone())),
257
+ _ => return Err(format!("Unknown input name: {}", name).into()),
258
+ }
259
+ }
260
+
261
+ let results = future::try_join_all(futures).await?;
262
+ let mut inputs = HashMap::new();
263
+ for (name, value) in input_names.iter().zip(results) {
264
+ inputs.insert(name.clone(), value);
265
+ }
266
+
267
+ // Convert inputs to ONNX values
268
+ let mut input_values: Vec<(&str, Value)> = Vec::new();
269
+ for input in &self.step_session.inputs {
270
+ let input_data = inputs
271
+ .get(&input.name)
272
+ .ok_or_else(|| format!("Missing input: {}", input.name))?;
273
+ let input_value = Value::from_array(input_data.view())?.into_dyn();
274
+ input_values.push((input.name.as_str(), input_value));
275
+ }
276
+
277
+ // Run the model
278
+ let outputs = self.step_session.run(input_values)?;
279
+ let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
280
+ let carry_tensor = outputs[1].try_extract_tensor::<f32>()?;
281
+
282
+ Ok((
283
+ output_tensor.view().to_owned(),
284
+ carry_tensor.view().to_owned(),
285
+ ))
134
286
  }
135
287
 
136
- pub fn output_schema(&self) -> Result<ProtoIOSchema, Box<dyn std::error::Error>> {
137
- self.schema
138
- .output_schema
139
- .as_ref()
140
- .ok_or_else(|| {
141
- Box::new(std::io::Error::new(
142
- std::io::ErrorKind::NotFound,
143
- "Missing output schema",
144
- )) as Box<dyn std::error::Error>
145
- })
146
- .map(|schema| schema.clone())
288
+ pub async fn take_action(
289
+ &self,
290
+ action: Array<f32, IxDyn>,
291
+ ) -> Result<(), Box<dyn std::error::Error>> {
292
+ self.provider
293
+ .take_action(self.metadata.joint_names.clone(), action)
294
+ .await?;
295
+ Ok(())
147
296
  }
148
- }
149
297
 
150
- fn main() -> Result<(), Box<dyn std::error::Error>> {
151
- println!("Hello, world!");
152
- Ok(())
298
+ pub async fn get_joint_angles(&self) -> Result<Array<f32, IxDyn>, Box<dyn std::error::Error>> {
299
+ let joint_names = &self.metadata.joint_names;
300
+ let joint_angles = self.provider.get_joint_angles(joint_names).await?;
301
+ Ok(joint_angles)
302
+ }
153
303
  }
@@ -0,0 +1,104 @@
1
+ use std::sync::atomic::{AtomicBool, Ordering};
2
+ use std::sync::Arc;
3
+ use tokio::runtime::Runtime;
4
+
5
+ use crate::model::{ModelError, ModelRunner};
6
+ use std::time::{Duration, Instant};
7
+ use tokio::time::sleep;
8
+
9
+ pub struct ModelRuntime {
10
+ model_runner: Arc<ModelRunner>,
11
+ dt: Duration,
12
+ slowdown_factor: i32,
13
+ magnitude_factor: f32,
14
+ running: Arc<AtomicBool>,
15
+ runtime: Option<Runtime>,
16
+ }
17
+
18
+ impl ModelRuntime {
19
+ pub fn new(model_runner: Arc<ModelRunner>, dt: u64) -> Self {
20
+ Self {
21
+ model_runner,
22
+ dt: Duration::from_millis(dt),
23
+ slowdown_factor: 1,
24
+ magnitude_factor: 1.0,
25
+ running: Arc::new(AtomicBool::new(false)),
26
+ runtime: None,
27
+ }
28
+ }
29
+
30
+ pub fn set_slowdown_factor(&mut self, slowdown_factor: i32) {
31
+ assert!(slowdown_factor >= 1);
32
+ self.slowdown_factor = slowdown_factor;
33
+ }
34
+
35
+ pub fn set_magnitude_factor(&mut self, magnitude_factor: f32) {
36
+ assert!(magnitude_factor >= 0.0);
37
+ assert!(magnitude_factor <= 1.0);
38
+ self.magnitude_factor = magnitude_factor;
39
+ }
40
+
41
+ pub fn start(&mut self) -> Result<(), ModelError> {
42
+ if self.running.load(Ordering::Relaxed) {
43
+ return Ok(());
44
+ }
45
+
46
+ let running = self.running.clone();
47
+ let model_runner = self.model_runner.clone();
48
+ let dt = self.dt;
49
+ let slowdown_factor = self.slowdown_factor;
50
+ let magnitude_factor = self.magnitude_factor;
51
+
52
+ let runtime = Runtime::new()?;
53
+ running.store(true, Ordering::Relaxed);
54
+
55
+ runtime.spawn(async move {
56
+ let mut carry = model_runner
57
+ .init()
58
+ .await
59
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
60
+ let mut joint_positions = model_runner
61
+ .get_joint_angles()
62
+ .await
63
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
64
+ let mut last_time = Instant::now();
65
+
66
+ while running.load(Ordering::Relaxed) {
67
+ let (output, next_carry) = model_runner
68
+ .step(carry)
69
+ .await
70
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
71
+ carry = next_carry;
72
+
73
+ for i in 1..(slowdown_factor + 1) {
74
+ if !running.load(Ordering::Relaxed) {
75
+ break;
76
+ }
77
+ let t = i as f32 / slowdown_factor as f32;
78
+ let interp_joint_positions = &joint_positions * (1.0 - t) + &output * t;
79
+ model_runner
80
+ .take_action(interp_joint_positions * magnitude_factor)
81
+ .await
82
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
83
+ last_time = last_time + dt;
84
+ if let Some(sleep_duration) = last_time.checked_duration_since(Instant::now()) {
85
+ sleep(sleep_duration).await;
86
+ }
87
+ }
88
+
89
+ joint_positions = output;
90
+ }
91
+ Ok::<(), ModelError>(())
92
+ });
93
+
94
+ self.runtime = Some(runtime);
95
+ Ok(())
96
+ }
97
+
98
+ pub fn stop(&mut self) {
99
+ self.running.store(false, Ordering::Relaxed);
100
+ if let Some(runtime) = self.runtime.take() {
101
+ runtime.shutdown_background();
102
+ }
103
+ }
104
+ }
@@ -15,5 +15,12 @@ name = "rust_bindings"
15
15
  crate-type = ["cdylib", "rlib"]
16
16
 
17
17
  [dependencies]
18
- pyo3 = { version = ">= 0.21.0", features = ["extension-module"] }
18
+
19
+ async-trait = "0.1"
20
+ kinfer = { path = "../rust" }
21
+ ndarray = "0.16.1"
22
+ numpy = ">= 0.24.0"
23
+ pyo3 = { version = ">= 0.24.0", features = ["extension-module"] }
24
+ pyo3-async-runtimes = { version = ">= 0.24.0", features = ["attributes", "tokio-runtime"] }
19
25
  pyo3-stub-gen = ">= 0.6.0"
26
+ tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
@@ -0,0 +1,35 @@
1
+ # This file is automatically generated by pyo3_stub_gen
2
+ # ruff: noqa: E501, F401
3
+
4
+ import builtins
5
+ import numpy
6
+ import numpy.typing
7
+ import typing
8
+
9
+ class ModelProviderABC:
10
+ def __new__(cls) -> ModelProviderABC: ...
11
+ def get_joint_angles(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
12
+ def get_joint_angular_velocities(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
13
+ def get_projected_gravity(self) -> numpy.typing.NDArray[numpy.float32]: ...
14
+ def get_accelerometer(self) -> numpy.typing.NDArray[numpy.float32]: ...
15
+ def get_gyroscope(self) -> numpy.typing.NDArray[numpy.float32]: ...
16
+ def take_action(self, joint_names:typing.Sequence[builtins.str], action:numpy.typing.NDArray[numpy.float32]) -> None: ...
17
+
18
+ class PyModelProvider:
19
+ ...
20
+
21
+ class PyModelRunner:
22
+ def __new__(cls, model_path:builtins.str, provider:ModelProviderABC) -> PyModelRunner: ...
23
+ def init(self) -> numpy.typing.NDArray[numpy.float32]: ...
24
+ def step(self, carry:numpy.typing.NDArray[numpy.float32]) -> tuple[numpy.typing.NDArray[numpy.float32], numpy.typing.NDArray[numpy.float32]]: ...
25
+ def take_action(self, action:numpy.typing.NDArray[numpy.float32]) -> None: ...
26
+
27
+ class PyModelRuntime:
28
+ def __new__(cls, model_runner:PyModelRunner, dt:builtins.int) -> PyModelRuntime: ...
29
+ def set_slowdown_factor(self, slowdown_factor:builtins.int) -> None: ...
30
+ def set_magnitude_factor(self, magnitude_factor:builtins.float) -> None: ...
31
+ def start(self) -> None: ...
32
+ def stop(self) -> None: ...
33
+
34
+ def get_version() -> builtins.str: ...
35
+