kinfer 0.4.3__cp312-cp312-macosx_11_0_arm64.whl → 0.5.3__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/export/__init__.py +3 -0
- kinfer/export/jax.py +13 -11
- kinfer/export/pytorch.py +4 -14
- kinfer/export/serialize.py +18 -27
- kinfer/requirements.txt +5 -0
- kinfer/rust/Cargo.toml +4 -0
- kinfer/rust/src/lib.rs +3 -0
- kinfer/rust/src/logger.rs +141 -0
- kinfer/rust/src/model.rs +126 -104
- kinfer/rust/src/runtime.rs +5 -2
- kinfer/rust/src/types.rs +96 -0
- kinfer/rust_bindings/rust_bindings.pyi +17 -8
- kinfer/rust_bindings/src/lib.rs +228 -157
- kinfer/rust_bindings.cpython-312-darwin.so +0 -0
- kinfer/rust_bindings.pyi +17 -8
- kinfer/scripts/plot_ndjson.py +177 -0
- {kinfer-0.4.3.dist-info → kinfer-0.5.3.dist-info}/METADATA +4 -1
- kinfer-0.5.3.dist-info/RECORD +26 -0
- {kinfer-0.4.3.dist-info → kinfer-0.5.3.dist-info}/WHEEL +1 -1
- kinfer/common/__init__.py +0 -0
- kinfer/common/types.py +0 -12
- kinfer/export/common.py +0 -44
- kinfer-0.4.3.dist-info/RECORD +0 -26
- {kinfer-0.4.3.dist-info → kinfer-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {kinfer-0.4.3.dist-info → kinfer-0.5.3.dist-info}/top_level.txt +0 -0
kinfer/export/__init__.py
CHANGED
kinfer/export/jax.py
CHANGED
@@ -10,20 +10,27 @@ from jax._src.stages import Wrapped
|
|
10
10
|
from jax.experimental import jax2tf
|
11
11
|
from onnx.onnx_pb import ModelProto
|
12
12
|
|
13
|
-
from kinfer.
|
13
|
+
from kinfer.rust_bindings import PyInputType, PyModelMetadata
|
14
14
|
|
15
15
|
logger = logging.getLogger(__name__)
|
16
16
|
|
17
17
|
|
18
18
|
def export_fn(
|
19
19
|
model: Wrapped,
|
20
|
+
metadata: PyModelMetadata,
|
20
21
|
*,
|
21
|
-
num_joints: int | None = None,
|
22
|
-
num_commands: int | None = None,
|
23
|
-
carry_shape: tuple[int, ...] | None = None,
|
24
22
|
opset: int = 13,
|
25
23
|
) -> ModelProto:
|
26
|
-
"""Export a JAX function to ONNX.
|
24
|
+
"""Export a JAX function to ONNX.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
model: The model to export.
|
28
|
+
metadata: The metadata for the model.
|
29
|
+
opset: The ONNX opset to use.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
The ONNX model as a `ModelProto`.
|
33
|
+
"""
|
27
34
|
if not isinstance(model, Wrapped):
|
28
35
|
raise ValueError("Model must be a Wrapped function")
|
29
36
|
|
@@ -33,12 +40,7 @@ def export_fn(
|
|
33
40
|
# Gets the dummy input tensors for exporting the model.
|
34
41
|
tf_args = []
|
35
42
|
for name in input_names:
|
36
|
-
shape = get_shape(
|
37
|
-
name,
|
38
|
-
num_joints=num_joints,
|
39
|
-
num_commands=num_commands,
|
40
|
-
carry_shape=carry_shape,
|
41
|
-
)
|
43
|
+
shape = PyInputType(name).get_shape(metadata)
|
42
44
|
tf_args.append(tf.TensorSpec(shape, tf.float32, name=name))
|
43
45
|
|
44
46
|
finalised_fn = finalise_fn(model)
|
kinfer/export/pytorch.py
CHANGED
@@ -12,23 +12,18 @@ import torch
|
|
12
12
|
from onnx.onnx_pb import ModelProto
|
13
13
|
from torch._C import FunctionSchema
|
14
14
|
|
15
|
-
from kinfer.
|
15
|
+
from kinfer.rust_bindings import PyInputType, PyModelMetadata
|
16
16
|
|
17
17
|
|
18
18
|
def export_fn(
|
19
19
|
model: torch.jit.ScriptFunction,
|
20
|
-
|
21
|
-
num_joints: int | None = None,
|
22
|
-
num_commands: int | None = None,
|
23
|
-
carry_shape: tuple[int, ...] | None = None,
|
20
|
+
metadata: PyModelMetadata,
|
24
21
|
) -> ModelProto:
|
25
22
|
"""Exports a PyTorch function to ONNX.
|
26
23
|
|
27
24
|
Args:
|
28
25
|
model: The model to export.
|
29
|
-
|
30
|
-
num_commands: The number of commands in the model.
|
31
|
-
carry_shape: The shape of the carry tensor.
|
26
|
+
metadata: The metadata for the model.
|
32
27
|
|
33
28
|
Returns:
|
34
29
|
The ONNX model as a `ModelProto`.
|
@@ -42,12 +37,7 @@ def export_fn(
|
|
42
37
|
# Gets the dummy input tensors for exporting the model.
|
43
38
|
args = []
|
44
39
|
for name in input_names:
|
45
|
-
shape = get_shape(
|
46
|
-
name,
|
47
|
-
num_joints=num_joints,
|
48
|
-
num_commands=num_commands,
|
49
|
-
carry_shape=carry_shape,
|
50
|
-
)
|
40
|
+
shape = PyInputType(name).get_shape(metadata)
|
51
41
|
args.append(torch.zeros(shape))
|
52
42
|
|
53
43
|
buffer = io.BytesIO()
|
kinfer/export/serialize.py
CHANGED
@@ -4,34 +4,30 @@ __all__ = [
|
|
4
4
|
"pack",
|
5
5
|
]
|
6
6
|
|
7
|
-
|
8
7
|
import io
|
8
|
+
import logging
|
9
9
|
import tarfile
|
10
10
|
|
11
11
|
from onnx.onnx_pb import ModelProto
|
12
12
|
|
13
|
-
from kinfer.
|
14
|
-
|
13
|
+
from kinfer.rust_bindings import PyInputType, PyModelMetadata
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
15
16
|
|
16
17
|
|
17
18
|
def pack(
|
18
19
|
init_fn: ModelProto,
|
19
20
|
step_fn: ModelProto,
|
20
|
-
|
21
|
-
num_commands: int | None = None,
|
22
|
-
carry_shape: tuple[int, ...] | None = None,
|
21
|
+
metadata: PyModelMetadata,
|
23
22
|
) -> bytes:
|
24
23
|
"""Packs the initialization function and step function into a directory.
|
25
24
|
|
26
25
|
Args:
|
27
26
|
init_fn: The initialization function.
|
28
27
|
step_fn: The step function.
|
29
|
-
|
30
|
-
expects them to be provided.
|
31
|
-
num_commands: The number of commands in the model.
|
32
|
-
carry_shape: The shape of the carry tensor.
|
28
|
+
metadata: The metadata for the model.
|
33
29
|
"""
|
34
|
-
num_joints = len(joint_names)
|
30
|
+
num_joints = len(metadata.joint_names) # type: ignore[attr-defined]
|
35
31
|
|
36
32
|
# Checks the `init` function.
|
37
33
|
if len(init_fn.graph.input) > 0:
|
@@ -40,20 +36,21 @@ def pack(
|
|
40
36
|
raise ValueError(f"`init` function should have exactly 1 output! Got {len(init_fn.graph.output)}")
|
41
37
|
init_carry = init_fn.graph.output[0]
|
42
38
|
init_carry_shape = tuple(dim.dim_value for dim in init_carry.type.tensor_type.shape.dim)
|
43
|
-
|
44
|
-
|
39
|
+
|
40
|
+
if metadata.carry_size != init_carry_shape: # type: ignore[attr-defined]
|
41
|
+
logger.warning(
|
42
|
+
"Updating carry size from %s to %s to match the `init` function",
|
43
|
+
metadata.carry_size, # type: ignore[attr-defined]
|
44
|
+
init_carry_shape,
|
45
|
+
)
|
46
|
+
metadata.carry_size = init_carry_shape # type: ignore[attr-defined]
|
45
47
|
|
46
48
|
# Checks the `step` function.
|
47
49
|
for step_input in step_fn.graph.input:
|
48
50
|
step_input_type = step_input.type.tensor_type
|
49
51
|
shape = tuple(dim.dim_value for dim in step_input_type.shape.dim)
|
50
|
-
expected_shape = get_shape(
|
51
|
-
|
52
|
-
num_joints=num_joints,
|
53
|
-
num_commands=num_commands,
|
54
|
-
carry_shape=carry_shape,
|
55
|
-
)
|
56
|
-
if shape != expected_shape:
|
52
|
+
expected_shape = PyInputType(step_input.name).get_shape(metadata)
|
53
|
+
if shape != tuple(expected_shape):
|
57
54
|
raise ValueError(f"Expected shape {expected_shape} for input `{step_input.name}`, got {shape}")
|
58
55
|
|
59
56
|
if len(step_fn.graph.output) != 2:
|
@@ -69,12 +66,6 @@ def pack(
|
|
69
66
|
if output_carry_shape != init_carry_shape:
|
70
67
|
raise ValueError(f"Expected carry shape {init_carry_shape} for output carry, got {output_carry_shape}")
|
71
68
|
|
72
|
-
# Builds the metadata object.
|
73
|
-
metadata = Metadata(
|
74
|
-
joint_names=joint_names,
|
75
|
-
num_commands=num_commands,
|
76
|
-
)
|
77
|
-
|
78
69
|
buffer = io.BytesIO()
|
79
70
|
|
80
71
|
with tarfile.open(fileobj=buffer, mode="w:gz") as tar:
|
@@ -86,7 +77,7 @@ def pack(
|
|
86
77
|
|
87
78
|
add_file_bytes("init_fn.onnx", init_fn.SerializeToString())
|
88
79
|
add_file_bytes("step_fn.onnx", step_fn.SerializeToString())
|
89
|
-
add_file_bytes("metadata.json", metadata.
|
80
|
+
add_file_bytes("metadata.json", metadata.to_json().encode("utf-8"))
|
90
81
|
|
91
82
|
buffer.seek(0)
|
92
83
|
|
kinfer/requirements.txt
CHANGED
kinfer/rust/Cargo.toml
CHANGED
@@ -17,10 +17,14 @@ crate-type = ["cdylib", "rlib"]
|
|
17
17
|
[dependencies]
|
18
18
|
|
19
19
|
async-trait = "0.1"
|
20
|
+
chrono = "0.4.41"
|
21
|
+
crossbeam-channel = "0.5.15"
|
20
22
|
flate2 = "1.0"
|
21
23
|
futures-util = "0.3.30"
|
24
|
+
log = "0.4"
|
22
25
|
ndarray = "0.16.1"
|
23
26
|
ort = { version = "2.0.0-rc.9", features = [ "load-dynamic" ] }
|
27
|
+
ort-sys = { version = "=2.0.0-rc.9" }
|
24
28
|
serde = { version = "1.0", features = ["derive"] }
|
25
29
|
serde_json = "1.0"
|
26
30
|
tar = "0.4"
|
kinfer/rust/src/lib.rs
CHANGED
@@ -0,0 +1,141 @@
|
|
1
|
+
use std::{
|
2
|
+
fs::OpenOptions,
|
3
|
+
io::{BufWriter, Write},
|
4
|
+
path::{Path, PathBuf},
|
5
|
+
sync::atomic::{AtomicU64, Ordering},
|
6
|
+
thread,
|
7
|
+
};
|
8
|
+
|
9
|
+
use crossbeam_channel::{bounded, Sender};
|
10
|
+
use log::{info, warn};
|
11
|
+
use serde::Serialize;
|
12
|
+
|
13
|
+
#[derive(Serialize)]
|
14
|
+
struct NdjsonStep<'a> {
|
15
|
+
step_id: u64,
|
16
|
+
t_us: u64,
|
17
|
+
joint_angles: Option<&'a [f32]>,
|
18
|
+
joint_vels: Option<&'a [f32]>,
|
19
|
+
initial_heading: Option<&'a [f32]>,
|
20
|
+
quaternion: Option<&'a [f32]>,
|
21
|
+
projected_g: Option<&'a [f32]>,
|
22
|
+
accel: Option<&'a [f32]>,
|
23
|
+
gyro: Option<&'a [f32]>,
|
24
|
+
command: Option<&'a [f32]>,
|
25
|
+
output: Option<&'a [f32]>,
|
26
|
+
}
|
27
|
+
|
28
|
+
// Channel capacity for non-blocking logging.
|
29
|
+
// ~1000 entires at 50Hz is 20 seconds of buffering.
|
30
|
+
// Warns if messages are dropped due to full buffer.
|
31
|
+
const CHANNEL_CAP: usize = 1024;
|
32
|
+
|
33
|
+
// Flush buffered writes every 100 log entries.
|
34
|
+
// At 50Hz control frequency, this flushes every 2 seconds.
|
35
|
+
const FLUSH_EVERY: u64 = 100;
|
36
|
+
|
37
|
+
pub struct StepLogger {
|
38
|
+
tx: Option<Sender<Vec<u8>>>,
|
39
|
+
worker: Option<thread::JoinHandle<()>>,
|
40
|
+
next_id: AtomicU64,
|
41
|
+
}
|
42
|
+
|
43
|
+
impl StepLogger {
|
44
|
+
pub fn new(path: impl AsRef<Path>) -> std::io::Result<Self> {
|
45
|
+
let path: PathBuf = path.as_ref().into();
|
46
|
+
if let Some(parent) = path.parent() {
|
47
|
+
std::fs::create_dir_all(parent)?;
|
48
|
+
}
|
49
|
+
info!("kinfer: logging to NDJSON: {}", path.display());
|
50
|
+
|
51
|
+
// I/O objects created here, but moved into the worker thread.
|
52
|
+
let file = OpenOptions::new().create(true).append(true).open(&path)?;
|
53
|
+
let mut bw = BufWriter::new(file);
|
54
|
+
|
55
|
+
// Bounded channel -> back-pressure capped at CHANNEL_CAP lines
|
56
|
+
let (tx, rx) = bounded::<Vec<u8>>(CHANNEL_CAP);
|
57
|
+
|
58
|
+
let worker = thread::spawn(move || {
|
59
|
+
let mut line_ctr: u64 = 0;
|
60
|
+
for msg in rx {
|
61
|
+
// drains until all senders dropped
|
62
|
+
let _ = bw.write_all(&msg);
|
63
|
+
line_ctr += 1;
|
64
|
+
if line_ctr % FLUSH_EVERY == 0 {
|
65
|
+
let _ = bw.flush();
|
66
|
+
}
|
67
|
+
}
|
68
|
+
// Final flush on graceful shutdown
|
69
|
+
let _ = bw.flush();
|
70
|
+
});
|
71
|
+
|
72
|
+
Ok(Self {
|
73
|
+
tx: Some(tx),
|
74
|
+
worker: Some(worker),
|
75
|
+
next_id: AtomicU64::new(0),
|
76
|
+
})
|
77
|
+
}
|
78
|
+
|
79
|
+
#[inline]
|
80
|
+
fn now_us() -> u128 {
|
81
|
+
std::time::SystemTime::now()
|
82
|
+
.duration_since(std::time::UNIX_EPOCH)
|
83
|
+
.unwrap()
|
84
|
+
.as_micros()
|
85
|
+
}
|
86
|
+
|
87
|
+
/// Non-blocking; drops a line if the channel is full.
|
88
|
+
#[allow(clippy::too_many_arguments)]
|
89
|
+
pub fn log_step(
|
90
|
+
&self,
|
91
|
+
joint_angles: Option<&[f32]>,
|
92
|
+
joint_vels: Option<&[f32]>,
|
93
|
+
initial_heading: Option<&[f32]>,
|
94
|
+
quaternion: Option<&[f32]>,
|
95
|
+
projected_g: Option<&[f32]>,
|
96
|
+
accel: Option<&[f32]>,
|
97
|
+
gyro: Option<&[f32]>,
|
98
|
+
command: Option<&[f32]>,
|
99
|
+
output: Option<&[f32]>,
|
100
|
+
) {
|
101
|
+
let record = NdjsonStep {
|
102
|
+
step_id: self.next_id.fetch_add(1, Ordering::Relaxed),
|
103
|
+
t_us: Self::now_us() as u64,
|
104
|
+
joint_angles,
|
105
|
+
joint_vels,
|
106
|
+
initial_heading,
|
107
|
+
quaternion,
|
108
|
+
projected_g,
|
109
|
+
accel,
|
110
|
+
gyro,
|
111
|
+
command,
|
112
|
+
output,
|
113
|
+
};
|
114
|
+
|
115
|
+
// Serialise directly into a Vec<u8>; then push newline and send.
|
116
|
+
if let Ok(mut line) = serde_json::to_vec(&record) {
|
117
|
+
line.push(b'\n');
|
118
|
+
if let Some(tx) = &self.tx {
|
119
|
+
if tx.try_send(line).is_err() {
|
120
|
+
warn!(
|
121
|
+
"kinfer: logging buffer full, dropped message (step_id: {})",
|
122
|
+
record.step_id
|
123
|
+
);
|
124
|
+
}
|
125
|
+
}
|
126
|
+
}
|
127
|
+
}
|
128
|
+
}
|
129
|
+
|
130
|
+
/// Ensure the worker drains and flushes before program exit.
|
131
|
+
impl Drop for StepLogger {
|
132
|
+
fn drop(&mut self) {
|
133
|
+
if let Some(tx) = self.tx.take() {
|
134
|
+
drop(tx); // Drop sender to close channel
|
135
|
+
}
|
136
|
+
// Wait for worker to finish
|
137
|
+
if let Some(worker) = self.worker.take() {
|
138
|
+
let _ = worker.join();
|
139
|
+
}
|
140
|
+
}
|
141
|
+
}
|
kinfer/rust/src/model.rs
CHANGED
@@ -1,10 +1,11 @@
|
|
1
|
+
use crate::logger::StepLogger;
|
2
|
+
use crate::types::{InputType, ModelMetadata};
|
1
3
|
use async_trait::async_trait;
|
4
|
+
use chrono;
|
2
5
|
use flate2::read::GzDecoder;
|
3
|
-
use futures_util::future;
|
4
6
|
use ndarray::{Array, IxDyn};
|
5
7
|
use ort::session::Session;
|
6
8
|
use ort::value::Value;
|
7
|
-
use serde::Deserialize;
|
8
9
|
use std::collections::HashMap;
|
9
10
|
use std::io::Read;
|
10
11
|
use std::path::Path;
|
@@ -13,18 +14,6 @@ use tar::Archive;
|
|
13
14
|
use tokio::fs::File;
|
14
15
|
use tokio::io::AsyncReadExt;
|
15
16
|
|
16
|
-
#[derive(Debug, Deserialize)]
|
17
|
-
struct ModelMetadata {
|
18
|
-
joint_names: Vec<String>,
|
19
|
-
num_commands: Option<usize>,
|
20
|
-
}
|
21
|
-
|
22
|
-
impl ModelMetadata {
|
23
|
-
fn model_validate_json(json: String) -> Result<Self, Box<dyn std::error::Error>> {
|
24
|
-
Ok(serde_json::from_str(&json)?)
|
25
|
-
}
|
26
|
-
}
|
27
|
-
|
28
17
|
#[derive(Debug, thiserror::Error)]
|
29
18
|
pub enum ModelError {
|
30
19
|
#[error("IO error: {0}")]
|
@@ -35,24 +24,16 @@ pub enum ModelError {
|
|
35
24
|
|
36
25
|
#[async_trait]
|
37
26
|
pub trait ModelProvider: Send + Sync {
|
38
|
-
async fn
|
39
|
-
&self,
|
40
|
-
joint_names: &[String],
|
41
|
-
) -> Result<Array<f32, IxDyn>, ModelError>;
|
42
|
-
async fn get_joint_angular_velocities(
|
27
|
+
async fn get_inputs(
|
43
28
|
&self,
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError>;
|
49
|
-
async fn get_command(&self) -> Result<Array<f32, IxDyn>, ModelError>;
|
50
|
-
async fn get_time(&self) -> Result<Array<f32, IxDyn>, ModelError>;
|
51
|
-
async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError>;
|
29
|
+
input_types: &[InputType],
|
30
|
+
metadata: &ModelMetadata,
|
31
|
+
) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError>;
|
32
|
+
|
52
33
|
async fn take_action(
|
53
34
|
&self,
|
54
|
-
joint_names: Vec<String>,
|
55
35
|
action: Array<f32, IxDyn>,
|
36
|
+
metadata: &ModelMetadata,
|
56
37
|
) -> Result<(), ModelError>;
|
57
38
|
}
|
58
39
|
|
@@ -61,6 +42,7 @@ pub struct ModelRunner {
|
|
61
42
|
step_session: Session,
|
62
43
|
metadata: ModelMetadata,
|
63
44
|
provider: Arc<dyn ModelProvider>,
|
45
|
+
logger: Option<Arc<StepLogger>>,
|
64
46
|
}
|
65
47
|
|
66
48
|
impl ModelRunner {
|
@@ -139,11 +121,29 @@ impl ModelRunner {
|
|
139
121
|
// Validate step_fn inputs and outputs
|
140
122
|
Self::validate_step_fn(&step_session, &metadata, &carry_shape)?;
|
141
123
|
|
124
|
+
let logger = if let Ok(log_dir) = std::env::var("KINFER_LOG_PATH") {
|
125
|
+
let log_dir_path = std::path::Path::new(&log_dir);
|
126
|
+
|
127
|
+
// Create the directory if it doesn't exist
|
128
|
+
if !log_dir_path.exists() {
|
129
|
+
std::fs::create_dir_all(log_dir_path)?;
|
130
|
+
}
|
131
|
+
|
132
|
+
// Generate a timestamped filename
|
133
|
+
let timestamp = chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S").to_string();
|
134
|
+
let log_file_path = log_dir_path.join(format!("{}.ndjson", timestamp));
|
135
|
+
|
136
|
+
Some(StepLogger::new(log_file_path).map(Arc::new)?)
|
137
|
+
} else {
|
138
|
+
None
|
139
|
+
};
|
140
|
+
|
142
141
|
Ok(Self {
|
143
142
|
init_session,
|
144
143
|
step_session,
|
145
144
|
metadata,
|
146
145
|
provider: input_provider,
|
146
|
+
logger,
|
147
147
|
})
|
148
148
|
}
|
149
149
|
|
@@ -159,55 +159,15 @@ impl ModelRunner {
|
|
159
159
|
input.name
|
160
160
|
))?;
|
161
161
|
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
}
|
172
|
-
}
|
173
|
-
"projected_gravity" | "accelerometer" | "gyroscope" => {
|
174
|
-
if *dims != vec![3] {
|
175
|
-
return Err(format!(
|
176
|
-
"Expected shape [3] for input `{}`, got {:?}",
|
177
|
-
input.name, dims
|
178
|
-
)
|
179
|
-
.into());
|
180
|
-
}
|
181
|
-
}
|
182
|
-
"command" => {
|
183
|
-
let num_commands = metadata.num_commands.ok_or("num_commands is not set")?;
|
184
|
-
if *dims != vec![num_commands as i64] {
|
185
|
-
return Err(format!(
|
186
|
-
"Expected shape [{num_commands}] for input `{}`, got {:?}",
|
187
|
-
input.name, dims
|
188
|
-
)
|
189
|
-
.into());
|
190
|
-
}
|
191
|
-
}
|
192
|
-
"time" => {
|
193
|
-
if *dims != vec![1] {
|
194
|
-
return Err(format!(
|
195
|
-
"Expected shape [1] for input `{}`, got {:?}",
|
196
|
-
input.name, dims
|
197
|
-
)
|
198
|
-
.into());
|
199
|
-
}
|
200
|
-
}
|
201
|
-
"carry" => {
|
202
|
-
if dims != carry_shape {
|
203
|
-
return Err(format!(
|
204
|
-
"Expected shape {:?} for input `carry`, got {:?}",
|
205
|
-
carry_shape, dims
|
206
|
-
)
|
207
|
-
.into());
|
208
|
-
}
|
209
|
-
}
|
210
|
-
_ => return Err(format!("Unknown input name: {}", input.name).into()),
|
162
|
+
let input_type = InputType::from_name(&input.name)?;
|
163
|
+
let expected_shape = input_type.get_shape(metadata);
|
164
|
+
let expected_shape_i64: Vec<i64> = expected_shape.iter().map(|&x| x as i64).collect();
|
165
|
+
if *dims != expected_shape_i64 {
|
166
|
+
return Err(format!(
|
167
|
+
"Expected input shape {:?}, got {:?}",
|
168
|
+
expected_shape_i64, dims
|
169
|
+
)
|
170
|
+
.into());
|
211
171
|
}
|
212
172
|
}
|
213
173
|
|
@@ -244,6 +204,13 @@ impl ModelRunner {
|
|
244
204
|
Ok(())
|
245
205
|
}
|
246
206
|
|
207
|
+
pub async fn get_inputs(
|
208
|
+
&self,
|
209
|
+
input_types: &[InputType],
|
210
|
+
) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError> {
|
211
|
+
self.provider.get_inputs(input_types, &self.metadata).await
|
212
|
+
}
|
213
|
+
|
247
214
|
pub async fn init(&self) -> Result<Array<f32, IxDyn>, Box<dyn std::error::Error>> {
|
248
215
|
let input_values: Vec<(&str, Value)> = Vec::new();
|
249
216
|
let outputs = self.init_session.run(input_values)?;
|
@@ -264,37 +231,59 @@ impl ModelRunner {
|
|
264
231
|
.collect();
|
265
232
|
|
266
233
|
// Calls the relevant getter methods in parallel.
|
267
|
-
let mut
|
234
|
+
let mut input_types = Vec::new();
|
235
|
+
let mut inputs = HashMap::new();
|
268
236
|
for name in &input_names {
|
269
237
|
match name.as_str() {
|
270
238
|
"joint_angles" => {
|
271
|
-
|
239
|
+
input_types.push(InputType::JointAngles);
|
240
|
+
}
|
241
|
+
"joint_angular_velocities" => {
|
242
|
+
input_types.push(InputType::JointAngularVelocities);
|
243
|
+
}
|
244
|
+
"initial_heading" => {
|
245
|
+
input_types.push(InputType::InitialHeading);
|
246
|
+
}
|
247
|
+
"quaternion" => {
|
248
|
+
input_types.push(InputType::Quaternion);
|
249
|
+
}
|
250
|
+
"projected_gravity" => {
|
251
|
+
input_types.push(InputType::ProjectedGravity);
|
252
|
+
}
|
253
|
+
"accelerometer" => {
|
254
|
+
input_types.push(InputType::Accelerometer);
|
255
|
+
}
|
256
|
+
"gyroscope" => {
|
257
|
+
input_types.push(InputType::Gyroscope);
|
258
|
+
}
|
259
|
+
"command" => {
|
260
|
+
input_types.push(InputType::Command);
|
261
|
+
}
|
262
|
+
"time" => {
|
263
|
+
input_types.push(InputType::Time);
|
264
|
+
}
|
265
|
+
"carry" => {
|
266
|
+
inputs.insert(InputType::Carry, carry.clone());
|
272
267
|
}
|
273
|
-
"joint_angular_velocities" => futures.push(
|
274
|
-
self.provider
|
275
|
-
.get_joint_angular_velocities(&self.metadata.joint_names),
|
276
|
-
),
|
277
|
-
"projected_gravity" => futures.push(self.provider.get_projected_gravity()),
|
278
|
-
"accelerometer" => futures.push(self.provider.get_accelerometer()),
|
279
|
-
"gyroscope" => futures.push(self.provider.get_gyroscope()),
|
280
|
-
"command" => futures.push(self.provider.get_command()),
|
281
|
-
"carry" => futures.push(self.provider.get_carry(carry.clone())),
|
282
|
-
"time" => futures.push(self.provider.get_time()),
|
283
268
|
_ => return Err(format!("Unknown input name: {}", name).into()),
|
284
269
|
}
|
285
270
|
}
|
286
271
|
|
287
|
-
|
288
|
-
let
|
289
|
-
|
290
|
-
|
291
|
-
|
272
|
+
// Gets the input values.
|
273
|
+
let result = self
|
274
|
+
.provider
|
275
|
+
.get_inputs(&input_types, &self.metadata)
|
276
|
+
.await?;
|
277
|
+
|
278
|
+
// Adds the input values to the input map.
|
279
|
+
inputs.extend(result);
|
292
280
|
|
293
281
|
// Convert inputs to ONNX values
|
294
282
|
let mut input_values: Vec<(&str, Value)> = Vec::new();
|
295
283
|
for input in &self.step_session.inputs {
|
284
|
+
let input_type = InputType::from_name(&input.name)?;
|
296
285
|
let input_data = inputs
|
297
|
-
.get(&
|
286
|
+
.get(&input_type)
|
298
287
|
.ok_or_else(|| format!("Missing input: {}", input.name))?;
|
299
288
|
let input_value = Value::from_array(input_data.view())?.into_dyn();
|
300
289
|
input_values.push((input.name.as_str(), input_value));
|
@@ -305,6 +294,47 @@ impl ModelRunner {
|
|
305
294
|
let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
|
306
295
|
let carry_tensor = outputs[1].try_extract_tensor::<f32>()?;
|
307
296
|
|
297
|
+
// Log the step if needed
|
298
|
+
if let Some(lg) = &self.logger {
|
299
|
+
let joint_angles_opt = inputs
|
300
|
+
.get(&InputType::JointAngles)
|
301
|
+
.map(|a| a.as_slice().unwrap());
|
302
|
+
let joint_vels_opt = inputs
|
303
|
+
.get(&InputType::JointAngularVelocities)
|
304
|
+
.map(|a| a.as_slice().unwrap());
|
305
|
+
let initial_heading_opt = inputs
|
306
|
+
.get(&InputType::InitialHeading)
|
307
|
+
.map(|a| a.as_slice().unwrap());
|
308
|
+
let quaternion_opt = inputs
|
309
|
+
.get(&InputType::Quaternion)
|
310
|
+
.map(|a| a.as_slice().unwrap());
|
311
|
+
let projected_g_opt = inputs
|
312
|
+
.get(&InputType::ProjectedGravity)
|
313
|
+
.map(|a| a.as_slice().unwrap());
|
314
|
+
let accel_opt = inputs
|
315
|
+
.get(&InputType::Accelerometer)
|
316
|
+
.map(|a| a.as_slice().unwrap());
|
317
|
+
let gyro_opt = inputs
|
318
|
+
.get(&InputType::Gyroscope)
|
319
|
+
.map(|a| a.as_slice().unwrap());
|
320
|
+
let command_opt = inputs
|
321
|
+
.get(&InputType::Command)
|
322
|
+
.map(|a| a.as_slice().unwrap());
|
323
|
+
let output_opt = Some(output_tensor.as_slice().unwrap());
|
324
|
+
|
325
|
+
lg.log_step(
|
326
|
+
joint_angles_opt,
|
327
|
+
joint_vels_opt,
|
328
|
+
initial_heading_opt,
|
329
|
+
quaternion_opt,
|
330
|
+
projected_g_opt,
|
331
|
+
accel_opt,
|
332
|
+
gyro_opt,
|
333
|
+
command_opt,
|
334
|
+
output_opt,
|
335
|
+
);
|
336
|
+
}
|
337
|
+
|
308
338
|
Ok((
|
309
339
|
output_tensor.view().to_owned(),
|
310
340
|
carry_tensor.view().to_owned(),
|
@@ -315,15 +345,7 @@ impl ModelRunner {
|
|
315
345
|
&self,
|
316
346
|
action: Array<f32, IxDyn>,
|
317
347
|
) -> Result<(), Box<dyn std::error::Error>> {
|
318
|
-
self.provider
|
319
|
-
.take_action(self.metadata.joint_names.clone(), action)
|
320
|
-
.await?;
|
348
|
+
self.provider.take_action(action, &self.metadata).await?;
|
321
349
|
Ok(())
|
322
350
|
}
|
323
|
-
|
324
|
-
pub async fn get_joint_angles(&self) -> Result<Array<f32, IxDyn>, Box<dyn std::error::Error>> {
|
325
|
-
let joint_names = &self.metadata.joint_names;
|
326
|
-
let joint_angles = self.provider.get_joint_angles(joint_names).await?;
|
327
|
-
Ok(joint_angles)
|
328
|
-
}
|
329
351
|
}
|