kinfer 0.4.3__cp311-cp311-macosx_11_0_arm64.whl → 0.5.2__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,86 @@
1
+ use serde::Deserialize;
2
+ use serde::Serialize;
3
+
4
+ #[derive(Debug, Deserialize, Serialize, Clone)]
5
+ pub struct ModelMetadata {
6
+ pub joint_names: Vec<String>,
7
+ pub num_commands: Option<usize>,
8
+ pub carry_size: Vec<usize>,
9
+ }
10
+
11
+ impl ModelMetadata {
12
+ pub fn model_validate_json(json: String) -> Result<Self, Box<dyn std::error::Error>> {
13
+ Ok(serde_json::from_str(&json)?)
14
+ }
15
+
16
+ pub fn to_json(&self) -> Result<String, Box<dyn std::error::Error>> {
17
+ Ok(serde_json::to_string(self)?)
18
+ }
19
+ }
20
+
21
+ #[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
22
+ pub enum InputType {
23
+ JointAngles,
24
+ JointAngularVelocities,
25
+ ProjectedGravity,
26
+ Accelerometer,
27
+ Gyroscope,
28
+ Command,
29
+ Time,
30
+ Carry,
31
+ }
32
+
33
+ impl InputType {
34
+ pub fn get_name(&self) -> &str {
35
+ match self {
36
+ InputType::JointAngles => "joint_angles",
37
+ InputType::JointAngularVelocities => "joint_angular_velocities",
38
+ InputType::ProjectedGravity => "projected_gravity",
39
+ InputType::Accelerometer => "accelerometer",
40
+ InputType::Gyroscope => "gyroscope",
41
+ InputType::Command => "command",
42
+ InputType::Time => "time",
43
+ InputType::Carry => "carry",
44
+ }
45
+ }
46
+
47
+ pub fn get_shape(&self, metadata: &ModelMetadata) -> Vec<usize> {
48
+ match self {
49
+ InputType::JointAngles => vec![metadata.joint_names.len()],
50
+ InputType::JointAngularVelocities => vec![metadata.joint_names.len()],
51
+ InputType::ProjectedGravity => vec![3],
52
+ InputType::Accelerometer => vec![3],
53
+ InputType::Gyroscope => vec![3],
54
+ InputType::Command => vec![metadata.num_commands.unwrap_or(0)],
55
+ InputType::Time => vec![1],
56
+ InputType::Carry => metadata.carry_size.clone(),
57
+ }
58
+ }
59
+
60
+ pub fn from_name(name: &str) -> Result<Self, Box<dyn std::error::Error>> {
61
+ match name {
62
+ "joint_angles" => Ok(InputType::JointAngles),
63
+ "joint_angular_velocities" => Ok(InputType::JointAngularVelocities),
64
+ "projected_gravity" => Ok(InputType::ProjectedGravity),
65
+ "accelerometer" => Ok(InputType::Accelerometer),
66
+ "gyroscope" => Ok(InputType::Gyroscope),
67
+ "command" => Ok(InputType::Command),
68
+ "time" => Ok(InputType::Time),
69
+ "carry" => Ok(InputType::Carry),
70
+ _ => Err(format!("Unknown input type: {}", name).into()),
71
+ }
72
+ }
73
+
74
+ pub fn get_names() -> Vec<&'static str> {
75
+ vec![
76
+ "joint_angles",
77
+ "joint_angular_velocities",
78
+ "projected_gravity",
79
+ "accelerometer",
80
+ "gyroscope",
81
+ "command",
82
+ "time",
83
+ "carry",
84
+ ]
85
+ }
86
+ }
@@ -8,14 +8,21 @@ import typing
8
8
 
9
9
  class ModelProviderABC:
10
10
  def __new__(cls) -> ModelProviderABC: ...
11
- def get_joint_angles(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
12
- def get_joint_angular_velocities(self, joint_names:typing.Sequence[builtins.str]) -> numpy.typing.NDArray[numpy.float32]: ...
13
- def get_projected_gravity(self) -> numpy.typing.NDArray[numpy.float32]: ...
14
- def get_accelerometer(self) -> numpy.typing.NDArray[numpy.float32]: ...
15
- def get_gyroscope(self) -> numpy.typing.NDArray[numpy.float32]: ...
16
- def get_command(self) -> numpy.typing.NDArray[numpy.float32]: ...
17
- def get_time(self) -> numpy.typing.NDArray[numpy.float32]: ...
18
- def take_action(self, joint_names:typing.Sequence[builtins.str], action:numpy.typing.NDArray[numpy.float32]) -> None: ...
11
+ def get_inputs(self, input_types:typing.Sequence[builtins.str], metadata:PyModelMetadata) -> builtins.dict[builtins.str, numpy.typing.NDArray[numpy.float32]]: ...
12
+ def take_action(self, action:numpy.typing.NDArray[numpy.float32], metadata:PyModelMetadata) -> None: ...
13
+
14
+ class PyInputType:
15
+ def __new__(cls, input_type:builtins.str) -> PyInputType: ...
16
+ def get_name(self) -> builtins.str: ...
17
+ def get_shape(self, metadata:PyModelMetadata) -> builtins.list[builtins.int]: ...
18
+ def __repr__(self) -> builtins.str: ...
19
+ def __eq__(self, other:typing.Any) -> builtins.bool: ...
20
+
21
+ class PyModelMetadata:
22
+ def __new__(self, joint_names:typing.Sequence[builtins.str], num_commands:typing.Optional[builtins.int], carry_size:typing.Sequence[builtins.int]) -> PyModelMetadata: ...
23
+ def to_json(self) -> builtins.str: ...
24
+ def __repr__(self) -> builtins.str: ...
25
+ def __eq__(self, other:typing.Any) -> builtins.bool: ...
19
26
 
20
27
  class PyModelProvider:
21
28
  ...
@@ -35,3 +42,5 @@ class PyModelRuntime:
35
42
 
36
43
  def get_version() -> builtins.str: ...
37
44
 
45
+ def metadata_from_json(json:builtins.str) -> PyModelMetadata: ...
46
+
@@ -1,89 +1,210 @@
1
1
  use async_trait::async_trait;
2
2
  use kinfer::model::{ModelError, ModelProvider, ModelRunner};
3
3
  use kinfer::runtime::ModelRuntime;
4
+ use kinfer::types::{InputType, ModelMetadata};
4
5
  use ndarray::{Array, Ix1, IxDyn};
5
6
  use numpy::{PyArray1, PyArrayDyn, PyArrayMethods};
6
7
  use pyo3::exceptions::PyNotImplementedError;
7
8
  use pyo3::prelude::*;
9
+ use pyo3::types::{PyAny, PyAnyMethods};
8
10
  use pyo3::{pymodule, types::PyModule, Bound, PyResult, Python};
9
11
  use pyo3_stub_gen::define_stub_info_gatherer;
10
12
  use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
13
+ use std::collections::HashMap;
14
+ use std::hash::Hash;
11
15
  use std::sync::Arc;
12
16
  use std::sync::Mutex;
13
- use std::time::Instant;
17
+
18
+ type StepResult = (Py<PyArrayDyn<f32>>, Py<PyArrayDyn<f32>>);
19
+
14
20
  #[pyfunction]
15
21
  #[gen_stub_pyfunction]
16
22
  fn get_version() -> String {
17
23
  env!("CARGO_PKG_VERSION").to_string()
18
24
  }
19
25
 
20
- #[pyclass(subclass)]
26
+ #[pyclass]
21
27
  #[gen_stub_pyclass]
22
- pub struct ModelProviderABC;
28
+ #[derive(Debug, Clone, PartialEq, Eq, Hash)]
29
+ struct PyInputType {
30
+ pub input_type: InputType,
31
+ }
32
+
33
+ impl From<InputType> for PyInputType {
34
+ fn from(input_type: InputType) -> Self {
35
+ Self { input_type }
36
+ }
37
+ }
38
+
39
+ impl From<PyInputType> for InputType {
40
+ fn from(input_type: PyInputType) -> Self {
41
+ input_type.input_type
42
+ }
43
+ }
23
44
 
24
45
  #[gen_stub_pymethods]
25
46
  #[pymethods]
26
- impl ModelProviderABC {
47
+ impl PyInputType {
27
48
  #[new]
28
- fn new() -> Self {
29
- ModelProviderABC
49
+ fn __new__(input_type: &str) -> PyResult<Self> {
50
+ let input_type = InputType::from_name(input_type).map_or_else(
51
+ |_| {
52
+ Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
53
+ "Invalid input type: {} (must be one of {})",
54
+ input_type,
55
+ InputType::get_names().join(", "),
56
+ )))
57
+ },
58
+ Ok,
59
+ )?;
60
+ Ok(Self { input_type })
30
61
  }
31
62
 
32
- fn get_joint_angles<'py>(
33
- &self,
34
- joint_names: Vec<String>,
35
- ) -> PyResult<Bound<'py, PyArray1<f32>>> {
36
- let n = joint_names.len();
37
- Err(PyNotImplementedError::new_err(format!(
38
- "Must override get_joint_angles with {} joint names",
39
- n
40
- )))
63
+ fn get_name(&self) -> String {
64
+ self.input_type.get_name().to_string()
41
65
  }
42
66
 
43
- fn get_joint_angular_velocities<'py>(
44
- &self,
67
+ fn get_shape(&self, metadata: PyModelMetadata) -> Vec<usize> {
68
+ self.input_type.get_shape(&metadata.into())
69
+ }
70
+
71
+ fn __repr__(&self) -> String {
72
+ format!("InputType({})", self.get_name())
73
+ }
74
+
75
+ fn __eq__(&self, other: Bound<'_, PyAny>) -> PyResult<bool> {
76
+ if let Ok(other) = other.extract::<PyInputType>() {
77
+ Ok(self == &other)
78
+ } else {
79
+ Ok(false)
80
+ }
81
+ }
82
+ }
83
+
84
+ #[pyclass]
85
+ #[gen_stub_pyclass]
86
+ #[derive(Debug, Clone, PartialEq, Eq, Hash)]
87
+ struct PyModelMetadata {
88
+ #[pyo3(get, set)]
89
+ pub joint_names: Vec<String>,
90
+ #[pyo3(get, set)]
91
+ pub num_commands: Option<usize>,
92
+ #[pyo3(get, set)]
93
+ pub carry_size: Vec<usize>,
94
+ }
95
+
96
+ #[pymethods]
97
+ #[gen_stub_pymethods]
98
+ impl PyModelMetadata {
99
+ #[new]
100
+ fn __new__(
45
101
  joint_names: Vec<String>,
46
- ) -> PyResult<Bound<'py, PyArray1<f32>>> {
47
- let n = joint_names.len();
48
- Err(PyNotImplementedError::new_err(format!(
49
- "Must override get_joint_angular_velocities with {} joint names",
50
- n
51
- )))
102
+ num_commands: Option<usize>,
103
+ carry_size: Vec<usize>,
104
+ ) -> Self {
105
+ Self {
106
+ joint_names,
107
+ num_commands,
108
+ carry_size,
109
+ }
52
110
  }
53
111
 
54
- fn get_projected_gravity<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
55
- Err(PyNotImplementedError::new_err(
56
- "Must override get_projected_gravity",
57
- ))
112
+ fn to_json(&self) -> PyResult<String> {
113
+ let metadata = ModelMetadata {
114
+ joint_names: self.joint_names.clone(),
115
+ num_commands: self.num_commands,
116
+ carry_size: self.carry_size.clone(),
117
+ }
118
+ .to_json()
119
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
120
+ Ok(metadata)
58
121
  }
59
122
 
60
- fn get_accelerometer<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
61
- Err(PyNotImplementedError::new_err(
62
- "Must override get_accelerometer",
63
- ))
123
+ fn __repr__(&self) -> PyResult<String> {
124
+ let json = self.to_json()?;
125
+ Ok(format!("ModelMetadata({:?})", json))
126
+ }
127
+
128
+ fn __eq__(&self, other: Bound<'_, PyAny>) -> PyResult<bool> {
129
+ if let Ok(other) = other.extract::<PyModelMetadata>() {
130
+ Ok(self == &other)
131
+ } else {
132
+ Ok(false)
133
+ }
64
134
  }
135
+ }
65
136
 
66
- fn get_gyroscope<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
67
- Err(PyNotImplementedError::new_err(
68
- "Must override get_gyroscope",
69
- ))
137
+ #[pyfunction]
138
+ #[gen_stub_pyfunction]
139
+ fn metadata_from_json(json: &str) -> PyResult<PyModelMetadata> {
140
+ let metadata = ModelMetadata::model_validate_json(json.to_string()).map_err(|e| {
141
+ PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Invalid model metadata: {}", e))
142
+ })?;
143
+ Ok(PyModelMetadata::from(metadata))
144
+ }
145
+
146
+ impl From<ModelMetadata> for PyModelMetadata {
147
+ fn from(metadata: ModelMetadata) -> Self {
148
+ Self {
149
+ joint_names: metadata.joint_names,
150
+ num_commands: metadata.num_commands,
151
+ carry_size: metadata.carry_size,
152
+ }
70
153
  }
154
+ }
71
155
 
72
- fn get_command<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
73
- Err(PyNotImplementedError::new_err("Must override get_command"))
156
+ impl From<&ModelMetadata> for PyModelMetadata {
157
+ fn from(metadata: &ModelMetadata) -> Self {
158
+ Self {
159
+ joint_names: metadata.joint_names.clone(),
160
+ num_commands: metadata.num_commands,
161
+ carry_size: metadata.carry_size.clone(),
162
+ }
74
163
  }
164
+ }
75
165
 
76
- fn get_time<'py>(&self) -> PyResult<Bound<'py, PyArray1<f32>>> {
77
- Err(PyNotImplementedError::new_err("Must override get_time"))
166
+ impl From<PyModelMetadata> for ModelMetadata {
167
+ fn from(metadata: PyModelMetadata) -> Self {
168
+ Self {
169
+ joint_names: metadata.joint_names,
170
+ num_commands: metadata.num_commands,
171
+ carry_size: metadata.carry_size,
172
+ }
78
173
  }
174
+ }
79
175
 
80
- fn take_action<'py>(
176
+ #[pyclass(subclass)]
177
+ #[gen_stub_pyclass]
178
+ struct ModelProviderABC;
179
+
180
+ #[gen_stub_pymethods]
181
+ #[pymethods]
182
+ impl ModelProviderABC {
183
+ #[new]
184
+ fn __new__() -> Self {
185
+ ModelProviderABC
186
+ }
187
+
188
+ fn get_inputs<'py>(
81
189
  &self,
82
- joint_names: Vec<String>,
83
- action: Bound<'py, PyArray1<f32>>,
190
+ input_types: Vec<String>,
191
+ metadata: PyModelMetadata,
192
+ ) -> PyResult<HashMap<String, Bound<'py, PyArrayDyn<f32>>>> {
193
+ Err(PyNotImplementedError::new_err(format!(
194
+ "Must override get_inputs with {} input types {:?} and metadata {:?}",
195
+ input_types.len(),
196
+ input_types,
197
+ metadata
198
+ )))
199
+ }
200
+
201
+ fn take_action(
202
+ &self,
203
+ action: Bound<'_, PyArray1<f32>>,
204
+ metadata: PyModelMetadata,
84
205
  ) -> PyResult<()> {
85
206
  let n = action.len()?;
86
- assert_eq!(joint_names.len(), n);
207
+ assert_eq!(metadata.joint_names.len(), n); // TODO: this is wrong
87
208
  Err(PyNotImplementedError::new_err(format!(
88
209
  "Must override take_action with {} action",
89
210
  n
@@ -96,127 +217,62 @@ impl ModelProviderABC {
96
217
  #[derive(Clone)]
97
218
  struct PyModelProvider {
98
219
  obj: Arc<Py<ModelProviderABC>>,
99
- start_time: Instant,
100
220
  }
101
221
 
102
222
  #[pymethods]
103
223
  impl PyModelProvider {
104
224
  #[new]
105
- fn new(obj: Py<ModelProviderABC>) -> Self {
106
- Self {
107
- obj: Arc::new(obj),
108
- start_time: Instant::now(),
109
- }
225
+ fn __new__(obj: Py<ModelProviderABC>) -> Self {
226
+ Self { obj: Arc::new(obj) }
110
227
  }
111
228
  }
112
229
 
113
230
  #[async_trait]
114
231
  impl ModelProvider for PyModelProvider {
115
- async fn get_joint_angles(
116
- &self,
117
- joint_names: &[String],
118
- ) -> Result<Array<f32, IxDyn>, ModelError> {
119
- let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
120
- let obj = self.obj.clone();
121
- let args = (joint_names,);
122
- let result = obj.call_method(py, "get_joint_angles", args, None)?;
123
- let array = result.extract::<Vec<f32>>(py)?;
124
- Ok(Array::from_vec(array).into_dyn())
125
- })
126
- .map_err(|e| ModelError::Provider(e.to_string()))?;
127
- Ok(args)
128
- }
129
-
130
- async fn get_joint_angular_velocities(
232
+ async fn get_inputs(
131
233
  &self,
132
- joint_names: &[String],
133
- ) -> Result<Array<f32, IxDyn>, ModelError> {
134
- let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
135
- let obj = self.obj.clone();
136
- let args = (joint_names,);
137
- let result = obj.call_method(py, "get_joint_angular_velocities", args, None)?;
138
- let array = result.extract::<Vec<f32>>(py)?;
139
- Ok(Array::from_vec(array).into_dyn())
140
- })
141
- .map_err(|e| ModelError::Provider(e.to_string()))?;
142
- Ok(args)
143
- }
144
-
145
- async fn get_projected_gravity(&self) -> Result<Array<f32, IxDyn>, ModelError> {
146
- let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
147
- let obj = self.obj.clone();
148
- let args = ();
149
- let result = obj.call_method(py, "get_projected_gravity", args, None)?;
150
- let array = result.extract::<Vec<f32>>(py)?;
151
- Ok(Array::from_vec(array).into_dyn())
152
- })
153
- .map_err(|e| ModelError::Provider(e.to_string()))?;
154
- Ok(args)
155
- }
156
-
157
- async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError> {
158
- let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
159
- let obj = self.obj.clone();
160
- let args = ();
161
- let result = obj.call_method(py, "get_accelerometer", args, None)?;
162
- let array = result.extract::<Vec<f32>>(py)?;
163
- Ok(Array::from_vec(array).into_dyn())
164
- })
165
- .map_err(|e| ModelError::Provider(e.to_string()))?;
166
- Ok(args)
167
- }
168
-
169
- async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError> {
170
- let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
171
- let obj = self.obj.clone();
172
- let args = ();
173
- let result = obj.call_method(py, "get_gyroscope", args, None)?;
174
- let array = result.extract::<Vec<f32>>(py)?;
175
- Ok(Array::from_vec(array).into_dyn())
176
- })
177
- .map_err(|e| ModelError::Provider(e.to_string()))?;
178
- Ok(args)
179
- }
180
-
181
- async fn get_command(&self) -> Result<Array<f32, IxDyn>, ModelError> {
182
- let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
183
- let obj = self.obj.clone();
184
- let args = ();
185
- let result = obj.call_method(py, "get_command", args, None)?;
186
- let array = result.extract::<Vec<f32>>(py)?;
187
- Ok(Array::from_vec(array).into_dyn())
188
- })
189
- .map_err(|e| ModelError::Provider(e.to_string()))?;
190
- Ok(args)
191
- }
192
-
193
- async fn get_time(&self) -> Result<Array<f32, IxDyn>, ModelError> {
194
- let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
234
+ input_types: &[InputType],
235
+ metadata: &ModelMetadata,
236
+ ) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError> {
237
+ let input_names: Vec<String> = input_types
238
+ .iter()
239
+ .map(|t| t.get_name().to_string())
240
+ .collect();
241
+ let result = Python::with_gil(|py| -> PyResult<HashMap<InputType, Array<f32, IxDyn>>> {
195
242
  let obj = self.obj.clone();
196
- let args = ();
197
- let result = obj.call_method(py, "get_time", args, None)?;
198
- let array = result.extract::<Vec<f32>>(py)?;
199
- Ok(Array::from_vec(array).into_dyn())
243
+ let args = (input_names.clone(), PyModelMetadata::from(metadata.clone()));
244
+ let result = obj.call_method(py, "get_inputs", args, None)?;
245
+ let dict: HashMap<String, Vec<f32>> = result.extract(py)?;
246
+ let mut arrays = HashMap::new();
247
+ for (i, name) in input_names.iter().enumerate() {
248
+ let array = dict.get(name).ok_or_else(|| {
249
+ PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
250
+ "Missing input: {}",
251
+ name
252
+ ))
253
+ })?;
254
+ arrays.insert(input_types[i], Array::from_vec(array.clone()).into_dyn());
255
+ }
256
+ Ok(arrays)
200
257
  })
201
258
  .map_err(|e| ModelError::Provider(e.to_string()))?;
202
- Ok(args)
203
- }
204
-
205
- async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError> {
206
- Ok(carry)
259
+ Ok(result)
207
260
  }
208
261
 
209
262
  async fn take_action(
210
263
  &self,
211
- joint_names: Vec<String>,
212
264
  action: Array<f32, IxDyn>,
265
+ metadata: &ModelMetadata,
213
266
  ) -> Result<(), ModelError> {
214
267
  Python::with_gil(|py| -> PyResult<()> {
215
268
  let obj = self.obj.clone();
216
269
  let action_1d = action
217
270
  .into_dimensionality::<Ix1>()
218
271
  .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
219
- let args = (joint_names, PyArray1::from_array(py, &action_1d));
272
+ let args = (
273
+ PyArray1::from_array(py, &action_1d),
274
+ PyModelMetadata::from(metadata.clone()),
275
+ );
220
276
  obj.call_method(py, "take_action", args, None)?;
221
277
  Ok(())
222
278
  })
@@ -236,10 +292,10 @@ struct PyModelRunner {
236
292
  #[pymethods]
237
293
  impl PyModelRunner {
238
294
  #[new]
239
- fn new(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
240
- let input_provider = Arc::new(PyModelProvider::new(provider));
295
+ fn __new__(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
296
+ let input_provider = Arc::new(PyModelProvider::__new__(provider));
241
297
 
242
- let runner = tokio::runtime::Runtime::new().unwrap().block_on(async {
298
+ let runner = tokio::runtime::Runtime::new()?.block_on(async {
243
299
  ModelRunner::new(model_path, input_provider)
244
300
  .await
245
301
  .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
@@ -252,7 +308,7 @@ impl PyModelRunner {
252
308
 
253
309
  fn init(&self) -> PyResult<Py<PyArrayDyn<f32>>> {
254
310
  let runner = self.runner.clone();
255
- let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
311
+ let result = tokio::runtime::Runtime::new()?.block_on(async {
256
312
  runner
257
313
  .init()
258
314
  .await
@@ -265,17 +321,14 @@ impl PyModelRunner {
265
321
  })
266
322
  }
267
323
 
268
- fn step(
269
- &self,
270
- carry: Py<PyArrayDyn<f32>>,
271
- ) -> PyResult<(Py<PyArrayDyn<f32>>, Py<PyArrayDyn<f32>>)> {
324
+ fn step(&self, carry: Py<PyArrayDyn<f32>>) -> PyResult<StepResult> {
272
325
  let runner = self.runner.clone();
273
326
  let carry_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
274
327
  let carry_array = carry.bind(py);
275
328
  Ok(carry_array.to_owned_array())
276
329
  })?;
277
330
 
278
- let result = tokio::runtime::Runtime::new().unwrap().block_on(async {
331
+ let result = tokio::runtime::Runtime::new()?.block_on(async {
279
332
  runner
280
333
  .step(carry_array)
281
334
  .await
@@ -297,7 +350,7 @@ impl PyModelRunner {
297
350
  Ok(action_array.to_owned_array())
298
351
  })?;
299
352
 
300
- tokio::runtime::Runtime::new().unwrap().block_on(async {
353
+ tokio::runtime::Runtime::new()?.block_on(async {
301
354
  runner
302
355
  .take_action(action_array)
303
356
  .await
@@ -319,38 +372,56 @@ struct PyModelRuntime {
319
372
  #[pymethods]
320
373
  impl PyModelRuntime {
321
374
  #[new]
322
- fn new(model_runner: PyModelRunner, dt: u64) -> PyResult<Self> {
375
+ fn __new__(model_runner: PyModelRunner, dt: u64) -> PyResult<Self> {
323
376
  Ok(Self {
324
377
  runtime: Arc::new(Mutex::new(ModelRuntime::new(model_runner.runner, dt))),
325
378
  })
326
379
  }
327
380
 
328
- fn set_slowdown_factor(&self, slowdown_factor: i32) {
329
- let mut runtime = self.runtime.lock().unwrap();
381
+ fn set_slowdown_factor(&self, slowdown_factor: i32) -> PyResult<()> {
382
+ let mut runtime = self
383
+ .runtime
384
+ .lock()
385
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
330
386
  runtime.set_slowdown_factor(slowdown_factor);
387
+ Ok(())
331
388
  }
332
389
 
333
- fn set_magnitude_factor(&self, magnitude_factor: f32) {
334
- let mut runtime = self.runtime.lock().unwrap();
390
+ fn set_magnitude_factor(&self, magnitude_factor: f32) -> PyResult<()> {
391
+ let mut runtime = self
392
+ .runtime
393
+ .lock()
394
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
335
395
  runtime.set_magnitude_factor(magnitude_factor);
396
+ Ok(())
336
397
  }
337
398
 
338
399
  fn start(&self) -> PyResult<()> {
339
- let mut runtime = self.runtime.lock().unwrap();
400
+ let mut runtime = self
401
+ .runtime
402
+ .lock()
403
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
340
404
  runtime
341
405
  .start()
342
406
  .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
343
407
  }
344
408
 
345
- fn stop(&self) {
346
- let mut runtime = self.runtime.lock().unwrap();
409
+ fn stop(&self) -> PyResult<()> {
410
+ let mut runtime = self
411
+ .runtime
412
+ .lock()
413
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
347
414
  runtime.stop();
415
+ Ok(())
348
416
  }
349
417
  }
350
418
 
351
419
  #[pymodule]
352
420
  fn rust_bindings(m: &Bound<PyModule>) -> PyResult<()> {
353
421
  m.add_function(wrap_pyfunction!(get_version, m)?)?;
422
+ m.add_class::<PyInputType>()?;
423
+ m.add_class::<PyModelMetadata>()?;
424
+ m.add_function(wrap_pyfunction!(metadata_from_json, m)?)?;
354
425
  m.add_class::<ModelProviderABC>()?;
355
426
  m.add_class::<PyModelRunner>()?;
356
427
  m.add_class::<PyModelRuntime>()?;
Binary file