kinfer 0.4.3__cp311-cp311-macosx_11_0_arm64.whl → 0.5.1__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/export/__init__.py +3 -0
- kinfer/export/jax.py +13 -11
- kinfer/export/pytorch.py +4 -14
- kinfer/export/serialize.py +18 -27
- kinfer/requirements.txt +5 -0
- kinfer/rust/Cargo.toml +3 -0
- kinfer/rust/src/lib.rs +3 -0
- kinfer/rust/src/logger.rs +135 -0
- kinfer/rust/src/model.rs +112 -104
- kinfer/rust/src/runtime.rs +5 -2
- kinfer/rust/src/types.rs +86 -0
- kinfer/rust_bindings/rust_bindings.pyi +17 -8
- kinfer/rust_bindings/src/lib.rs +228 -157
- kinfer/rust_bindings.cpython-311-darwin.so +0 -0
- kinfer/rust_bindings.pyi +17 -8
- kinfer/scripts/plot_ndjson.py +177 -0
- {kinfer-0.4.3.dist-info → kinfer-0.5.1.dist-info}/METADATA +4 -1
- kinfer-0.5.1.dist-info/RECORD +26 -0
- {kinfer-0.4.3.dist-info → kinfer-0.5.1.dist-info}/WHEEL +1 -1
- kinfer/common/__init__.py +0 -0
- kinfer/common/types.py +0 -12
- kinfer/export/common.py +0 -44
- kinfer-0.4.3.dist-info/RECORD +0 -26
- {kinfer-0.4.3.dist-info → kinfer-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {kinfer-0.4.3.dist-info → kinfer-0.5.1.dist-info}/top_level.txt +0 -0
kinfer/rust/src/types.rs
ADDED
@@ -0,0 +1,86 @@
|
|
1
|
+
use serde::Deserialize;
|
2
|
+
use serde::Serialize;
|
3
|
+
|
4
|
+
#[derive(Debug, Deserialize, Serialize, Clone)]
|
5
|
+
pub struct ModelMetadata {
|
6
|
+
pub joint_names: Vec<String>,
|
7
|
+
pub num_commands: Option<usize>,
|
8
|
+
pub carry_size: Vec<usize>,
|
9
|
+
}
|
10
|
+
|
11
|
+
impl ModelMetadata {
|
12
|
+
pub fn model_validate_json(json: String) -> Result<Self, Box<dyn std::error::Error>> {
|
13
|
+
Ok(serde_json::from_str(&json)?)
|
14
|
+
}
|
15
|
+
|
16
|
+
pub fn to_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
17
|
+
Ok(serde_json::to_string(self)?)
|
18
|
+
}
|
19
|
+
}
|
20
|
+
|
21
|
+
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
|
22
|
+
pub enum InputType {
|
23
|
+
JointAngles,
|
24
|
+
JointAngularVelocities,
|
25
|
+
ProjectedGravity,
|
26
|
+
Accelerometer,
|
27
|
+
Gyroscope,
|
28
|
+
Command,
|
29
|
+
Time,
|
30
|
+
Carry,
|
31
|
+
}
|
32
|
+
|
33
|
+
impl InputType {
|
34
|
+
pub fn get_name(&self) -> &str {
|
35
|
+
match self {
|
36
|
+
InputType::JointAngles => "joint_angles",
|
37
|
+
InputType::JointAngularVelocities => "joint_angular_velocities",
|
38
|
+
InputType::ProjectedGravity => "projected_gravity",
|
39
|
+
InputType::Accelerometer => "accelerometer",
|
40
|
+
InputType::Gyroscope => "gyroscope",
|
41
|
+
InputType::Command => "command",
|
42
|
+
InputType::Time => "time",
|
43
|
+
InputType::Carry => "carry",
|
44
|
+
}
|
45
|
+
}
|
46
|
+
|
47
|
+
pub fn get_shape(&self, metadata: &ModelMetadata) -> Vec<usize> {
|
48
|
+
match self {
|
49
|
+
InputType::JointAngles => vec![metadata.joint_names.len()],
|
50
|
+
InputType::JointAngularVelocities => vec![metadata.joint_names.len()],
|
51
|
+
InputType::ProjectedGravity => vec![3],
|
52
|
+
InputType::Accelerometer => vec![3],
|
53
|
+
InputType::Gyroscope => vec![3],
|
54
|
+
InputType::Command => vec![metadata.num_commands.unwrap_or(0)],
|
55
|
+
InputType::Time => vec![1],
|
56
|
+
InputType::Carry => metadata.carry_size.clone(),
|
57
|
+
}
|
58
|
+
}
|
59
|
+
|
60
|
+
pub fn from_name(name: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
61
|
+
match name {
|
62
|
+
"joint_angles" => Ok(InputType::JointAngles),
|
63
|
+
"joint_angular_velocities" => Ok(InputType::JointAngularVelocities),
|
64
|
+
"projected_gravity" => Ok(InputType::ProjectedGravity),
|
65
|
+
"accelerometer" => Ok(InputType::Accelerometer),
|
66
|
+
"gyroscope" => Ok(InputType::Gyroscope),
|
67
|
+
"command" => Ok(InputType::Command),
|
68
|
+
"time" => Ok(InputType::Time),
|
69
|
+
"carry" => Ok(InputType::Carry),
|
70
|
+
_ => Err(format!("Unknown input type: {}", name).into()),
|
71
|
+
}
|
72
|
+
}
|
73
|
+
|
74
|
+
pub fn get_names() -> Vec<&'static str> {
|
75
|
+
vec![
|
76
|
+
"joint_angles",
|
77
|
+
"joint_angular_velocities",
|
78
|
+
"projected_gravity",
|
79
|
+
"accelerometer",
|
80
|
+
"gyroscope",
|
81
|
+
"command",
|
82
|
+
"time",
|
83
|
+
"carry",
|
84
|
+
]
|
85
|
+
}
|
86
|
+
}
|
@@ -8,14 +8,21 @@ import typing
|
|
8
8
|
|
9
9
|
class ModelProviderABC:
|
10
10
|
def __new__(cls) -> ModelProviderABC: ...
|
11
|
-
def
|
12
|
-
def
|
13
|
-
|
14
|
-
|
15
|
-
def
|
16
|
-
def
|
17
|
-
def
|
18
|
-
def
|
11
|
+
def get_inputs(self, input_types:typing.Sequence[builtins.str], metadata:PyModelMetadata) -> builtins.dict[builtins.str, numpy.typing.NDArray[numpy.float32]]: ...
|
12
|
+
def take_action(self, action:numpy.typing.NDArray[numpy.float32], metadata:PyModelMetadata) -> None: ...
|
13
|
+
|
14
|
+
class PyInputType:
|
15
|
+
def __new__(cls, input_type:builtins.str) -> PyInputType: ...
|
16
|
+
def get_name(self) -> builtins.str: ...
|
17
|
+
def get_shape(self, metadata:PyModelMetadata) -> builtins.list[builtins.int]: ...
|
18
|
+
def __repr__(self) -> builtins.str: ...
|
19
|
+
def __eq__(self, other:typing.Any) -> builtins.bool: ...
|
20
|
+
|
21
|
+
class PyModelMetadata:
|
22
|
+
def __new__(self, joint_names:typing.Sequence[builtins.str], num_commands:typing.Optional[builtins.int], carry_size:typing.Sequence[builtins.int]) -> PyModelMetadata: ...
|
23
|
+
def to_json(self) -> builtins.str: ...
|
24
|
+
def __repr__(self) -> builtins.str: ...
|
25
|
+
def __eq__(self, other:typing.Any) -> builtins.bool: ...
|
19
26
|
|
20
27
|
class PyModelProvider:
|
21
28
|
...
|
@@ -35,3 +42,5 @@ class PyModelRuntime:
|
|
35
42
|
|
36
43
|
def get_version() -> builtins.str: ...
|
37
44
|
|
45
|
+
def metadata_from_json(json:builtins.str) -> PyModelMetadata: ...
|
46
|
+
|
kinfer/rust_bindings/src/lib.rs
CHANGED
@@ -1,89 +1,210 @@
|
|
1
1
|
use async_trait::async_trait;
|
2
2
|
use kinfer::model::{ModelError, ModelProvider, ModelRunner};
|
3
3
|
use kinfer::runtime::ModelRuntime;
|
4
|
+
use kinfer::types::{InputType, ModelMetadata};
|
4
5
|
use ndarray::{Array, Ix1, IxDyn};
|
5
6
|
use numpy::{PyArray1, PyArrayDyn, PyArrayMethods};
|
6
7
|
use pyo3::exceptions::PyNotImplementedError;
|
7
8
|
use pyo3::prelude::*;
|
9
|
+
use pyo3::types::{PyAny, PyAnyMethods};
|
8
10
|
use pyo3::{pymodule, types::PyModule, Bound, PyResult, Python};
|
9
11
|
use pyo3_stub_gen::define_stub_info_gatherer;
|
10
12
|
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
|
13
|
+
use std::collections::HashMap;
|
14
|
+
use std::hash::Hash;
|
11
15
|
use std::sync::Arc;
|
12
16
|
use std::sync::Mutex;
|
13
|
-
|
17
|
+
|
18
|
+
type StepResult = (Py<PyArrayDyn<f32>>, Py<PyArrayDyn<f32>>);
|
19
|
+
|
14
20
|
#[pyfunction]
|
15
21
|
#[gen_stub_pyfunction]
|
16
22
|
fn get_version() -> String {
|
17
23
|
env!("CARGO_PKG_VERSION").to_string()
|
18
24
|
}
|
19
25
|
|
20
|
-
#[pyclass
|
26
|
+
#[pyclass]
|
21
27
|
#[gen_stub_pyclass]
|
22
|
-
|
28
|
+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
29
|
+
struct PyInputType {
|
30
|
+
pub input_type: InputType,
|
31
|
+
}
|
32
|
+
|
33
|
+
impl From<InputType> for PyInputType {
|
34
|
+
fn from(input_type: InputType) -> Self {
|
35
|
+
Self { input_type }
|
36
|
+
}
|
37
|
+
}
|
38
|
+
|
39
|
+
impl From<PyInputType> for InputType {
|
40
|
+
fn from(input_type: PyInputType) -> Self {
|
41
|
+
input_type.input_type
|
42
|
+
}
|
43
|
+
}
|
23
44
|
|
24
45
|
#[gen_stub_pymethods]
|
25
46
|
#[pymethods]
|
26
|
-
impl
|
47
|
+
impl PyInputType {
|
27
48
|
#[new]
|
28
|
-
fn
|
29
|
-
|
49
|
+
fn __new__(input_type: &str) -> PyResult<Self> {
|
50
|
+
let input_type = InputType::from_name(input_type).map_or_else(
|
51
|
+
|_| {
|
52
|
+
Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
53
|
+
"Invalid input type: {} (must be one of {})",
|
54
|
+
input_type,
|
55
|
+
InputType::get_names().join(", "),
|
56
|
+
)))
|
57
|
+
},
|
58
|
+
Ok,
|
59
|
+
)?;
|
60
|
+
Ok(Self { input_type })
|
30
61
|
}
|
31
62
|
|
32
|
-
fn
|
33
|
-
|
34
|
-
joint_names: Vec<String>,
|
35
|
-
) -> PyResult<Bound<'py, PyArray1<f32>>> {
|
36
|
-
let n = joint_names.len();
|
37
|
-
Err(PyNotImplementedError::new_err(format!(
|
38
|
-
"Must override get_joint_angles with {} joint names",
|
39
|
-
n
|
40
|
-
)))
|
63
|
+
fn get_name(&self) -> String {
|
64
|
+
self.input_type.get_name().to_string()
|
41
65
|
}
|
42
66
|
|
43
|
-
fn
|
44
|
-
&
|
67
|
+
fn get_shape(&self, metadata: PyModelMetadata) -> Vec<usize> {
|
68
|
+
self.input_type.get_shape(&metadata.into())
|
69
|
+
}
|
70
|
+
|
71
|
+
fn __repr__(&self) -> String {
|
72
|
+
format!("InputType({})", self.get_name())
|
73
|
+
}
|
74
|
+
|
75
|
+
fn __eq__(&self, other: Bound<'_, PyAny>) -> PyResult<bool> {
|
76
|
+
if let Ok(other) = other.extract::<PyInputType>() {
|
77
|
+
Ok(self == &other)
|
78
|
+
} else {
|
79
|
+
Ok(false)
|
80
|
+
}
|
81
|
+
}
|
82
|
+
}
|
83
|
+
|
84
|
+
#[pyclass]
|
85
|
+
#[gen_stub_pyclass]
|
86
|
+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
87
|
+
struct PyModelMetadata {
|
88
|
+
#[pyo3(get, set)]
|
89
|
+
pub joint_names: Vec<String>,
|
90
|
+
#[pyo3(get, set)]
|
91
|
+
pub num_commands: Option<usize>,
|
92
|
+
#[pyo3(get, set)]
|
93
|
+
pub carry_size: Vec<usize>,
|
94
|
+
}
|
95
|
+
|
96
|
+
#[pymethods]
|
97
|
+
#[gen_stub_pymethods]
|
98
|
+
impl PyModelMetadata {
|
99
|
+
#[new]
|
100
|
+
fn __new__(
|
45
101
|
joint_names: Vec<String>,
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
102
|
+
num_commands: Option<usize>,
|
103
|
+
carry_size: Vec<usize>,
|
104
|
+
) -> Self {
|
105
|
+
Self {
|
106
|
+
joint_names,
|
107
|
+
num_commands,
|
108
|
+
carry_size,
|
109
|
+
}
|
52
110
|
}
|
53
111
|
|
54
|
-
fn
|
55
|
-
|
56
|
-
|
57
|
-
|
112
|
+
fn to_json(&self) -> PyResult<String> {
|
113
|
+
let metadata = ModelMetadata {
|
114
|
+
joint_names: self.joint_names.clone(),
|
115
|
+
num_commands: self.num_commands,
|
116
|
+
carry_size: self.carry_size.clone(),
|
117
|
+
}
|
118
|
+
.to_json()
|
119
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
120
|
+
Ok(metadata)
|
58
121
|
}
|
59
122
|
|
60
|
-
fn
|
61
|
-
|
62
|
-
|
63
|
-
|
123
|
+
fn __repr__(&self) -> PyResult<String> {
|
124
|
+
let json = self.to_json()?;
|
125
|
+
Ok(format!("ModelMetadata({:?})", json))
|
126
|
+
}
|
127
|
+
|
128
|
+
fn __eq__(&self, other: Bound<'_, PyAny>) -> PyResult<bool> {
|
129
|
+
if let Ok(other) = other.extract::<PyModelMetadata>() {
|
130
|
+
Ok(self == &other)
|
131
|
+
} else {
|
132
|
+
Ok(false)
|
133
|
+
}
|
64
134
|
}
|
135
|
+
}
|
65
136
|
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
137
|
+
#[pyfunction]
|
138
|
+
#[gen_stub_pyfunction]
|
139
|
+
fn metadata_from_json(json: &str) -> PyResult<PyModelMetadata> {
|
140
|
+
let metadata = ModelMetadata::model_validate_json(json.to_string()).map_err(|e| {
|
141
|
+
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Invalid model metadata: {}", e))
|
142
|
+
})?;
|
143
|
+
Ok(PyModelMetadata::from(metadata))
|
144
|
+
}
|
145
|
+
|
146
|
+
impl From<ModelMetadata> for PyModelMetadata {
|
147
|
+
fn from(metadata: ModelMetadata) -> Self {
|
148
|
+
Self {
|
149
|
+
joint_names: metadata.joint_names,
|
150
|
+
num_commands: metadata.num_commands,
|
151
|
+
carry_size: metadata.carry_size,
|
152
|
+
}
|
70
153
|
}
|
154
|
+
}
|
71
155
|
|
72
|
-
|
73
|
-
|
156
|
+
impl From<&ModelMetadata> for PyModelMetadata {
|
157
|
+
fn from(metadata: &ModelMetadata) -> Self {
|
158
|
+
Self {
|
159
|
+
joint_names: metadata.joint_names.clone(),
|
160
|
+
num_commands: metadata.num_commands,
|
161
|
+
carry_size: metadata.carry_size.clone(),
|
162
|
+
}
|
74
163
|
}
|
164
|
+
}
|
75
165
|
|
76
|
-
|
77
|
-
|
166
|
+
impl From<PyModelMetadata> for ModelMetadata {
|
167
|
+
fn from(metadata: PyModelMetadata) -> Self {
|
168
|
+
Self {
|
169
|
+
joint_names: metadata.joint_names,
|
170
|
+
num_commands: metadata.num_commands,
|
171
|
+
carry_size: metadata.carry_size,
|
172
|
+
}
|
78
173
|
}
|
174
|
+
}
|
79
175
|
|
80
|
-
|
176
|
+
#[pyclass(subclass)]
|
177
|
+
#[gen_stub_pyclass]
|
178
|
+
struct ModelProviderABC;
|
179
|
+
|
180
|
+
#[gen_stub_pymethods]
|
181
|
+
#[pymethods]
|
182
|
+
impl ModelProviderABC {
|
183
|
+
#[new]
|
184
|
+
fn __new__() -> Self {
|
185
|
+
ModelProviderABC
|
186
|
+
}
|
187
|
+
|
188
|
+
fn get_inputs<'py>(
|
81
189
|
&self,
|
82
|
-
|
83
|
-
|
190
|
+
input_types: Vec<String>,
|
191
|
+
metadata: PyModelMetadata,
|
192
|
+
) -> PyResult<HashMap<String, Bound<'py, PyArrayDyn<f32>>>> {
|
193
|
+
Err(PyNotImplementedError::new_err(format!(
|
194
|
+
"Must override get_inputs with {} input types {:?} and metadata {:?}",
|
195
|
+
input_types.len(),
|
196
|
+
input_types,
|
197
|
+
metadata
|
198
|
+
)))
|
199
|
+
}
|
200
|
+
|
201
|
+
fn take_action(
|
202
|
+
&self,
|
203
|
+
action: Bound<'_, PyArray1<f32>>,
|
204
|
+
metadata: PyModelMetadata,
|
84
205
|
) -> PyResult<()> {
|
85
206
|
let n = action.len()?;
|
86
|
-
assert_eq!(joint_names.len(), n);
|
207
|
+
assert_eq!(metadata.joint_names.len(), n); // TODO: this is wrong
|
87
208
|
Err(PyNotImplementedError::new_err(format!(
|
88
209
|
"Must override take_action with {} action",
|
89
210
|
n
|
@@ -96,127 +217,62 @@ impl ModelProviderABC {
|
|
96
217
|
#[derive(Clone)]
|
97
218
|
struct PyModelProvider {
|
98
219
|
obj: Arc<Py<ModelProviderABC>>,
|
99
|
-
start_time: Instant,
|
100
220
|
}
|
101
221
|
|
102
222
|
#[pymethods]
|
103
223
|
impl PyModelProvider {
|
104
224
|
#[new]
|
105
|
-
fn
|
106
|
-
Self {
|
107
|
-
obj: Arc::new(obj),
|
108
|
-
start_time: Instant::now(),
|
109
|
-
}
|
225
|
+
fn __new__(obj: Py<ModelProviderABC>) -> Self {
|
226
|
+
Self { obj: Arc::new(obj) }
|
110
227
|
}
|
111
228
|
}
|
112
229
|
|
113
230
|
#[async_trait]
|
114
231
|
impl ModelProvider for PyModelProvider {
|
115
|
-
async fn
|
116
|
-
&self,
|
117
|
-
joint_names: &[String],
|
118
|
-
) -> Result<Array<f32, IxDyn>, ModelError> {
|
119
|
-
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
120
|
-
let obj = self.obj.clone();
|
121
|
-
let args = (joint_names,);
|
122
|
-
let result = obj.call_method(py, "get_joint_angles", args, None)?;
|
123
|
-
let array = result.extract::<Vec<f32>>(py)?;
|
124
|
-
Ok(Array::from_vec(array).into_dyn())
|
125
|
-
})
|
126
|
-
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
127
|
-
Ok(args)
|
128
|
-
}
|
129
|
-
|
130
|
-
async fn get_joint_angular_velocities(
|
232
|
+
async fn get_inputs(
|
131
233
|
&self,
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
})
|
141
|
-
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
142
|
-
Ok(args)
|
143
|
-
}
|
144
|
-
|
145
|
-
async fn get_projected_gravity(&self) -> Result<Array<f32, IxDyn>, ModelError> {
|
146
|
-
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
147
|
-
let obj = self.obj.clone();
|
148
|
-
let args = ();
|
149
|
-
let result = obj.call_method(py, "get_projected_gravity", args, None)?;
|
150
|
-
let array = result.extract::<Vec<f32>>(py)?;
|
151
|
-
Ok(Array::from_vec(array).into_dyn())
|
152
|
-
})
|
153
|
-
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
154
|
-
Ok(args)
|
155
|
-
}
|
156
|
-
|
157
|
-
async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError> {
|
158
|
-
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
159
|
-
let obj = self.obj.clone();
|
160
|
-
let args = ();
|
161
|
-
let result = obj.call_method(py, "get_accelerometer", args, None)?;
|
162
|
-
let array = result.extract::<Vec<f32>>(py)?;
|
163
|
-
Ok(Array::from_vec(array).into_dyn())
|
164
|
-
})
|
165
|
-
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
166
|
-
Ok(args)
|
167
|
-
}
|
168
|
-
|
169
|
-
async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError> {
|
170
|
-
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
171
|
-
let obj = self.obj.clone();
|
172
|
-
let args = ();
|
173
|
-
let result = obj.call_method(py, "get_gyroscope", args, None)?;
|
174
|
-
let array = result.extract::<Vec<f32>>(py)?;
|
175
|
-
Ok(Array::from_vec(array).into_dyn())
|
176
|
-
})
|
177
|
-
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
178
|
-
Ok(args)
|
179
|
-
}
|
180
|
-
|
181
|
-
async fn get_command(&self) -> Result<Array<f32, IxDyn>, ModelError> {
|
182
|
-
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
183
|
-
let obj = self.obj.clone();
|
184
|
-
let args = ();
|
185
|
-
let result = obj.call_method(py, "get_command", args, None)?;
|
186
|
-
let array = result.extract::<Vec<f32>>(py)?;
|
187
|
-
Ok(Array::from_vec(array).into_dyn())
|
188
|
-
})
|
189
|
-
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
190
|
-
Ok(args)
|
191
|
-
}
|
192
|
-
|
193
|
-
async fn get_time(&self) -> Result<Array<f32, IxDyn>, ModelError> {
|
194
|
-
let args = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
234
|
+
input_types: &[InputType],
|
235
|
+
metadata: &ModelMetadata,
|
236
|
+
) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError> {
|
237
|
+
let input_names: Vec<String> = input_types
|
238
|
+
.iter()
|
239
|
+
.map(|t| t.get_name().to_string())
|
240
|
+
.collect();
|
241
|
+
let result = Python::with_gil(|py| -> PyResult<HashMap<InputType, Array<f32, IxDyn>>> {
|
195
242
|
let obj = self.obj.clone();
|
196
|
-
let args = ();
|
197
|
-
let result = obj.call_method(py, "
|
198
|
-
let
|
199
|
-
|
243
|
+
let args = (input_names.clone(), PyModelMetadata::from(metadata.clone()));
|
244
|
+
let result = obj.call_method(py, "get_inputs", args, None)?;
|
245
|
+
let dict: HashMap<String, Vec<f32>> = result.extract(py)?;
|
246
|
+
let mut arrays = HashMap::new();
|
247
|
+
for (i, name) in input_names.iter().enumerate() {
|
248
|
+
let array = dict.get(name).ok_or_else(|| {
|
249
|
+
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
250
|
+
"Missing input: {}",
|
251
|
+
name
|
252
|
+
))
|
253
|
+
})?;
|
254
|
+
arrays.insert(input_types[i], Array::from_vec(array.clone()).into_dyn());
|
255
|
+
}
|
256
|
+
Ok(arrays)
|
200
257
|
})
|
201
258
|
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
202
|
-
Ok(
|
203
|
-
}
|
204
|
-
|
205
|
-
async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError> {
|
206
|
-
Ok(carry)
|
259
|
+
Ok(result)
|
207
260
|
}
|
208
261
|
|
209
262
|
async fn take_action(
|
210
263
|
&self,
|
211
|
-
joint_names: Vec<String>,
|
212
264
|
action: Array<f32, IxDyn>,
|
265
|
+
metadata: &ModelMetadata,
|
213
266
|
) -> Result<(), ModelError> {
|
214
267
|
Python::with_gil(|py| -> PyResult<()> {
|
215
268
|
let obj = self.obj.clone();
|
216
269
|
let action_1d = action
|
217
270
|
.into_dimensionality::<Ix1>()
|
218
271
|
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
219
|
-
let args = (
|
272
|
+
let args = (
|
273
|
+
PyArray1::from_array(py, &action_1d),
|
274
|
+
PyModelMetadata::from(metadata.clone()),
|
275
|
+
);
|
220
276
|
obj.call_method(py, "take_action", args, None)?;
|
221
277
|
Ok(())
|
222
278
|
})
|
@@ -236,10 +292,10 @@ struct PyModelRunner {
|
|
236
292
|
#[pymethods]
|
237
293
|
impl PyModelRunner {
|
238
294
|
#[new]
|
239
|
-
fn
|
240
|
-
let input_provider = Arc::new(PyModelProvider::
|
295
|
+
fn __new__(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
|
296
|
+
let input_provider = Arc::new(PyModelProvider::__new__(provider));
|
241
297
|
|
242
|
-
let runner = tokio::runtime::Runtime::new()
|
298
|
+
let runner = tokio::runtime::Runtime::new()?.block_on(async {
|
243
299
|
ModelRunner::new(model_path, input_provider)
|
244
300
|
.await
|
245
301
|
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
@@ -252,7 +308,7 @@ impl PyModelRunner {
|
|
252
308
|
|
253
309
|
fn init(&self) -> PyResult<Py<PyArrayDyn<f32>>> {
|
254
310
|
let runner = self.runner.clone();
|
255
|
-
let result = tokio::runtime::Runtime::new()
|
311
|
+
let result = tokio::runtime::Runtime::new()?.block_on(async {
|
256
312
|
runner
|
257
313
|
.init()
|
258
314
|
.await
|
@@ -265,17 +321,14 @@ impl PyModelRunner {
|
|
265
321
|
})
|
266
322
|
}
|
267
323
|
|
268
|
-
fn step(
|
269
|
-
&self,
|
270
|
-
carry: Py<PyArrayDyn<f32>>,
|
271
|
-
) -> PyResult<(Py<PyArrayDyn<f32>>, Py<PyArrayDyn<f32>>)> {
|
324
|
+
fn step(&self, carry: Py<PyArrayDyn<f32>>) -> PyResult<StepResult> {
|
272
325
|
let runner = self.runner.clone();
|
273
326
|
let carry_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
274
327
|
let carry_array = carry.bind(py);
|
275
328
|
Ok(carry_array.to_owned_array())
|
276
329
|
})?;
|
277
330
|
|
278
|
-
let result = tokio::runtime::Runtime::new()
|
331
|
+
let result = tokio::runtime::Runtime::new()?.block_on(async {
|
279
332
|
runner
|
280
333
|
.step(carry_array)
|
281
334
|
.await
|
@@ -297,7 +350,7 @@ impl PyModelRunner {
|
|
297
350
|
Ok(action_array.to_owned_array())
|
298
351
|
})?;
|
299
352
|
|
300
|
-
tokio::runtime::Runtime::new()
|
353
|
+
tokio::runtime::Runtime::new()?.block_on(async {
|
301
354
|
runner
|
302
355
|
.take_action(action_array)
|
303
356
|
.await
|
@@ -319,38 +372,56 @@ struct PyModelRuntime {
|
|
319
372
|
#[pymethods]
|
320
373
|
impl PyModelRuntime {
|
321
374
|
#[new]
|
322
|
-
fn
|
375
|
+
fn __new__(model_runner: PyModelRunner, dt: u64) -> PyResult<Self> {
|
323
376
|
Ok(Self {
|
324
377
|
runtime: Arc::new(Mutex::new(ModelRuntime::new(model_runner.runner, dt))),
|
325
378
|
})
|
326
379
|
}
|
327
380
|
|
328
|
-
fn set_slowdown_factor(&self, slowdown_factor: i32) {
|
329
|
-
let mut runtime = self
|
381
|
+
fn set_slowdown_factor(&self, slowdown_factor: i32) -> PyResult<()> {
|
382
|
+
let mut runtime = self
|
383
|
+
.runtime
|
384
|
+
.lock()
|
385
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
330
386
|
runtime.set_slowdown_factor(slowdown_factor);
|
387
|
+
Ok(())
|
331
388
|
}
|
332
389
|
|
333
|
-
fn set_magnitude_factor(&self, magnitude_factor: f32) {
|
334
|
-
let mut runtime = self
|
390
|
+
fn set_magnitude_factor(&self, magnitude_factor: f32) -> PyResult<()> {
|
391
|
+
let mut runtime = self
|
392
|
+
.runtime
|
393
|
+
.lock()
|
394
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
335
395
|
runtime.set_magnitude_factor(magnitude_factor);
|
396
|
+
Ok(())
|
336
397
|
}
|
337
398
|
|
338
399
|
fn start(&self) -> PyResult<()> {
|
339
|
-
let mut runtime = self
|
400
|
+
let mut runtime = self
|
401
|
+
.runtime
|
402
|
+
.lock()
|
403
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
340
404
|
runtime
|
341
405
|
.start()
|
342
406
|
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
343
407
|
}
|
344
408
|
|
345
|
-
fn stop(&self) {
|
346
|
-
let mut runtime = self
|
409
|
+
fn stop(&self) -> PyResult<()> {
|
410
|
+
let mut runtime = self
|
411
|
+
.runtime
|
412
|
+
.lock()
|
413
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
347
414
|
runtime.stop();
|
415
|
+
Ok(())
|
348
416
|
}
|
349
417
|
}
|
350
418
|
|
351
419
|
#[pymodule]
|
352
420
|
fn rust_bindings(m: &Bound<PyModule>) -> PyResult<()> {
|
353
421
|
m.add_function(wrap_pyfunction!(get_version, m)?)?;
|
422
|
+
m.add_class::<PyInputType>()?;
|
423
|
+
m.add_class::<PyModelMetadata>()?;
|
424
|
+
m.add_function(wrap_pyfunction!(metadata_from_json, m)?)?;
|
354
425
|
m.add_class::<ModelProviderABC>()?;
|
355
426
|
m.add_class::<PyModelRunner>()?;
|
356
427
|
m.add_class::<PyModelRuntime>()?;
|
Binary file
|