kinfer 0.4.2__cp311-cp311-macosx_11_0_arm64.whl → 0.5.1__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/export/__init__.py +3 -0
- kinfer/export/jax.py +13 -11
- kinfer/export/pytorch.py +4 -14
- kinfer/export/serialize.py +18 -27
- kinfer/requirements.txt +5 -0
- kinfer/rust/Cargo.toml +3 -0
- kinfer/rust/src/lib.rs +3 -0
- kinfer/rust/src/logger.rs +135 -0
- kinfer/rust/src/model.rs +112 -93
- kinfer/rust/src/runtime.rs +12 -9
- kinfer/rust/src/types.rs +86 -0
- kinfer/rust_bindings/rust_bindings.pyi +17 -7
- kinfer/rust_bindings/src/lib.rs +228 -139
- kinfer/rust_bindings.cpython-311-darwin.so +0 -0
- kinfer/rust_bindings.pyi +17 -7
- kinfer/scripts/plot_ndjson.py +177 -0
- {kinfer-0.4.2.dist-info → kinfer-0.5.1.dist-info}/METADATA +6 -3
- kinfer-0.5.1.dist-info/RECORD +26 -0
- {kinfer-0.4.2.dist-info → kinfer-0.5.1.dist-info}/WHEEL +1 -1
- kinfer/common/__init__.py +0 -0
- kinfer/common/types.py +0 -12
- kinfer/export/common.py +0 -41
- kinfer-0.4.2.dist-info/RECORD +0 -26
- {kinfer-0.4.2.dist-info → kinfer-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {kinfer-0.4.2.dist-info → kinfer-0.5.1.dist-info}/top_level.txt +0 -0
kinfer/rust/src/runtime.rs
CHANGED
@@ -3,8 +3,9 @@ use std::sync::Arc;
|
|
3
3
|
use tokio::runtime::Runtime;
|
4
4
|
|
5
5
|
use crate::model::{ModelError, ModelRunner};
|
6
|
-
use
|
7
|
-
use
|
6
|
+
use crate::types::InputType;
|
7
|
+
use std::time::Duration;
|
8
|
+
use tokio::time::interval;
|
8
9
|
|
9
10
|
pub struct ModelRuntime {
|
10
11
|
model_runner: Arc<ModelRunner>,
|
@@ -57,11 +58,16 @@ impl ModelRuntime {
|
|
57
58
|
.init()
|
58
59
|
.await
|
59
60
|
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
60
|
-
|
61
|
-
|
61
|
+
|
62
|
+
let model_inputs = model_runner
|
63
|
+
.get_inputs(&[InputType::JointAngles])
|
62
64
|
.await
|
63
65
|
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
64
|
-
let mut
|
66
|
+
let mut joint_positions = model_inputs[&InputType::JointAngles].clone();
|
67
|
+
|
68
|
+
// Wait for the first tick, since it happens immediately.
|
69
|
+
let mut interval = interval(dt);
|
70
|
+
interval.tick().await;
|
65
71
|
|
66
72
|
while running.load(Ordering::Relaxed) {
|
67
73
|
let (output, next_carry) = model_runner
|
@@ -80,10 +86,7 @@ impl ModelRuntime {
|
|
80
86
|
.take_action(interp_joint_positions * magnitude_factor)
|
81
87
|
.await
|
82
88
|
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
83
|
-
|
84
|
-
if let Some(sleep_duration) = last_time.checked_duration_since(Instant::now()) {
|
85
|
-
sleep(sleep_duration).await;
|
86
|
-
}
|
89
|
+
interval.tick().await;
|
87
90
|
}
|
88
91
|
|
89
92
|
joint_positions = output;
|
kinfer/rust/src/types.rs
ADDED
@@ -0,0 +1,86 @@
|
|
1
|
+
use serde::Deserialize;
|
2
|
+
use serde::Serialize;
|
3
|
+
|
4
|
+
#[derive(Debug, Deserialize, Serialize, Clone)]
|
5
|
+
pub struct ModelMetadata {
|
6
|
+
pub joint_names: Vec<String>,
|
7
|
+
pub num_commands: Option<usize>,
|
8
|
+
pub carry_size: Vec<usize>,
|
9
|
+
}
|
10
|
+
|
11
|
+
impl ModelMetadata {
|
12
|
+
pub fn model_validate_json(json: String) -> Result<Self, Box<dyn std::error::Error>> {
|
13
|
+
Ok(serde_json::from_str(&json)?)
|
14
|
+
}
|
15
|
+
|
16
|
+
pub fn to_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
17
|
+
Ok(serde_json::to_string(self)?)
|
18
|
+
}
|
19
|
+
}
|
20
|
+
|
21
|
+
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
|
22
|
+
pub enum InputType {
|
23
|
+
JointAngles,
|
24
|
+
JointAngularVelocities,
|
25
|
+
ProjectedGravity,
|
26
|
+
Accelerometer,
|
27
|
+
Gyroscope,
|
28
|
+
Command,
|
29
|
+
Time,
|
30
|
+
Carry,
|
31
|
+
}
|
32
|
+
|
33
|
+
impl InputType {
|
34
|
+
pub fn get_name(&self) -> &str {
|
35
|
+
match self {
|
36
|
+
InputType::JointAngles => "joint_angles",
|
37
|
+
InputType::JointAngularVelocities => "joint_angular_velocities",
|
38
|
+
InputType::ProjectedGravity => "projected_gravity",
|
39
|
+
InputType::Accelerometer => "accelerometer",
|
40
|
+
InputType::Gyroscope => "gyroscope",
|
41
|
+
InputType::Command => "command",
|
42
|
+
InputType::Time => "time",
|
43
|
+
InputType::Carry => "carry",
|
44
|
+
}
|
45
|
+
}
|
46
|
+
|
47
|
+
pub fn get_shape(&self, metadata: &ModelMetadata) -> Vec<usize> {
|
48
|
+
match self {
|
49
|
+
InputType::JointAngles => vec![metadata.joint_names.len()],
|
50
|
+
InputType::JointAngularVelocities => vec![metadata.joint_names.len()],
|
51
|
+
InputType::ProjectedGravity => vec![3],
|
52
|
+
InputType::Accelerometer => vec![3],
|
53
|
+
InputType::Gyroscope => vec![3],
|
54
|
+
InputType::Command => vec![metadata.num_commands.unwrap_or(0)],
|
55
|
+
InputType::Time => vec![1],
|
56
|
+
InputType::Carry => metadata.carry_size.clone(),
|
57
|
+
}
|
58
|
+
}
|
59
|
+
|
60
|
+
pub fn from_name(name: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
61
|
+
match name {
|
62
|
+
"joint_angles" => Ok(InputType::JointAngles),
|
63
|
+
"joint_angular_velocities" => Ok(InputType::JointAngularVelocities),
|
64
|
+
"projected_gravity" => Ok(InputType::ProjectedGravity),
|
65
|
+
"accelerometer" => Ok(InputType::Accelerometer),
|
66
|
+
"gyroscope" => Ok(InputType::Gyroscope),
|
67
|
+
"command" => Ok(InputType::Command),
|
68
|
+
"time" => Ok(InputType::Time),
|
69
|
+
"carry" => Ok(InputType::Carry),
|
70
|
+
_ => Err(format!("Unknown input type: {}", name).into()),
|
71
|
+
}
|
72
|
+
}
|
73
|
+
|
74
|
+
pub fn get_names() -> Vec<&'static str> {
|
75
|
+
vec![
|
76
|
+
"joint_angles",
|
77
|
+
"joint_angular_velocities",
|
78
|
+
"projected_gravity",
|
79
|
+
"accelerometer",
|
80
|
+
"gyroscope",
|
81
|
+
"command",
|
82
|
+
"time",
|
83
|
+
"carry",
|
84
|
+
]
|
85
|
+
}
|
86
|
+
}
|
@@ -8,13 +8,21 @@ import typing
|
|
8
8
|
|
9
9
|
class ModelProviderABC:
|
10
10
|
def __new__(cls) -> ModelProviderABC: ...
|
11
|
-
def
|
12
|
-
def
|
13
|
-
|
14
|
-
|
15
|
-
def
|
16
|
-
def
|
17
|
-
def
|
11
|
+
def get_inputs(self, input_types:typing.Sequence[builtins.str], metadata:PyModelMetadata) -> builtins.dict[builtins.str, numpy.typing.NDArray[numpy.float32]]: ...
|
12
|
+
def take_action(self, action:numpy.typing.NDArray[numpy.float32], metadata:PyModelMetadata) -> None: ...
|
13
|
+
|
14
|
+
class PyInputType:
|
15
|
+
def __new__(cls, input_type:builtins.str) -> PyInputType: ...
|
16
|
+
def get_name(self) -> builtins.str: ...
|
17
|
+
def get_shape(self, metadata:PyModelMetadata) -> builtins.list[builtins.int]: ...
|
18
|
+
def __repr__(self) -> builtins.str: ...
|
19
|
+
def __eq__(self, other:typing.Any) -> builtins.bool: ...
|
20
|
+
|
21
|
+
class PyModelMetadata:
|
22
|
+
def __new__(self, joint_names:typing.Sequence[builtins.str], num_commands:typing.Optional[builtins.int], carry_size:typing.Sequence[builtins.int]) -> PyModelMetadata: ...
|
23
|
+
def to_json(self) -> builtins.str: ...
|
24
|
+
def __repr__(self) -> builtins.str: ...
|
25
|
+
def __eq__(self, other:typing.Any) -> builtins.bool: ...
|
18
26
|
|
19
27
|
class PyModelProvider:
|
20
28
|
...
|
@@ -34,3 +42,5 @@ class PyModelRuntime:
|
|
34
42
|
|
35
43
|
def get_version() -> builtins.str: ...
|
36
44
|
|
45
|
+
def metadata_from_json(json:builtins.str) -> PyModelMetadata: ...
|
46
|
+
|
kinfer/rust_bindings/src/lib.rs
CHANGED
@@ -1,85 +1,210 @@
|
|
1
1
|
use async_trait::async_trait;
|
2
2
|
use kinfer::model::{ModelError, ModelProvider, ModelRunner};
|
3
3
|
use kinfer::runtime::ModelRuntime;
|
4
|
+
use kinfer::types::{InputType, ModelMetadata};
|
4
5
|
use ndarray::{Array, Ix1, IxDyn};
|
5
6
|
use numpy::{PyArray1, PyArrayDyn, PyArrayMethods};
|
6
7
|
use pyo3::exceptions::PyNotImplementedError;
|
7
8
|
use pyo3::prelude::*;
|
9
|
+
use pyo3::types::{PyAny, PyAnyMethods};
|
8
10
|
use pyo3::{pymodule, types::PyModule, Bound, PyResult, Python};
|
9
11
|
use pyo3_stub_gen::define_stub_info_gatherer;
|
10
12
|
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
|
13
|
+
use std::collections::HashMap;
|
14
|
+
use std::hash::Hash;
|
11
15
|
use std::sync::Arc;
|
12
16
|
use std::sync::Mutex;
|
13
17
|
|
18
|
+
type StepResult = (Py<PyArrayDyn<f32>>, Py<PyArrayDyn<f32>>);
|
19
|
+
|
14
20
|
#[pyfunction]
|
15
21
|
#[gen_stub_pyfunction]
|
16
22
|
fn get_version() -> String {
|
17
23
|
env!("CARGO_PKG_VERSION").to_string()
|
18
24
|
}
|
19
25
|
|
20
|
-
#[pyclass
|
26
|
+
#[pyclass]
|
21
27
|
#[gen_stub_pyclass]
|
22
|
-
|
28
|
+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
29
|
+
struct PyInputType {
|
30
|
+
pub input_type: InputType,
|
31
|
+
}
|
32
|
+
|
33
|
+
impl From<InputType> for PyInputType {
|
34
|
+
fn from(input_type: InputType) -> Self {
|
35
|
+
Self { input_type }
|
36
|
+
}
|
37
|
+
}
|
38
|
+
|
39
|
+
impl From<PyInputType> for InputType {
|
40
|
+
fn from(input_type: PyInputType) -> Self {
|
41
|
+
input_type.input_type
|
42
|
+
}
|
43
|
+
}
|
23
44
|
|
24
45
|
#[gen_stub_pymethods]
|
25
46
|
#[pymethods]
|
26
|
-
impl
|
47
|
+
impl PyInputType {
|
27
48
|
#[new]
|
28
|
-
fn
|
29
|
-
|
49
|
+
fn __new__(input_type: &str) -> PyResult<Self> {
|
50
|
+
let input_type = InputType::from_name(input_type).map_or_else(
|
51
|
+
|_| {
|
52
|
+
Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
53
|
+
"Invalid input type: {} (must be one of {})",
|
54
|
+
input_type,
|
55
|
+
InputType::get_names().join(", "),
|
56
|
+
)))
|
57
|
+
},
|
58
|
+
Ok,
|
59
|
+
)?;
|
60
|
+
Ok(Self { input_type })
|
30
61
|
}
|
31
62
|
|
32
|
-
fn
|
33
|
-
|
34
|
-
joint_names: Vec<String>,
|
35
|
-
) -> PyResult<Bound<'py, PyArray1<f32>>> {
|
36
|
-
let n = joint_names.len();
|
37
|
-
Err(PyNotImplementedError::new_err(format!(
|
38
|
-
"Must override get_joint_angles with {} joint names",
|
39
|
-
n
|
40
|
-
)))
|
63
|
+
fn get_name(&self) -> String {
|
64
|
+
self.input_type.get_name().to_string()
|
41
65
|
}
|
42
66
|
|
43
|
-
fn
|
44
|
-
&
|
67
|
+
fn get_shape(&self, metadata: PyModelMetadata) -> Vec<usize> {
|
68
|
+
self.input_type.get_shape(&metadata.into())
|
69
|
+
}
|
70
|
+
|
71
|
+
fn __repr__(&self) -> String {
|
72
|
+
format!("InputType({})", self.get_name())
|
73
|
+
}
|
74
|
+
|
75
|
+
fn __eq__(&self, other: Bound<'_, PyAny>) -> PyResult<bool> {
|
76
|
+
if let Ok(other) = other.extract::<PyInputType>() {
|
77
|
+
Ok(self == &other)
|
78
|
+
} else {
|
79
|
+
Ok(false)
|
80
|
+
}
|
81
|
+
}
|
82
|
+
}
|
83
|
+
|
84
|
+
#[pyclass]
|
85
|
+
#[gen_stub_pyclass]
|
86
|
+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
87
|
+
struct PyModelMetadata {
|
88
|
+
#[pyo3(get, set)]
|
89
|
+
pub joint_names: Vec<String>,
|
90
|
+
#[pyo3(get, set)]
|
91
|
+
pub num_commands: Option<usize>,
|
92
|
+
#[pyo3(get, set)]
|
93
|
+
pub carry_size: Vec<usize>,
|
94
|
+
}
|
95
|
+
|
96
|
+
#[pymethods]
|
97
|
+
#[gen_stub_pymethods]
|
98
|
+
impl PyModelMetadata {
|
99
|
+
#[new]
|
100
|
+
fn __new__(
|
45
101
|
joint_names: Vec<String>,
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
102
|
+
num_commands: Option<usize>,
|
103
|
+
carry_size: Vec<usize>,
|
104
|
+
) -> Self {
|
105
|
+
Self {
|
106
|
+
joint_names,
|
107
|
+
num_commands,
|
108
|
+
carry_size,
|
109
|
+
}
|
110
|
+
}
|
111
|
+
|
112
|
+
fn to_json(&self) -> PyResult<String> {
|
113
|
+
let metadata = ModelMetadata {
|
114
|
+
joint_names: self.joint_names.clone(),
|
115
|
+
num_commands: self.num_commands,
|
116
|
+
carry_size: self.carry_size.clone(),
|
117
|
+
}
|
118
|
+
.to_json()
|
119
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
120
|
+
Ok(metadata)
|
121
|
+
}
|
122
|
+
|
123
|
+
fn __repr__(&self) -> PyResult<String> {
|
124
|
+
let json = self.to_json()?;
|
125
|
+
Ok(format!("ModelMetadata({:?})", json))
|
126
|
+
}
|
127
|
+
|
128
|
+
fn __eq__(&self, other: Bound<'_, PyAny>) -> PyResult<bool> {
|
129
|
+
if let Ok(other) = other.extract::<PyModelMetadata>() {
|
130
|
+
Ok(self == &other)
|
131
|
+
} else {
|
132
|
+
Ok(false)
|
133
|
+
}
|
134
|
+
}
|
135
|
+
}
|
136
|
+
|
137
|
+
#[pyfunction]
|
138
|
+
#[gen_stub_pyfunction]
|
139
|
+
fn metadata_from_json(json: &str) -> PyResult<PyModelMetadata> {
|
140
|
+
let metadata = ModelMetadata::model_validate_json(json.to_string()).map_err(|e| {
|
141
|
+
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Invalid model metadata: {}", e))
|
142
|
+
})?;
|
143
|
+
Ok(PyModelMetadata::from(metadata))
|
144
|
+
}
|
145
|
+
|
146
|
+
impl From<ModelMetadata> for PyModelMetadata {
|
147
|
+
fn from(metadata: ModelMetadata) -> Self {
|
148
|
+
Self {
|
149
|
+
joint_names: metadata.joint_names,
|
150
|
+
num_commands: metadata.num_commands,
|
151
|
+
carry_size: metadata.carry_size,
|
152
|
+
}
|
52
153
|
}
|
154
|
+
}
|
53
155
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
156
|
+
impl From<&ModelMetadata> for PyModelMetadata {
|
157
|
+
fn from(metadata: &ModelMetadata) -> Self {
|
158
|
+
Self {
|
159
|
+
joint_names: metadata.joint_names.clone(),
|
160
|
+
num_commands: metadata.num_commands,
|
161
|
+
carry_size: metadata.carry_size.clone(),
|
162
|
+
}
|
58
163
|
}
|
164
|
+
}
|
59
165
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
166
|
+
impl From<PyModelMetadata> for ModelMetadata {
|
167
|
+
fn from(metadata: PyModelMetadata) -> Self {
|
168
|
+
Self {
|
169
|
+
joint_names: metadata.joint_names,
|
170
|
+
num_commands: metadata.num_commands,
|
171
|
+
carry_size: metadata.carry_size,
|
172
|
+
}
|
64
173
|
}
|
174
|
+
}
|
175
|
+
|
176
|
+
#[pyclass(subclass)]
|
177
|
+
#[gen_stub_pyclass]
|
178
|
+
struct ModelProviderABC;
|
65
179
|
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
180
|
+
#[gen_stub_pymethods]
|
181
|
+
#[pymethods]
|
182
|
+
impl ModelProviderABC {
|
183
|
+
#[new]
|
184
|
+
fn __new__() -> Self {
|
185
|
+
ModelProviderABC
|
70
186
|
}
|
71
187
|
|
72
|
-
fn
|
73
|
-
|
188
|
+
fn get_inputs<'py>(
|
189
|
+
&self,
|
190
|
+
input_types: Vec<String>,
|
191
|
+
metadata: PyModelMetadata,
|
192
|
+
) -> PyResult<HashMap<String, Bound<'py, PyArrayDyn<f32>>>> {
|
193
|
+
Err(PyNotImplementedError::new_err(format!(
|
194
|
+
"Must override get_inputs with {} input types {:?} and metadata {:?}",
|
195
|
+
input_types.len(),
|
196
|
+
input_types,
|
197
|
+
metadata
|
198
|
+
)))
|
74
199
|
}
|
75
200
|
|
76
|
-
fn take_action
|
201
|
+
fn take_action(
|
77
202
|
&self,
|
78
|
-
|
79
|
-
|
203
|
+
action: Bound<'_, PyArray1<f32>>,
|
204
|
+
metadata: PyModelMetadata,
|
80
205
|
) -> PyResult<()> {
|
81
206
|
let n = action.len()?;
|
82
|
-
assert_eq!(joint_names.len(), n);
|
207
|
+
assert_eq!(metadata.joint_names.len(), n); // TODO: this is wrong
|
83
208
|
Err(PyNotImplementedError::new_err(format!(
|
84
209
|
"Must override take_action with {} action",
|
85
210
|
n
|
@@ -97,106 +222,57 @@ struct PyModelProvider {
|
|
97
222
|
#[pymethods]
|
98
223
|
impl PyModelProvider {
|
99
224
|
#[new]
|
100
|
-
fn
|
225
|
+
fn __new__(obj: Py<ModelProviderABC>) -> Self {
|
101
226
|
Self { obj: Arc::new(obj) }
|
102
227
|
}
|
103
228
|
}
|
104
229
|
|
105
230
|
#[async_trait]
|
106
231
|
impl ModelProvider for PyModelProvider {
|
107
|
-
async fn
|
108
|
-
&self,
|
109
|
-
joint_names: &[String],
|
110
|
-
) -> Result<Array<f32, IxDyn>, ModelError> {
|
111
|
-
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
112
|
-
let obj = self.obj.clone();
|
113
|
-
let args = (joint_names,);
|
114
|
-
let result = obj.call_method(py, "get_joint_angles", args, None)?;
|
115
|
-
let array = result.extract::<Vec<f32>>(py)?;
|
116
|
-
Ok(Array::from_vec(array).into_dyn())
|
117
|
-
})
|
118
|
-
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
119
|
-
Ok(args)
|
120
|
-
}
|
121
|
-
|
122
|
-
async fn get_joint_angular_velocities(
|
232
|
+
async fn get_inputs(
|
123
233
|
&self,
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
})
|
133
|
-
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
134
|
-
Ok(args)
|
135
|
-
}
|
136
|
-
|
137
|
-
async fn get_projected_gravity(&self) -> Result<Array<f32, IxDyn>, ModelError> {
|
138
|
-
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
139
|
-
let obj = self.obj.clone();
|
140
|
-
let args = ();
|
141
|
-
let result = obj.call_method(py, "get_projected_gravity", args, None)?;
|
142
|
-
let array = result.extract::<Vec<f32>>(py)?;
|
143
|
-
Ok(Array::from_vec(array).into_dyn())
|
144
|
-
})
|
145
|
-
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
146
|
-
Ok(args)
|
147
|
-
}
|
148
|
-
|
149
|
-
async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError> {
|
150
|
-
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
151
|
-
let obj = self.obj.clone();
|
152
|
-
let args = ();
|
153
|
-
let result = obj.call_method(py, "get_accelerometer", args, None)?;
|
154
|
-
let array = result.extract::<Vec<f32>>(py)?;
|
155
|
-
Ok(Array::from_vec(array).into_dyn())
|
156
|
-
})
|
157
|
-
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
158
|
-
Ok(args)
|
159
|
-
}
|
160
|
-
|
161
|
-
async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError> {
|
162
|
-
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
234
|
+
input_types: &[InputType],
|
235
|
+
metadata: &ModelMetadata,
|
236
|
+
) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError> {
|
237
|
+
let input_names: Vec<String> = input_types
|
238
|
+
.iter()
|
239
|
+
.map(|t| t.get_name().to_string())
|
240
|
+
.collect();
|
241
|
+
let result = Python::with_gil(|py| -> PyResult<HashMap<InputType, Array<f32, IxDyn>>> {
|
163
242
|
let obj = self.obj.clone();
|
164
|
-
let args = ();
|
165
|
-
let result = obj.call_method(py, "
|
166
|
-
let
|
167
|
-
|
243
|
+
let args = (input_names.clone(), PyModelMetadata::from(metadata.clone()));
|
244
|
+
let result = obj.call_method(py, "get_inputs", args, None)?;
|
245
|
+
let dict: HashMap<String, Vec<f32>> = result.extract(py)?;
|
246
|
+
let mut arrays = HashMap::new();
|
247
|
+
for (i, name) in input_names.iter().enumerate() {
|
248
|
+
let array = dict.get(name).ok_or_else(|| {
|
249
|
+
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
250
|
+
"Missing input: {}",
|
251
|
+
name
|
252
|
+
))
|
253
|
+
})?;
|
254
|
+
arrays.insert(input_types[i], Array::from_vec(array.clone()).into_dyn());
|
255
|
+
}
|
256
|
+
Ok(arrays)
|
168
257
|
})
|
169
258
|
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
170
|
-
Ok(
|
171
|
-
}
|
172
|
-
|
173
|
-
async fn get_command(&self) -> Result<Array<f32, IxDyn>, ModelError> {
|
174
|
-
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
175
|
-
let obj = self.obj.clone();
|
176
|
-
let args = ();
|
177
|
-
let result = obj.call_method(py, "get_command", args, None)?;
|
178
|
-
let array = result.extract::<Vec<f32>>(py)?;
|
179
|
-
Ok(Array::from_vec(array).into_dyn())
|
180
|
-
})
|
181
|
-
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
182
|
-
Ok(args)
|
183
|
-
}
|
184
|
-
|
185
|
-
async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError> {
|
186
|
-
Ok(carry)
|
259
|
+
Ok(result)
|
187
260
|
}
|
188
261
|
|
189
262
|
async fn take_action(
|
190
263
|
&self,
|
191
|
-
joint_names: Vec<String>,
|
192
264
|
action: Array<f32, IxDyn>,
|
265
|
+
metadata: &ModelMetadata,
|
193
266
|
) -> Result<(), ModelError> {
|
194
267
|
Python::with_gil(|py| -> PyResult<()> {
|
195
268
|
let obj = self.obj.clone();
|
196
269
|
let action_1d = action
|
197
270
|
.into_dimensionality::<Ix1>()
|
198
271
|
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
199
|
-
let args = (
|
272
|
+
let args = (
|
273
|
+
PyArray1::from_array(py, &action_1d),
|
274
|
+
PyModelMetadata::from(metadata.clone()),
|
275
|
+
);
|
200
276
|
obj.call_method(py, "take_action", args, None)?;
|
201
277
|
Ok(())
|
202
278
|
})
|
@@ -216,12 +292,10 @@ struct PyModelRunner {
|
|
216
292
|
#[pymethods]
|
217
293
|
impl PyModelRunner {
|
218
294
|
#[new]
|
219
|
-
fn
|
220
|
-
let input_provider = Arc::new(PyModelProvider
|
221
|
-
obj: Arc::new(provider),
|
222
|
-
});
|
295
|
+
fn __new__(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
|
296
|
+
let input_provider = Arc::new(PyModelProvider::__new__(provider));
|
223
297
|
|
224
|
-
let runner = tokio::runtime::Runtime::new()
|
298
|
+
let runner = tokio::runtime::Runtime::new()?.block_on(async {
|
225
299
|
ModelRunner::new(model_path, input_provider)
|
226
300
|
.await
|
227
301
|
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
@@ -234,7 +308,7 @@ impl PyModelRunner {
|
|
234
308
|
|
235
309
|
fn init(&self) -> PyResult<Py<PyArrayDyn<f32>>> {
|
236
310
|
let runner = self.runner.clone();
|
237
|
-
let result = tokio::runtime::Runtime::new()
|
311
|
+
let result = tokio::runtime::Runtime::new()?.block_on(async {
|
238
312
|
runner
|
239
313
|
.init()
|
240
314
|
.await
|
@@ -247,17 +321,14 @@ impl PyModelRunner {
|
|
247
321
|
})
|
248
322
|
}
|
249
323
|
|
250
|
-
fn step(
|
251
|
-
&self,
|
252
|
-
carry: Py<PyArrayDyn<f32>>,
|
253
|
-
) -> PyResult<(Py<PyArrayDyn<f32>>, Py<PyArrayDyn<f32>>)> {
|
324
|
+
fn step(&self, carry: Py<PyArrayDyn<f32>>) -> PyResult<StepResult> {
|
254
325
|
let runner = self.runner.clone();
|
255
326
|
let carry_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
256
327
|
let carry_array = carry.bind(py);
|
257
328
|
Ok(carry_array.to_owned_array())
|
258
329
|
})?;
|
259
330
|
|
260
|
-
let result = tokio::runtime::Runtime::new()
|
331
|
+
let result = tokio::runtime::Runtime::new()?.block_on(async {
|
261
332
|
runner
|
262
333
|
.step(carry_array)
|
263
334
|
.await
|
@@ -279,7 +350,7 @@ impl PyModelRunner {
|
|
279
350
|
Ok(action_array.to_owned_array())
|
280
351
|
})?;
|
281
352
|
|
282
|
-
tokio::runtime::Runtime::new()
|
353
|
+
tokio::runtime::Runtime::new()?.block_on(async {
|
283
354
|
runner
|
284
355
|
.take_action(action_array)
|
285
356
|
.await
|
@@ -301,38 +372,56 @@ struct PyModelRuntime {
|
|
301
372
|
#[pymethods]
|
302
373
|
impl PyModelRuntime {
|
303
374
|
#[new]
|
304
|
-
fn
|
375
|
+
fn __new__(model_runner: PyModelRunner, dt: u64) -> PyResult<Self> {
|
305
376
|
Ok(Self {
|
306
377
|
runtime: Arc::new(Mutex::new(ModelRuntime::new(model_runner.runner, dt))),
|
307
378
|
})
|
308
379
|
}
|
309
380
|
|
310
|
-
fn set_slowdown_factor(&self, slowdown_factor: i32) {
|
311
|
-
let mut runtime = self
|
381
|
+
fn set_slowdown_factor(&self, slowdown_factor: i32) -> PyResult<()> {
|
382
|
+
let mut runtime = self
|
383
|
+
.runtime
|
384
|
+
.lock()
|
385
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
312
386
|
runtime.set_slowdown_factor(slowdown_factor);
|
387
|
+
Ok(())
|
313
388
|
}
|
314
389
|
|
315
|
-
fn set_magnitude_factor(&self, magnitude_factor: f32) {
|
316
|
-
let mut runtime = self
|
390
|
+
fn set_magnitude_factor(&self, magnitude_factor: f32) -> PyResult<()> {
|
391
|
+
let mut runtime = self
|
392
|
+
.runtime
|
393
|
+
.lock()
|
394
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
317
395
|
runtime.set_magnitude_factor(magnitude_factor);
|
396
|
+
Ok(())
|
318
397
|
}
|
319
398
|
|
320
399
|
fn start(&self) -> PyResult<()> {
|
321
|
-
let mut runtime = self
|
400
|
+
let mut runtime = self
|
401
|
+
.runtime
|
402
|
+
.lock()
|
403
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
322
404
|
runtime
|
323
405
|
.start()
|
324
406
|
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
325
407
|
}
|
326
408
|
|
327
|
-
fn stop(&self) {
|
328
|
-
let mut runtime = self
|
409
|
+
fn stop(&self) -> PyResult<()> {
|
410
|
+
let mut runtime = self
|
411
|
+
.runtime
|
412
|
+
.lock()
|
413
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
329
414
|
runtime.stop();
|
415
|
+
Ok(())
|
330
416
|
}
|
331
417
|
}
|
332
418
|
|
333
419
|
#[pymodule]
|
334
420
|
fn rust_bindings(m: &Bound<PyModule>) -> PyResult<()> {
|
335
421
|
m.add_function(wrap_pyfunction!(get_version, m)?)?;
|
422
|
+
m.add_class::<PyInputType>()?;
|
423
|
+
m.add_class::<PyModelMetadata>()?;
|
424
|
+
m.add_function(wrap_pyfunction!(metadata_from_json, m)?)?;
|
336
425
|
m.add_class::<ModelProviderABC>()?;
|
337
426
|
m.add_class::<PyModelRunner>()?;
|
338
427
|
m.add_class::<PyModelRuntime>()?;
|
Binary file
|