kinfer 0.5.4__cp311-cp311-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,354 @@
1
+ use crate::logger::StepLogger;
2
+ use crate::types::{InputType, ModelMetadata};
3
+ use async_trait::async_trait;
4
+ use chrono;
5
+ use flate2::read::GzDecoder;
6
+ use ndarray::{Array, IxDyn};
7
+ use ort::session::Session;
8
+ use ort::value::Value;
9
+ use std::collections::HashMap;
10
+ use std::io::Read;
11
+ use std::path::Path;
12
+ use std::sync::Arc;
13
+ use tar::Archive;
14
+ use tokio::fs::File;
15
+ use tokio::io::AsyncReadExt;
16
+
17
+ #[derive(Debug, thiserror::Error)]
18
+ pub enum ModelError {
19
+ #[error("IO error: {0}")]
20
+ Io(#[from] std::io::Error),
21
+ #[error("Provider error: {0}")]
22
+ Provider(String),
23
+ }
24
+
25
+ #[async_trait]
26
+ pub trait ModelProvider: Send + Sync {
27
+ async fn get_inputs(
28
+ &self,
29
+ input_types: &[InputType],
30
+ metadata: &ModelMetadata,
31
+ ) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError>;
32
+
33
+ async fn take_action(
34
+ &self,
35
+ action: Array<f32, IxDyn>,
36
+ metadata: &ModelMetadata,
37
+ ) -> Result<(), ModelError>;
38
+ }
39
+
40
+ pub struct ModelRunner {
41
+ init_session: Session,
42
+ step_session: Session,
43
+ metadata: ModelMetadata,
44
+ provider: Arc<dyn ModelProvider>,
45
+ logger: Option<Arc<StepLogger>>,
46
+ }
47
+
48
+ impl ModelRunner {
49
+ pub async fn new<P: AsRef<Path>>(
50
+ model_path: P,
51
+ input_provider: Arc<dyn ModelProvider>,
52
+ ) -> Result<Self, Box<dyn std::error::Error>> {
53
+ let mut file = File::open(model_path).await?;
54
+
55
+ // Read entire file into memory
56
+ let mut buffer = Vec::new();
57
+ file.read_to_end(&mut buffer).await?;
58
+
59
+ // Decompress and read the tar archive from memory
60
+ let gz = GzDecoder::new(&buffer[..]);
61
+ let mut archive = Archive::new(gz);
62
+
63
+ // Extract and validate joint names
64
+ let mut metadata: Option<String> = None;
65
+ let mut init_fn: Option<Vec<u8>> = None;
66
+ let mut step_fn: Option<Vec<u8>> = None;
67
+
68
+ for entry in archive.entries()? {
69
+ let mut entry = entry?;
70
+ let path = entry.path()?;
71
+ let path_str = path.to_string_lossy();
72
+
73
+ match path_str.as_ref() {
74
+ "metadata.json" => {
75
+ let mut contents = String::new();
76
+ entry.read_to_string(&mut contents)?;
77
+ metadata = Some(contents);
78
+ }
79
+ "init_fn.onnx" => {
80
+ let size = entry.size() as usize;
81
+ let mut contents = vec![0u8; size];
82
+ entry.read_exact(&mut contents)?;
83
+ assert_eq!(contents.len(), entry.size() as usize);
84
+ init_fn = Some(contents);
85
+ }
86
+ "step_fn.onnx" => {
87
+ let size = entry.size() as usize;
88
+ let mut contents = vec![0u8; size];
89
+ entry.read_exact(&mut contents)?;
90
+ assert_eq!(contents.len(), entry.size() as usize);
91
+ step_fn = Some(contents);
92
+ }
93
+ _ => return Err("Unknown entry".into()),
94
+ }
95
+ }
96
+
97
+ // Reads the files.
98
+ let metadata = ModelMetadata::model_validate_json(
99
+ metadata.ok_or("metadata.json not found in archive")?,
100
+ )?;
101
+ let init_session = Session::builder()?
102
+ .commit_from_memory(&init_fn.ok_or("init_fn.onnx not found in archive")?)?;
103
+ let step_session = Session::builder()?
104
+ .commit_from_memory(&step_fn.ok_or("step_fn.onnx not found in archive")?)?;
105
+
106
+ // Validate init_fn has no inputs and one output
107
+ if !init_session.inputs.is_empty() {
108
+ return Err("init_fn should not have any inputs".into());
109
+ }
110
+ if init_session.outputs.len() != 1 {
111
+ return Err("init_fn should have exactly one output".into());
112
+ }
113
+
114
+ // Get carry shape from init_fn output
115
+ let carry_shape = init_session.outputs[0]
116
+ .output_type
117
+ .tensor_dimensions()
118
+ .ok_or("Missing tensor type")?
119
+ .to_vec();
120
+
121
+ // Validate step_fn inputs and outputs
122
+ Self::validate_step_fn(&step_session, &metadata, &carry_shape)?;
123
+
124
+ let logger = if let Ok(log_dir) = std::env::var("KINFER_LOG_PATH") {
125
+ let log_dir_path = std::path::Path::new(&log_dir);
126
+
127
+ // Create the directory if it doesn't exist
128
+ if !log_dir_path.exists() {
129
+ std::fs::create_dir_all(log_dir_path)?;
130
+ }
131
+
132
+ // Use uuid if found, otherwise timestamp
133
+ let log_name = std::env::var("KINFER_LOG_UUID").unwrap_or_else(|_| {
134
+ chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S").to_string()
135
+ });
136
+
137
+ let log_file_path = log_dir_path.join(format!("{}.ndjson", log_name));
138
+
139
+ Some(StepLogger::new(log_file_path).map(Arc::new)?)
140
+ } else {
141
+ None
142
+ };
143
+
144
+ Ok(Self {
145
+ init_session,
146
+ step_session,
147
+ metadata,
148
+ provider: input_provider,
149
+ logger,
150
+ })
151
+ }
152
+
153
+ fn validate_step_fn(
154
+ session: &Session,
155
+ metadata: &ModelMetadata,
156
+ carry_shape: &[i64],
157
+ ) -> Result<(), Box<dyn std::error::Error>> {
158
+ // Validate inputs
159
+ for input in &session.inputs {
160
+ let dims = input.input_type.tensor_dimensions().ok_or(format!(
161
+ "Input {} is not a tensor with known dimensions",
162
+ input.name
163
+ ))?;
164
+
165
+ let input_type = InputType::from_name(&input.name)?;
166
+ let expected_shape = input_type.get_shape(metadata);
167
+ let expected_shape_i64: Vec<i64> = expected_shape.iter().map(|&x| x as i64).collect();
168
+ if *dims != expected_shape_i64 {
169
+ return Err(format!(
170
+ "Expected input shape {:?}, got {:?}",
171
+ expected_shape_i64, dims
172
+ )
173
+ .into());
174
+ }
175
+ }
176
+
177
+ // Validate outputs
178
+ if session.outputs.len() != 2 {
179
+ return Err("Step function must have exactly 2 outputs".into());
180
+ }
181
+
182
+ let output_shape = session.outputs[0]
183
+ .output_type
184
+ .tensor_dimensions()
185
+ .ok_or("Missing tensor type")?;
186
+ let num_joints = metadata.joint_names.len();
187
+ if *output_shape != vec![num_joints as i64] {
188
+ return Err(format!(
189
+ "Expected output shape [{num_joints}], got {:?}",
190
+ output_shape
191
+ )
192
+ .into());
193
+ }
194
+
195
+ let infered_carry_shape = session.outputs[1]
196
+ .output_type
197
+ .tensor_dimensions()
198
+ .ok_or("Missing tensor type")?;
199
+ if *infered_carry_shape != *carry_shape {
200
+ return Err(format!(
201
+ "Expected carry shape {:?}, got {:?}",
202
+ carry_shape, infered_carry_shape
203
+ )
204
+ .into());
205
+ }
206
+
207
+ Ok(())
208
+ }
209
+
210
+ pub async fn get_inputs(
211
+ &self,
212
+ input_types: &[InputType],
213
+ ) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError> {
214
+ self.provider.get_inputs(input_types, &self.metadata).await
215
+ }
216
+
217
+ pub async fn init(&self) -> Result<Array<f32, IxDyn>, Box<dyn std::error::Error>> {
218
+ let input_values: Vec<(&str, Value)> = Vec::new();
219
+ let outputs = self.init_session.run(input_values)?;
220
+ let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
221
+ Ok(output_tensor.view().to_owned())
222
+ }
223
+
224
+ pub async fn step(
225
+ &self,
226
+ carry: Array<f32, IxDyn>,
227
+ ) -> Result<(Array<f32, IxDyn>, Array<f32, IxDyn>), Box<dyn std::error::Error>> {
228
+ // Gets the model input names.
229
+ let input_names: Vec<String> = self
230
+ .step_session
231
+ .inputs
232
+ .iter()
233
+ .map(|i| i.name.clone())
234
+ .collect();
235
+
236
+ // Calls the relevant getter methods in parallel.
237
+ let mut input_types = Vec::new();
238
+ let mut inputs = HashMap::new();
239
+ for name in &input_names {
240
+ match name.as_str() {
241
+ "joint_angles" => {
242
+ input_types.push(InputType::JointAngles);
243
+ }
244
+ "joint_angular_velocities" => {
245
+ input_types.push(InputType::JointAngularVelocities);
246
+ }
247
+ "initial_heading" => {
248
+ input_types.push(InputType::InitialHeading);
249
+ }
250
+ "quaternion" => {
251
+ input_types.push(InputType::Quaternion);
252
+ }
253
+ "projected_gravity" => {
254
+ input_types.push(InputType::ProjectedGravity);
255
+ }
256
+ "accelerometer" => {
257
+ input_types.push(InputType::Accelerometer);
258
+ }
259
+ "gyroscope" => {
260
+ input_types.push(InputType::Gyroscope);
261
+ }
262
+ "command" => {
263
+ input_types.push(InputType::Command);
264
+ }
265
+ "time" => {
266
+ input_types.push(InputType::Time);
267
+ }
268
+ "carry" => {
269
+ inputs.insert(InputType::Carry, carry.clone());
270
+ }
271
+ _ => return Err(format!("Unknown input name: {}", name).into()),
272
+ }
273
+ }
274
+
275
+ // Gets the input values.
276
+ let result = self
277
+ .provider
278
+ .get_inputs(&input_types, &self.metadata)
279
+ .await?;
280
+
281
+ // Adds the input values to the input map.
282
+ inputs.extend(result);
283
+
284
+ // Convert inputs to ONNX values
285
+ let mut input_values: Vec<(&str, Value)> = Vec::new();
286
+ for input in &self.step_session.inputs {
287
+ let input_type = InputType::from_name(&input.name)?;
288
+ let input_data = inputs
289
+ .get(&input_type)
290
+ .ok_or_else(|| format!("Missing input: {}", input.name))?;
291
+ let input_value = Value::from_array(input_data.view())?.into_dyn();
292
+ input_values.push((input.name.as_str(), input_value));
293
+ }
294
+
295
+ // Run the model
296
+ let outputs = self.step_session.run(input_values)?;
297
+ let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
298
+ let carry_tensor = outputs[1].try_extract_tensor::<f32>()?;
299
+
300
+ // Log the step if needed
301
+ if let Some(lg) = &self.logger {
302
+ let joint_angles_opt = inputs
303
+ .get(&InputType::JointAngles)
304
+ .map(|a| a.as_slice().unwrap());
305
+ let joint_vels_opt = inputs
306
+ .get(&InputType::JointAngularVelocities)
307
+ .map(|a| a.as_slice().unwrap());
308
+ let initial_heading_opt = inputs
309
+ .get(&InputType::InitialHeading)
310
+ .map(|a| a.as_slice().unwrap());
311
+ let quaternion_opt = inputs
312
+ .get(&InputType::Quaternion)
313
+ .map(|a| a.as_slice().unwrap());
314
+ let projected_g_opt = inputs
315
+ .get(&InputType::ProjectedGravity)
316
+ .map(|a| a.as_slice().unwrap());
317
+ let accel_opt = inputs
318
+ .get(&InputType::Accelerometer)
319
+ .map(|a| a.as_slice().unwrap());
320
+ let gyro_opt = inputs
321
+ .get(&InputType::Gyroscope)
322
+ .map(|a| a.as_slice().unwrap());
323
+ let command_opt = inputs
324
+ .get(&InputType::Command)
325
+ .map(|a| a.as_slice().unwrap());
326
+ let output_opt = Some(output_tensor.as_slice().unwrap());
327
+
328
+ lg.log_step(
329
+ joint_angles_opt,
330
+ joint_vels_opt,
331
+ initial_heading_opt,
332
+ quaternion_opt,
333
+ projected_g_opt,
334
+ accel_opt,
335
+ gyro_opt,
336
+ command_opt,
337
+ output_opt,
338
+ );
339
+ }
340
+
341
+ Ok((
342
+ output_tensor.view().to_owned(),
343
+ carry_tensor.view().to_owned(),
344
+ ))
345
+ }
346
+
347
+ pub async fn take_action(
348
+ &self,
349
+ action: Array<f32, IxDyn>,
350
+ ) -> Result<(), Box<dyn std::error::Error>> {
351
+ self.provider.take_action(action, &self.metadata).await?;
352
+ Ok(())
353
+ }
354
+ }
@@ -0,0 +1,107 @@
1
+ use std::sync::atomic::{AtomicBool, Ordering};
2
+ use std::sync::Arc;
3
+ use tokio::runtime::Runtime;
4
+
5
+ use crate::model::{ModelError, ModelRunner};
6
+ use crate::types::InputType;
7
+ use std::time::Duration;
8
+ use tokio::time::interval;
9
+
10
+ pub struct ModelRuntime {
11
+ model_runner: Arc<ModelRunner>,
12
+ dt: Duration,
13
+ slowdown_factor: i32,
14
+ magnitude_factor: f32,
15
+ running: Arc<AtomicBool>,
16
+ runtime: Option<Runtime>,
17
+ }
18
+
19
+ impl ModelRuntime {
20
+ pub fn new(model_runner: Arc<ModelRunner>, dt: u64) -> Self {
21
+ Self {
22
+ model_runner,
23
+ dt: Duration::from_millis(dt),
24
+ slowdown_factor: 1,
25
+ magnitude_factor: 1.0,
26
+ running: Arc::new(AtomicBool::new(false)),
27
+ runtime: None,
28
+ }
29
+ }
30
+
31
+ pub fn set_slowdown_factor(&mut self, slowdown_factor: i32) {
32
+ assert!(slowdown_factor >= 1);
33
+ self.slowdown_factor = slowdown_factor;
34
+ }
35
+
36
+ pub fn set_magnitude_factor(&mut self, magnitude_factor: f32) {
37
+ assert!(magnitude_factor >= 0.0);
38
+ assert!(magnitude_factor <= 1.0);
39
+ self.magnitude_factor = magnitude_factor;
40
+ }
41
+
42
+ pub fn start(&mut self) -> Result<(), ModelError> {
43
+ if self.running.load(Ordering::Relaxed) {
44
+ return Ok(());
45
+ }
46
+
47
+ let running = self.running.clone();
48
+ let model_runner = self.model_runner.clone();
49
+ let dt = self.dt;
50
+ let slowdown_factor = self.slowdown_factor;
51
+ let magnitude_factor = self.magnitude_factor;
52
+
53
+ let runtime = Runtime::new()?;
54
+ running.store(true, Ordering::Relaxed);
55
+
56
+ runtime.spawn(async move {
57
+ let mut carry = model_runner
58
+ .init()
59
+ .await
60
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
61
+
62
+ let model_inputs = model_runner
63
+ .get_inputs(&[InputType::JointAngles])
64
+ .await
65
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
66
+ let mut joint_positions = model_inputs[&InputType::JointAngles].clone();
67
+
68
+ // Wait for the first tick, since it happens immediately.
69
+ let mut interval = interval(dt);
70
+ interval.tick().await;
71
+
72
+ while running.load(Ordering::Relaxed) {
73
+ let (output, next_carry) = model_runner
74
+ .step(carry)
75
+ .await
76
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
77
+ carry = next_carry;
78
+
79
+ for i in 1..(slowdown_factor + 1) {
80
+ if !running.load(Ordering::Relaxed) {
81
+ break;
82
+ }
83
+ let t = i as f32 / slowdown_factor as f32;
84
+ let interp_joint_positions = &joint_positions * (1.0 - t) + &output * t;
85
+ model_runner
86
+ .take_action(interp_joint_positions * magnitude_factor)
87
+ .await
88
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
89
+ interval.tick().await;
90
+ }
91
+
92
+ joint_positions = output;
93
+ }
94
+ Ok::<(), ModelError>(())
95
+ });
96
+
97
+ self.runtime = Some(runtime);
98
+ Ok(())
99
+ }
100
+
101
+ pub fn stop(&mut self) {
102
+ self.running.store(false, Ordering::Relaxed);
103
+ if let Some(runtime) = self.runtime.take() {
104
+ runtime.shutdown_background();
105
+ }
106
+ }
107
+ }
@@ -0,0 +1,96 @@
1
+ use serde::Deserialize;
2
+ use serde::Serialize;
3
+
4
+ #[derive(Debug, Deserialize, Serialize, Clone)]
5
+ pub struct ModelMetadata {
6
+ pub joint_names: Vec<String>,
7
+ pub num_commands: Option<usize>,
8
+ pub carry_size: Vec<usize>,
9
+ }
10
+
11
+ impl ModelMetadata {
12
+ pub fn model_validate_json(json: String) -> Result<Self, Box<dyn std::error::Error>> {
13
+ Ok(serde_json::from_str(&json)?)
14
+ }
15
+
16
+ pub fn to_json(&self) -> Result<String, Box<dyn std::error::Error>> {
17
+ Ok(serde_json::to_string(self)?)
18
+ }
19
+ }
20
+
21
+ #[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
22
+ pub enum InputType {
23
+ JointAngles,
24
+ JointAngularVelocities,
25
+ InitialHeading,
26
+ Quaternion,
27
+ ProjectedGravity,
28
+ Accelerometer,
29
+ Gyroscope,
30
+ Command,
31
+ Time,
32
+ Carry,
33
+ }
34
+
35
+ impl InputType {
36
+ pub fn get_name(&self) -> &str {
37
+ match self {
38
+ InputType::JointAngles => "joint_angles",
39
+ InputType::JointAngularVelocities => "joint_angular_velocities",
40
+ InputType::InitialHeading => "initial_heading",
41
+ InputType::Quaternion => "quaternion",
42
+ InputType::ProjectedGravity => "projected_gravity",
43
+ InputType::Accelerometer => "accelerometer",
44
+ InputType::Gyroscope => "gyroscope",
45
+ InputType::Command => "command",
46
+ InputType::Time => "time",
47
+ InputType::Carry => "carry",
48
+ }
49
+ }
50
+
51
+ pub fn get_shape(&self, metadata: &ModelMetadata) -> Vec<usize> {
52
+ match self {
53
+ InputType::JointAngles => vec![metadata.joint_names.len()],
54
+ InputType::JointAngularVelocities => vec![metadata.joint_names.len()],
55
+ InputType::InitialHeading => vec![1],
56
+ InputType::Quaternion => vec![4],
57
+ InputType::ProjectedGravity => vec![3],
58
+ InputType::Accelerometer => vec![3],
59
+ InputType::Gyroscope => vec![3],
60
+ InputType::Command => vec![metadata.num_commands.unwrap_or(0)],
61
+ InputType::Time => vec![1],
62
+ InputType::Carry => metadata.carry_size.clone(),
63
+ }
64
+ }
65
+
66
+ pub fn from_name(name: &str) -> Result<Self, Box<dyn std::error::Error>> {
67
+ match name {
68
+ "joint_angles" => Ok(InputType::JointAngles),
69
+ "joint_angular_velocities" => Ok(InputType::JointAngularVelocities),
70
+ "initial_heading" => Ok(InputType::InitialHeading),
71
+ "quaternion" => Ok(InputType::Quaternion),
72
+ "projected_gravity" => Ok(InputType::ProjectedGravity),
73
+ "accelerometer" => Ok(InputType::Accelerometer),
74
+ "gyroscope" => Ok(InputType::Gyroscope),
75
+ "command" => Ok(InputType::Command),
76
+ "time" => Ok(InputType::Time),
77
+ "carry" => Ok(InputType::Carry),
78
+ _ => Err(format!("Unknown input type: {}", name).into()),
79
+ }
80
+ }
81
+
82
+ pub fn get_names() -> Vec<&'static str> {
83
+ vec![
84
+ "joint_angles",
85
+ "joint_angular_velocities",
86
+ "initial_heading",
87
+ "quaternion",
88
+ "projected_gravity",
89
+ "accelerometer",
90
+ "gyroscope",
91
+ "command",
92
+ "time",
93
+ "carry",
94
+ ]
95
+ }
96
+ }
@@ -0,0 +1,26 @@
1
+ [package]
2
+
3
+ name = "rust_bindings"
4
+ version.workspace = true
5
+ edition.workspace = true
6
+ description.workspace = true
7
+ authors.workspace = true
8
+ repository.workspace = true
9
+ license.workspace = true
10
+ readme.workspace = true
11
+
12
+ [lib]
13
+
14
+ name = "rust_bindings"
15
+ crate-type = ["cdylib", "rlib"]
16
+
17
+ [dependencies]
18
+
19
+ async-trait = "0.1"
20
+ kinfer = { path = "../rust" }
21
+ ndarray = "0.16.1"
22
+ numpy = ">= 0.24.0"
23
+ pyo3 = { version = ">= 0.24.0", features = ["extension-module"] }
24
+ pyo3-async-runtimes = { version = ">= 0.24.0", features = ["attributes", "tokio-runtime"] }
25
+ pyo3-stub-gen = ">= 0.6.0"
26
+ tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
@@ -0,0 +1,7 @@
1
+ [build-system]
2
+ requires = ["maturin>=1.1,<2.0"]
3
+ build-backend = "maturin"
4
+
5
+ [project]
6
+ name = "rust_bindings"
7
+ requires-python = ">=3.11"
@@ -0,0 +1,46 @@
1
+ # This file is automatically generated by pyo3_stub_gen
2
+ # ruff: noqa: E501, F401
3
+
4
+ import builtins
5
+ import numpy
6
+ import numpy.typing
7
+ import typing
8
+
9
+ class ModelProviderABC:
10
+ def __new__(cls) -> ModelProviderABC: ...
11
+ def get_inputs(self, input_types:typing.Sequence[builtins.str], metadata:PyModelMetadata) -> builtins.dict[builtins.str, numpy.typing.NDArray[numpy.float32]]: ...
12
+ def take_action(self, action:numpy.typing.NDArray[numpy.float32], metadata:PyModelMetadata) -> None: ...
13
+
14
+ class PyInputType:
15
+ def __new__(cls, input_type:builtins.str) -> PyInputType: ...
16
+ def get_name(self) -> builtins.str: ...
17
+ def get_shape(self, metadata:PyModelMetadata) -> builtins.list[builtins.int]: ...
18
+ def __repr__(self) -> builtins.str: ...
19
+ def __eq__(self, other:typing.Any) -> builtins.bool: ...
20
+
21
+ class PyModelMetadata:
22
+ def __new__(self, joint_names:typing.Sequence[builtins.str], num_commands:typing.Optional[builtins.int], carry_size:typing.Sequence[builtins.int]) -> PyModelMetadata: ...
23
+ def to_json(self) -> builtins.str: ...
24
+ def __repr__(self) -> builtins.str: ...
25
+ def __eq__(self, other:typing.Any) -> builtins.bool: ...
26
+
27
+ class PyModelProvider:
28
+ ...
29
+
30
+ class PyModelRunner:
31
+ def __new__(cls, model_path:builtins.str, provider:ModelProviderABC) -> PyModelRunner: ...
32
+ def init(self) -> numpy.typing.NDArray[numpy.float32]: ...
33
+ def step(self, carry:numpy.typing.NDArray[numpy.float32]) -> tuple[numpy.typing.NDArray[numpy.float32], numpy.typing.NDArray[numpy.float32]]: ...
34
+ def take_action(self, action:numpy.typing.NDArray[numpy.float32]) -> None: ...
35
+
36
+ class PyModelRuntime:
37
+ def __new__(cls, model_runner:PyModelRunner, dt:builtins.int) -> PyModelRuntime: ...
38
+ def set_slowdown_factor(self, slowdown_factor:builtins.int) -> None: ...
39
+ def set_magnitude_factor(self, magnitude_factor:builtins.float) -> None: ...
40
+ def start(self) -> None: ...
41
+ def stop(self) -> None: ...
42
+
43
+ def get_version() -> builtins.str: ...
44
+
45
+ def metadata_from_json(json:builtins.str) -> PyModelMetadata: ...
46
+
@@ -0,0 +1,7 @@
1
+ use pyo3_stub_gen::Result;
2
+
3
+ fn main() -> Result<()> {
4
+ let stub = rust_bindings::stub_info()?;
5
+ stub.generate()?;
6
+ Ok(())
7
+ }