kinfer 0.4.2__cp311-cp311-macosx_11_0_arm64.whl → 0.5.1__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,8 +3,9 @@ use std::sync::Arc;
3
3
  use tokio::runtime::Runtime;
4
4
 
5
5
  use crate::model::{ModelError, ModelRunner};
6
- use std::time::{Duration, Instant};
7
- use tokio::time::sleep;
6
+ use crate::types::InputType;
7
+ use std::time::Duration;
8
+ use tokio::time::interval;
8
9
 
9
10
  pub struct ModelRuntime {
10
11
  model_runner: Arc<ModelRunner>,
@@ -57,11 +58,16 @@ impl ModelRuntime {
57
58
  .init()
58
59
  .await
59
60
  .map_err(|e| ModelError::Provider(e.to_string()))?;
60
- let mut joint_positions = model_runner
61
- .get_joint_angles()
61
+
62
+ let model_inputs = model_runner
63
+ .get_inputs(&[InputType::JointAngles])
62
64
  .await
63
65
  .map_err(|e| ModelError::Provider(e.to_string()))?;
64
- let mut last_time = Instant::now();
66
+ let mut joint_positions = model_inputs[&InputType::JointAngles].clone();
67
+
68
+ // Wait for the first tick, since it happens immediately.
69
+ let mut interval = interval(dt);
70
+ interval.tick().await;
65
71
 
66
72
  while running.load(Ordering::Relaxed) {
67
73
  let (output, next_carry) = model_runner
@@ -80,10 +86,7 @@ impl ModelRuntime {
80
86
  .take_action(interp_joint_positions * magnitude_factor)
81
87
  .await
82
88
  .map_err(|e| ModelError::Provider(e.to_string()))?;
83
- last_time = last_time + dt;
84
- if let Some(sleep_duration) = last_time.checked_duration_since(Instant::now()) {
85
- sleep(sleep_duration).await;
86
- }
89
+ interval.tick().await;
87
90
  }
88
91
 
89
92
  joint_positions = output;
@@ -0,0 +1,86 @@
1
+ use serde::Deserialize;
2
+ use serde::Serialize;
3
+
4
+ #[derive(Debug, Deserialize, Serialize, Clone)]
5
+ pub struct ModelMetadata {
6
+ pub joint_names: Vec<String>,
7
+ pub num_commands: Option<usize>,
8
+ pub carry_size: Vec<usize>,
9
+ }
10
+
11
+ impl ModelMetadata {
12
+ pub fn model_validate_json(json: String) -> Result<Self, Box<dyn std::error::Error>> {
13
+ Ok(serde_json::from_str(&json)?)
14
+ }
15
+
16
+ pub fn to_json(&self) -> Result<String, Box<dyn std::error::Error>> {
17
+ Ok(serde_json::to_string(self)?)
18
+ }
19
+ }
20
+
21
+ #[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
22
+ pub enum InputType {
23
+ JointAngles,
24
+ JointAngularVelocities,
25
+ ProjectedGravity,
26
+ Accelerometer,
27
+ Gyroscope,
28
+ Command,
29
+ Time,
30
+ Carry,
31
+ }
32
+
33
+ impl InputType {
34
+ pub fn get_name(&self) -> &str {
35
+ match self {
36
+ InputType::JointAngles => "joint_angles",
37
+ InputType::JointAngularVelocities => "joint_angular_velocities",
38
+ InputType::ProjectedGravity => "projected_gravity",
39
+ InputType::Accelerometer => "accelerometer",
40
+ InputType::Gyroscope => "gyroscope",
41
+ InputType::Command => "command",
42
+ InputType::Time => "time",
43
+ InputType::Carry => "carry",
44
+ }
45
+ }
46
+
47
+ pub fn get_shape(&self, metadata: &ModelMetadata) -> Vec<usize> {
48
+ match self {
49
+ InputType::JointAngles => vec![metadata.joint_names.len()],
50
+ InputType::JointAngularVelocities => vec![metadata.joint_names.len()],
51
+ InputType::ProjectedGravity => vec![3],
52
+ InputType::Accelerometer => vec![3],
53
+ InputType::Gyroscope => vec![3],
54
+ InputType::Command => vec![metadata.num_commands.unwrap_or(0)],
55
+ InputType::Time => vec![1],
56
+ InputType::Carry => metadata.carry_size.clone(),
57
+ }
58
+ }
59
+
60
+ pub fn from_name(name: &str) -> Result<Self, Box<dyn std::error::Error>> {
61
+ match name {
62
+ "joint_angles" => Ok(InputType::JointAngles),
63
+ "joint_angular_velocities" => Ok(InputType::JointAngularVelocities),
64
+ "projected_gravity" => Ok(InputType::ProjectedGravity),
65
+ "accelerometer" => Ok(InputType::Accelerometer),
66
+ "gyroscope" => Ok(InputType::Gyroscope),
67
+ "command" => Ok(InputType::Command),
68
+ "time" => Ok(InputType::Time),
69
+ "carry" => Ok(InputType::Carry),
70
+ _ => Err(format!("Unknown input type: {}", name).into()),
71
+ }
72
+ }
73
+
74
+ pub fn get_names() -> Vec<&'static str> {
75
+ vec![
76
+ "joint_angles",
77
+ "joint_angular_velocities",
78
+ "projected_gravity",
79
+ "accelerometer",
80
+ "gyroscope",
81
+ "command",
82
+ "time",
83
+ "carry",
84
+ ]
85
+ }
86
+ }
@@ -8,13 +8,21 @@ import typing
8
8
 
9
9
  class ModelProviderABC:
10
10
  def __new__(cls) -> ModelProviderABC: ...
11
- def get_joint_angles(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
12
- def get_joint_angular_velocities(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
13
- def get_projected_gravity(self) -> numpy.typing.NDArray[numpy.float32]: ...
14
- def get_accelerometer(self) -> numpy.typing.NDArray[numpy.float32]: ...
15
- def get_gyroscope(self) -> numpy.typing.NDArray[numpy.float32]: ...
16
- def get_command(self) -> numpy.typing.NDArray[numpy.float32]: ...
17
- def take_action(self, joint_names:typing.Sequence[builtins.str], action:numpy.typing.NDArray[numpy.float32]) -> None: ...
11
+ def get_inputs(self, input_types:typing.Sequence[builtins.str], metadata:PyModelMetadata) -> builtins.dict[builtins.str, numpy.typing.NDArray[numpy.float32]]: ...
12
+ def take_action(self, action:numpy.typing.NDArray[numpy.float32], metadata:PyModelMetadata) -> None: ...
13
+
14
+ class PyInputType:
15
+ def __new__(cls, input_type:builtins.str) -> PyInputType: ...
16
+ def get_name(self) -> builtins.str: ...
17
+ def get_shape(self, metadata:PyModelMetadata) -> builtins.list[builtins.int]: ...
18
+ def __repr__(self) -> builtins.str: ...
19
+ def __eq__(self, other:typing.Any) -> builtins.bool: ...
20
+
21
+ class PyModelMetadata:
22
+ def __new__(self, joint_names:typing.Sequence[builtins.str], num_commands:typing.Optional[builtins.int], carry_size:typing.Sequence[builtins.int]) -> PyModelMetadata: ...
23
+ def to_json(self) -> builtins.str: ...
24
+ def __repr__(self) -> builtins.str: ...
25
+ def __eq__(self, other:typing.Any) -> builtins.bool: ...
18
26
 
19
27
  class PyModelProvider:
20
28
  ...
@@ -34,3 +42,5 @@ class PyModelRuntime:
34
42
 
35
43
  def get_version() -> builtins.str: ...
36
44
 
45
+ def metadata_from_json(json:builtins.str) -> PyModelMetadata: ...
46
+
@@ -1,85 +1,210 @@
1
1
  use async_trait::async_trait;
2
2
  use kinfer::model::{ModelError, ModelProvider, ModelRunner};
3
3
  use kinfer::runtime::ModelRuntime;
4
+ use kinfer::types::{InputType, ModelMetadata};
4
5
  use ndarray::{Array, Ix1, IxDyn};
5
6
  use numpy::{PyArray1, PyArrayDyn, PyArrayMethods};
6
7
  use pyo3::exceptions::PyNotImplementedError;
7
8
  use pyo3::prelude::*;
9
+ use pyo3::types::{PyAny, PyAnyMethods};
8
10
  use pyo3::{pymodule, types::PyModule, Bound, PyResult, Python};
9
11
  use pyo3_stub_gen::define_stub_info_gatherer;
10
12
  use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
13
+ use std::collections::HashMap;
14
+ use std::hash::Hash;
11
15
  use std::sync::Arc;
12
16
  use std::sync::Mutex;
13
17
 
18
+ type StepResult = (Py<PyArrayDyn<f32>>, Py<PyArrayDyn<f32>>);
19
+
14
20
  #[pyfunction]
15
21
  #[gen_stub_pyfunction]
16
22
  fn get_version() -> String {
17
23
  env!("CARGO_PKG_VERSION").to_string()
18
24
  }
19
25
 
20
- #[pyclass(subclass)]
26
+ #[pyclass]
21
27
  #[gen_stub_pyclass]
22
- pub struct ModelProviderABC;
28
+ #[derive(Debug, Clone, PartialEq, Eq, Hash)]
29
+ struct PyInputType {
30
+ pub input_type: InputType,
31
+ }
32
+
33
+ impl From<InputType> for PyInputType {
34
+ fn from(input_type: InputType) -> Self {
35
+ Self { input_type }
36
+ }
37
+ }
38
+
39
+ impl From<PyInputType> for InputType {
40
+ fn from(input_type: PyInputType) -> Self {
41
+ input_type.input_type
42
+ }
43
+ }
23
44
 
24
45
  #[gen_stub_pymethods]
25
46
  #[pymethods]
26
- impl ModelProviderABC {
47
+ impl PyInputType {
27
48
  #[new]
28
- fn new() -> Self {
29
- ModelProviderABC
49
+ fn __new__(input_type: &str) -> PyResult<Self> {
50
+ let input_type = InputType::from_name(input_type).map_or_else(
51
+ |_| {
52
+ Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
53
+ "Invalid input type: {} (must be one of {})",
54
+ input_type,
55
+ InputType::get_names().join(", "),
56
+ )))
57
+ },
58
+ Ok,
59
+ )?;
60
+ Ok(Self { input_type })
30
61
  }
31
62
 
32
- fn get_joint_angles<'py>(
33
- &self,
34
- joint_names: Vec<String>,
35
- ) -> PyResult<Bound<'py, PyArray1<f32>>> {
36
- let n = joint_names.len();
37
- Err(PyNotImplementedError::new_err(format!(
38
- "Must override get_joint_angles with {} joint names",
39
- n
40
- )))
63
+ fn get_name(&self) -> String {
64
+ self.input_type.get_name().to_string()
41
65
  }
42
66
 
43
- fn get_joint_angular_velocities<'py>(
44
- &self,
67
+ fn get_shape(&self, metadata: PyModelMetadata) -> Vec<usize> {
68
+ self.input_type.get_shape(&metadata.into())
69
+ }
70
+
71
+ fn __repr__(&self) -> String {
72
+ format!("InputType({})", self.get_name())
73
+ }
74
+
75
+ fn __eq__(&self, other: Bound<'_, PyAny>) -> PyResult<bool> {
76
+ if let Ok(other) = other.extract::<PyInputType>() {
77
+ Ok(self == &other)
78
+ } else {
79
+ Ok(false)
80
+ }
81
+ }
82
+ }
83
+
84
+ #[pyclass]
85
+ #[gen_stub_pyclass]
86
+ #[derive(Debug, Clone, PartialEq, Eq, Hash)]
87
+ struct PyModelMetadata {
88
+ #[pyo3(get, set)]
89
+ pub joint_names: Vec<String>,
90
+ #[pyo3(get, set)]
91
+ pub num_commands: Option<usize>,
92
+ #[pyo3(get, set)]
93
+ pub carry_size: Vec<usize>,
94
+ }
95
+
96
+ #[pymethods]
97
+ #[gen_stub_pymethods]
98
+ impl PyModelMetadata {
99
+ #[new]
100
+ fn __new__(
45
101
  joint_names: Vec<String>,
46
- ) -> PyResult<Bound<'py, PyArray1<f32>>> {
47
- let n = joint_names.len();
48
- Err(PyNotImplementedError::new_err(format!(
49
- "Must override get_joint_angular_velocities with {} joint names",
50
- n
51
- )))
102
+ num_commands: Option<usize>,
103
+ carry_size: Vec<usize>,
104
+ ) -> Self {
105
+ Self {
106
+ joint_names,
107
+ num_commands,
108
+ carry_size,
109
+ }
110
+ }
111
+
112
+ fn to_json(&self) -> PyResult<String> {
113
+ let metadata = ModelMetadata {
114
+ joint_names: self.joint_names.clone(),
115
+ num_commands: self.num_commands,
116
+ carry_size: self.carry_size.clone(),
117
+ }
118
+ .to_json()
119
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
120
+ Ok(metadata)
121
+ }
122
+
123
+ fn __repr__(&self) -> PyResult<String> {
124
+ let json = self.to_json()?;
125
+ Ok(format!("ModelMetadata({:?})", json))
126
+ }
127
+
128
+ fn __eq__(&self, other: Bound<'_, PyAny>) -> PyResult<bool> {
129
+ if let Ok(other) = other.extract::<PyModelMetadata>() {
130
+ Ok(self == &other)
131
+ } else {
132
+ Ok(false)
133
+ }
134
+ }
135
+ }
136
+
137
+ #[pyfunction]
138
+ #[gen_stub_pyfunction]
139
+ fn metadata_from_json(json: &str) -> PyResult<PyModelMetadata> {
140
+ let metadata = ModelMetadata::model_validate_json(json.to_string()).map_err(|e| {
141
+ PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Invalid model metadata: {}", e))
142
+ })?;
143
+ Ok(PyModelMetadata::from(metadata))
144
+ }
145
+
146
+ impl From<ModelMetadata> for PyModelMetadata {
147
+ fn from(metadata: ModelMetadata) -> Self {
148
+ Self {
149
+ joint_names: metadata.joint_names,
150
+ num_commands: metadata.num_commands,
151
+ carry_size: metadata.carry_size,
152
+ }
52
153
  }
154
+ }
53
155
 
54
- fn get_projected_gravity<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
55
- Err(PyNotImplementedError::new_err(
56
- "Must override get_projected_gravity",
57
- ))
156
+ impl From<&ModelMetadata> for PyModelMetadata {
157
+ fn from(metadata: &ModelMetadata) -> Self {
158
+ Self {
159
+ joint_names: metadata.joint_names.clone(),
160
+ num_commands: metadata.num_commands,
161
+ carry_size: metadata.carry_size.clone(),
162
+ }
58
163
  }
164
+ }
59
165
 
60
- fn get_accelerometer<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
61
- Err(PyNotImplementedError::new_err(
62
- "Must override get_accelerometer",
63
- ))
166
+ impl From<PyModelMetadata> for ModelMetadata {
167
+ fn from(metadata: PyModelMetadata) -> Self {
168
+ Self {
169
+ joint_names: metadata.joint_names,
170
+ num_commands: metadata.num_commands,
171
+ carry_size: metadata.carry_size,
172
+ }
64
173
  }
174
+ }
175
+
176
+ #[pyclass(subclass)]
177
+ #[gen_stub_pyclass]
178
+ struct ModelProviderABC;
65
179
 
66
- fn get_gyroscope<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
67
- Err(PyNotImplementedError::new_err(
68
- "Must override get_gyroscope",
69
- ))
180
+ #[gen_stub_pymethods]
181
+ #[pymethods]
182
+ impl ModelProviderABC {
183
+ #[new]
184
+ fn __new__() -> Self {
185
+ ModelProviderABC
70
186
  }
71
187
 
72
- fn get_command<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
73
- Err(PyNotImplementedError::new_err("Must override get_command"))
188
+ fn get_inputs<'py>(
189
+ &self,
190
+ input_types: Vec<String>,
191
+ metadata: PyModelMetadata,
192
+ ) -> PyResult<HashMap<String, Bound<'py, PyArrayDyn<f32>>>> {
193
+ Err(PyNotImplementedError::new_err(format!(
194
+ "Must override get_inputs with {} input types {:?} and metadata {:?}",
195
+ input_types.len(),
196
+ input_types,
197
+ metadata
198
+ )))
74
199
  }
75
200
 
76
- fn take_action<'py>(
201
+ fn take_action(
77
202
  &self,
78
- joint_names: Vec<String>,
79
- action: Bound<'py, PyArray1<f32>>,
203
+ action: Bound<'_, PyArray1<f32>>,
204
+ metadata: PyModelMetadata,
80
205
  ) -> PyResult<()> {
81
206
  let n = action.len()?;
82
- assert_eq!(joint_names.len(), n);
207
+ assert_eq!(metadata.joint_names.len(), n); // TODO: this is wrong
83
208
  Err(PyNotImplementedError::new_err(format!(
84
209
  "Must override take_action with {} action",
85
210
  n
@@ -97,106 +222,57 @@ struct PyModelProvider {
97
222
  #[pymethods]
98
223
  impl PyModelProvider {
99
224
  #[new]
100
- fn new(obj: Py<ModelProviderABC>) -> Self {
225
+ fn __new__(obj: Py<ModelProviderABC>) -> Self {
101
226
  Self { obj: Arc::new(obj) }
102
227
  }
103
228
  }
104
229
 
105
230
  #[async_trait]
106
231
  impl ModelProvider for PyModelProvider {
107
- async fn get_joint_angles(
108
- &self,
109
- joint_names: &[String],
110
- ) -> Result<Array<f32, IxDyn>, ModelError> {
111
- let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
112
- let obj = self.obj.clone();
113
- let args = (joint_names,);
114
- let result = obj.call_method(py, "get_joint_angles", args, None)?;
115
- let array = result.extract::<Vec<f32>>(py)?;
116
- Ok(Array::from_vec(array).into_dyn())
117
- })
118
- .map_err(|e| ModelError::Provider(e.to_string()))?;
119
- Ok(args)
120
- }
121
-
122
- async fn get_joint_angular_velocities(
232
+ async fn get_inputs(
123
233
  &self,
124
- joint_names: &[String],
125
- ) -> Result<Array<f32, IxDyn>, ModelError> {
126
- let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
127
- let obj = self.obj.clone();
128
- let args = (joint_names,);
129
- let result = obj.call_method(py, "get_joint_angular_velocities", args, None)?;
130
- let array = result.extract::<Vec<f32>>(py)?;
131
- Ok(Array::from_vec(array).into_dyn())
132
- })
133
- .map_err(|e| ModelError::Provider(e.to_string()))?;
134
- Ok(args)
135
- }
136
-
137
- async fn get_projected_gravity(&self) -> Result<Array<f32, IxDyn>, ModelError> {
138
- let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
139
- let obj = self.obj.clone();
140
- let args = ();
141
- let result = obj.call_method(py, "get_projected_gravity", args, None)?;
142
- let array = result.extract::<Vec<f32>>(py)?;
143
- Ok(Array::from_vec(array).into_dyn())
144
- })
145
- .map_err(|e| ModelError::Provider(e.to_string()))?;
146
- Ok(args)
147
- }
148
-
149
- async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError> {
150
- let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
151
- let obj = self.obj.clone();
152
- let args = ();
153
- let result = obj.call_method(py, "get_accelerometer", args, None)?;
154
- let array = result.extract::<Vec<f32>>(py)?;
155
- Ok(Array::from_vec(array).into_dyn())
156
- })
157
- .map_err(|e| ModelError::Provider(e.to_string()))?;
158
- Ok(args)
159
- }
160
-
161
- async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError> {
162
- let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
234
+ input_types: &[InputType],
235
+ metadata: &ModelMetadata,
236
+ ) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError> {
237
+ let input_names: Vec<String> = input_types
238
+ .iter()
239
+ .map(|t| t.get_name().to_string())
240
+ .collect();
241
+ let result = Python::with_gil(|py| -> PyResult<HashMap<InputType, Array<f32, IxDyn>>> {
163
242
  let obj = self.obj.clone();
164
- let args = ();
165
- let result = obj.call_method(py, "get_gyroscope", args, None)?;
166
- let array = result.extract::<Vec<f32>>(py)?;
167
- Ok(Array::from_vec(array).into_dyn())
243
+ let args = (input_names.clone(), PyModelMetadata::from(metadata.clone()));
244
+ let result = obj.call_method(py, "get_inputs", args, None)?;
245
+ let dict: HashMap<String, Vec<f32>> = result.extract(py)?;
246
+ let mut arrays = HashMap::new();
247
+ for (i, name) in input_names.iter().enumerate() {
248
+ let array = dict.get(name).ok_or_else(|| {
249
+ PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
250
+ "Missing input: {}",
251
+ name
252
+ ))
253
+ })?;
254
+ arrays.insert(input_types[i], Array::from_vec(array.clone()).into_dyn());
255
+ }
256
+ Ok(arrays)
168
257
  })
169
258
  .map_err(|e| ModelError::Provider(e.to_string()))?;
170
- Ok(args)
171
- }
172
-
173
- async fn get_command(&self) -> Result<Array<f32, IxDyn>, ModelError> {
174
- let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
175
- let obj = self.obj.clone();
176
- let args = ();
177
- let result = obj.call_method(py, "get_command", args, None)?;
178
- let array = result.extract::<Vec<f32>>(py)?;
179
- Ok(Array::from_vec(array).into_dyn())
180
- })
181
- .map_err(|e| ModelError::Provider(e.to_string()))?;
182
- Ok(args)
183
- }
184
-
185
- async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError> {
186
- Ok(carry)
259
+ Ok(result)
187
260
  }
188
261
 
189
262
  async fn take_action(
190
263
  &self,
191
- joint_names: Vec<String>,
192
264
  action: Array<f32, IxDyn>,
265
+ metadata: &ModelMetadata,
193
266
  ) -> Result<(), ModelError> {
194
267
  Python::with_gil(|py| -> PyResult<()> {
195
268
  let obj = self.obj.clone();
196
269
  let action_1d = action
197
270
  .into_dimensionality::<Ix1>()
198
271
  .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
199
- let args = (joint_names, PyArray1::from_array(py, &action_1d));
272
+ let args = (
273
+ PyArray1::from_array(py, &action_1d),
274
+ PyModelMetadata::from(metadata.clone()),
275
+ );
200
276
  obj.call_method(py, "take_action", args, None)?;
201
277
  Ok(())
202
278
  })
@@ -216,12 +292,10 @@ struct PyModelRunner {
216
292
  #[pymethods]
217
293
  impl PyModelRunner {
218
294
  #[new]
219
- fn new(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
220
- let input_provider = Arc::new(PyModelProvider {
221
- obj: Arc::new(provider),
222
- });
295
+ fn __new__(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
296
+ let input_provider = Arc::new(PyModelProvider::__new__(provider));
223
297
 
224
- let runner = tokio::runtime::Runtime::new().unwrap().block_on(async {
298
+ let runner = tokio::runtime::Runtime::new()?.block_on(async {
225
299
  ModelRunner::new(model_path, input_provider)
226
300
  .await
227
301
  .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
@@ -234,7 +308,7 @@ impl PyModelRunner {
234
308
 
235
309
  fn init(&self) -> PyResult<Py<PyArrayDyn<f32>>> {
236
310
  let runner = self.runner.clone();
237
- let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
311
+ let result = tokio::runtime::Runtime::new()?.block_on(async {
238
312
  runner
239
313
  .init()
240
314
  .await
@@ -247,17 +321,14 @@ impl PyModelRunner {
247
321
  })
248
322
  }
249
323
 
250
- fn step(
251
- &self,
252
- carry: Py<PyArrayDyn<f32>>,
253
- ) -> PyResult<(Py<PyArrayDyn<f32>>, Py<PyArrayDyn<f32>>)> {
324
+ fn step(&self, carry: Py<PyArrayDyn<f32>>) -> PyResult<StepResult> {
254
325
  let runner = self.runner.clone();
255
326
  let carry_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
256
327
  let carry_array = carry.bind(py);
257
328
  Ok(carry_array.to_owned_array())
258
329
  })?;
259
330
 
260
- let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
331
+ let result = tokio::runtime::Runtime::new()?.block_on(async {
261
332
  runner
262
333
  .step(carry_array)
263
334
  .await
@@ -279,7 +350,7 @@ impl PyModelRunner {
279
350
  Ok(action_array.to_owned_array())
280
351
  })?;
281
352
 
282
- tokio::runtime::Runtime::new().unwrap().block_on(async {
353
+ tokio::runtime::Runtime::new()?.block_on(async {
283
354
  runner
284
355
  .take_action(action_array)
285
356
  .await
@@ -301,38 +372,56 @@ struct PyModelRuntime {
301
372
  #[pymethods]
302
373
  impl PyModelRuntime {
303
374
  #[new]
304
- fn new(model_runner: PyModelRunner, dt: u64) -> PyResult<Self> {
375
+ fn __new__(model_runner: PyModelRunner, dt: u64) -> PyResult<Self> {
305
376
  Ok(Self {
306
377
  runtime: Arc::new(Mutex::new(ModelRuntime::new(model_runner.runner, dt))),
307
378
  })
308
379
  }
309
380
 
310
- fn set_slowdown_factor(&self, slowdown_factor: i32) {
311
- let mut runtime = self.runtime.lock().unwrap();
381
+ fn set_slowdown_factor(&self, slowdown_factor: i32) -> PyResult<()> {
382
+ let mut runtime = self
383
+ .runtime
384
+ .lock()
385
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
312
386
  runtime.set_slowdown_factor(slowdown_factor);
387
+ Ok(())
313
388
  }
314
389
 
315
- fn set_magnitude_factor(&self, magnitude_factor: f32) {
316
- let mut runtime = self.runtime.lock().unwrap();
390
+ fn set_magnitude_factor(&self, magnitude_factor: f32) -> PyResult<()> {
391
+ let mut runtime = self
392
+ .runtime
393
+ .lock()
394
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
317
395
  runtime.set_magnitude_factor(magnitude_factor);
396
+ Ok(())
318
397
  }
319
398
 
320
399
  fn start(&self) -> PyResult<()> {
321
- let mut runtime = self.runtime.lock().unwrap();
400
+ let mut runtime = self
401
+ .runtime
402
+ .lock()
403
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
322
404
  runtime
323
405
  .start()
324
406
  .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
325
407
  }
326
408
 
327
- fn stop(&self) {
328
- let mut runtime = self.runtime.lock().unwrap();
409
+ fn stop(&self) -> PyResult<()> {
410
+ let mut runtime = self
411
+ .runtime
412
+ .lock()
413
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
329
414
  runtime.stop();
415
+ Ok(())
330
416
  }
331
417
  }
332
418
 
333
419
  #[pymodule]
334
420
  fn rust_bindings(m: &Bound<PyModule>) -> PyResult<()> {
335
421
  m.add_function(wrap_pyfunction!(get_version, m)?)?;
422
+ m.add_class::<PyInputType>()?;
423
+ m.add_class::<PyModelMetadata>()?;
424
+ m.add_function(wrap_pyfunction!(metadata_from_json, m)?)?;
336
425
  m.add_class::<ModelProviderABC>()?;
337
426
  m.add_class::<PyModelRunner>()?;
338
427
  m.add_class::<PyModelRuntime>()?;
Binary file