kinfer 0.4.2__cp311-cp311-macosx_11_0_arm64.whl → 0.5.1__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kinfer/export/__init__.py CHANGED
@@ -0,0 +1,3 @@
1
+ """Defines the export API."""
2
+
3
+ from .serialize import *
kinfer/export/jax.py CHANGED
@@ -10,20 +10,27 @@ from jax._src.stages import Wrapped
10
10
  from jax.experimental import jax2tf
11
11
  from onnx.onnx_pb import ModelProto
12
12
 
13
- from kinfer.export.common import get_shape
13
+ from kinfer.rust_bindings import PyInputType, PyModelMetadata
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
17
17
 
18
18
  def export_fn(
19
19
  model: Wrapped,
20
+ metadata: PyModelMetadata,
20
21
  *,
21
- num_joints: int | None = None,
22
- num_commands: int | None = None,
23
- carry_shape: tuple[int, ...] | None = None,
24
22
  opset: int = 13,
25
23
  ) -> ModelProto:
26
- """Export a JAX function to ONNX."""
24
+ """Export a JAX function to ONNX.
25
+
26
+ Args:
27
+ model: The model to export.
28
+ metadata: The metadata for the model.
29
+ opset: The ONNX opset to use.
30
+
31
+ Returns:
32
+ The ONNX model as a `ModelProto`.
33
+ """
27
34
  if not isinstance(model, Wrapped):
28
35
  raise ValueError("Model must be a Wrapped function")
29
36
 
@@ -33,12 +40,7 @@ def export_fn(
33
40
  # Gets the dummy input tensors for exporting the model.
34
41
  tf_args = []
35
42
  for name in input_names:
36
- shape = get_shape(
37
- name,
38
- num_joints=num_joints,
39
- num_commands=num_commands,
40
- carry_shape=carry_shape,
41
- )
43
+ shape = PyInputType(name).get_shape(metadata)
42
44
  tf_args.append(tf.TensorSpec(shape, tf.float32, name=name))
43
45
 
44
46
  finalised_fn = finalise_fn(model)
kinfer/export/pytorch.py CHANGED
@@ -12,23 +12,18 @@ import torch
12
12
  from onnx.onnx_pb import ModelProto
13
13
  from torch._C import FunctionSchema
14
14
 
15
- from kinfer.export.common import get_shape
15
+ from kinfer.rust_bindings import PyInputType, PyModelMetadata
16
16
 
17
17
 
18
18
  def export_fn(
19
19
  model: torch.jit.ScriptFunction,
20
- *,
21
- num_joints: int | None = None,
22
- num_commands: int | None = None,
23
- carry_shape: tuple[int, ...] | None = None,
20
+ metadata: PyModelMetadata,
24
21
  ) -> ModelProto:
25
22
  """Exports a PyTorch function to ONNX.
26
23
 
27
24
  Args:
28
25
  model: The model to export.
29
- num_joints: The number of joints in the model.
30
- num_commands: The number of commands in the model.
31
- carry_shape: The shape of the carry tensor.
26
+ metadata: The metadata for the model.
32
27
 
33
28
  Returns:
34
29
  The ONNX model as a `ModelProto`.
@@ -42,12 +37,7 @@ def export_fn(
42
37
  # Gets the dummy input tensors for exporting the model.
43
38
  args = []
44
39
  for name in input_names:
45
- shape = get_shape(
46
- name,
47
- num_joints=num_joints,
48
- num_commands=num_commands,
49
- carry_shape=carry_shape,
50
- )
40
+ shape = PyInputType(name).get_shape(metadata)
51
41
  args.append(torch.zeros(shape))
52
42
 
53
43
  buffer = io.BytesIO()
@@ -4,34 +4,30 @@ __all__ = [
4
4
  "pack",
5
5
  ]
6
6
 
7
-
8
7
  import io
8
+ import logging
9
9
  import tarfile
10
10
 
11
11
  from onnx.onnx_pb import ModelProto
12
12
 
13
- from kinfer.common.types import Metadata
14
- from kinfer.export.common import get_shape
13
+ from kinfer.rust_bindings import PyInputType, PyModelMetadata
14
+
15
+ logger = logging.getLogger(__name__)
15
16
 
16
17
 
17
18
  def pack(
18
19
  init_fn: ModelProto,
19
20
  step_fn: ModelProto,
20
- joint_names: list[str],
21
- num_commands: int | None = None,
22
- carry_shape: tuple[int, ...] | None = None,
21
+ metadata: PyModelMetadata,
23
22
  ) -> bytes:
24
23
  """Packs the initialization function and step function into a directory.
25
24
 
26
25
  Args:
27
26
  init_fn: The initialization function.
28
27
  step_fn: The step function.
29
- joint_names: The list of joint names, in the order that the model
30
- expects them to be provided.
31
- num_commands: The number of commands in the model.
32
- carry_shape: The shape of the carry tensor.
28
+ metadata: The metadata for the model.
33
29
  """
34
- num_joints = len(joint_names)
30
+ num_joints = len(metadata.joint_names) # type: ignore[attr-defined]
35
31
 
36
32
  # Checks the `init` function.
37
33
  if len(init_fn.graph.input) > 0:
@@ -40,20 +36,21 @@ def pack(
40
36
  raise ValueError(f"`init` function should have exactly 1 output! Got {len(init_fn.graph.output)}")
41
37
  init_carry = init_fn.graph.output[0]
42
38
  init_carry_shape = tuple(dim.dim_value for dim in init_carry.type.tensor_type.shape.dim)
43
- if carry_shape is not None and init_carry_shape != carry_shape:
44
- raise ValueError(f"Expected carry shape {carry_shape} for output `{init_carry.name}`, got {init_carry_shape}")
39
+
40
+ if metadata.carry_size != init_carry_shape: # type: ignore[attr-defined]
41
+ logger.warning(
42
+ "Updating carry size from %s to %s to match the `init` function",
43
+ metadata.carry_size, # type: ignore[attr-defined]
44
+ init_carry_shape,
45
+ )
46
+ metadata.carry_size = init_carry_shape # type: ignore[attr-defined]
45
47
 
46
48
  # Checks the `step` function.
47
49
  for step_input in step_fn.graph.input:
48
50
  step_input_type = step_input.type.tensor_type
49
51
  shape = tuple(dim.dim_value for dim in step_input_type.shape.dim)
50
- expected_shape = get_shape(
51
- step_input.name,
52
- num_joints=num_joints,
53
- num_commands=num_commands,
54
- carry_shape=carry_shape,
55
- )
56
- if shape != expected_shape:
52
+ expected_shape = PyInputType(step_input.name).get_shape(metadata)
53
+ if shape != tuple(expected_shape):
57
54
  raise ValueError(f"Expected shape {expected_shape} for input `{step_input.name}`, got {shape}")
58
55
 
59
56
  if len(step_fn.graph.output) != 2:
@@ -69,12 +66,6 @@ def pack(
69
66
  if output_carry_shape != init_carry_shape:
70
67
  raise ValueError(f"Expected carry shape {init_carry_shape} for output carry, got {output_carry_shape}")
71
68
 
72
- # Builds the metadata object.
73
- metadata = Metadata(
74
- joint_names=joint_names,
75
- num_commands=num_commands,
76
- )
77
-
78
69
  buffer = io.BytesIO()
79
70
 
80
71
  with tarfile.open(fileobj=buffer, mode="w:gz") as tar:
@@ -86,7 +77,7 @@ def pack(
86
77
 
87
78
  add_file_bytes("init_fn.onnx", init_fn.SerializeToString())
88
79
  add_file_bytes("step_fn.onnx", step_fn.SerializeToString())
89
- add_file_bytes("metadata.json", metadata.model_dump_json().encode("utf-8"))
80
+ add_file_bytes("metadata.json", metadata.to_json().encode("utf-8"))
90
81
 
91
82
  buffer.seek(0)
92
83
 
kinfer/requirements.txt CHANGED
@@ -6,3 +6,8 @@ onnxruntime==1.20.0
6
6
 
7
7
  # Serialization
8
8
  pydantic
9
+
10
+ # Plotting
11
+ matplotlib
12
+ numpy
13
+ pathlib
kinfer/rust/Cargo.toml CHANGED
@@ -17,8 +17,11 @@ crate-type = ["cdylib", "rlib"]
17
17
  [dependencies]
18
18
 
19
19
  async-trait = "0.1"
20
+ chrono = "0.4.41"
21
+ crossbeam-channel = "0.5.15"
20
22
  flate2 = "1.0"
21
23
  futures-util = "0.3.30"
24
+ log = "0.4"
22
25
  ndarray = "0.16.1"
23
26
  ort = { version = "2.0.0-rc.9", features = [ "load-dynamic" ] }
24
27
  serde = { version = "1.0", features = ["derive"] }
kinfer/rust/src/lib.rs CHANGED
@@ -1,5 +1,8 @@
1
+ pub mod logger;
1
2
  pub mod model;
2
3
  pub mod runtime;
4
+ pub mod types;
3
5
 
4
6
  pub use model::*;
5
7
  pub use runtime::*;
8
+ pub use types::*;
@@ -0,0 +1,135 @@
1
+ use std::{
2
+ fs::OpenOptions,
3
+ io::{BufWriter, Write},
4
+ path::{Path, PathBuf},
5
+ sync::atomic::{AtomicU64, Ordering},
6
+ thread,
7
+ };
8
+
9
+ use crossbeam_channel::{bounded, Sender};
10
+ use log::{info, warn};
11
+ use serde::Serialize;
12
+
13
+ #[derive(Serialize)]
14
+ struct NdjsonStep<'a> {
15
+ step_id: u64,
16
+ t_us: u64,
17
+ joint_angles: Option<&'a [f32]>,
18
+ joint_vels: Option<&'a [f32]>,
19
+ projected_g: Option<&'a [f32]>,
20
+ accel: Option<&'a [f32]>,
21
+ gyro: Option<&'a [f32]>,
22
+ command: Option<&'a [f32]>,
23
+ output: Option<&'a [f32]>,
24
+ }
25
+
26
+ // Channel capacity for non-blocking logging.
27
+ // ~1000 entires at 50Hz is 20 seconds of buffering.
28
+ // Warns if messages are dropped due to full buffer.
29
+ const CHANNEL_CAP: usize = 1024;
30
+
31
+ // Flush buffered writes every 100 log entries.
32
+ // At 50Hz control frequency, this flushes every 2 seconds.
33
+ const FLUSH_EVERY: u64 = 100;
34
+
35
+ pub struct StepLogger {
36
+ tx: Option<Sender<Vec<u8>>>,
37
+ worker: Option<thread::JoinHandle<()>>,
38
+ next_id: AtomicU64,
39
+ }
40
+
41
+ impl StepLogger {
42
+ pub fn new(path: impl AsRef<Path>) -> std::io::Result<Self> {
43
+ let path: PathBuf = path.as_ref().into();
44
+ if let Some(parent) = path.parent() {
45
+ std::fs::create_dir_all(parent)?;
46
+ }
47
+ info!("kinfer: logging to NDJSON: {}", path.display());
48
+
49
+ // I/O objects created here, but moved into the worker thread.
50
+ let file = OpenOptions::new().create(true).append(true).open(&path)?;
51
+ let mut bw = BufWriter::new(file);
52
+
53
+ // Bounded channel -> back-pressure capped at CHANNEL_CAP lines
54
+ let (tx, rx) = bounded::<Vec<u8>>(CHANNEL_CAP);
55
+
56
+ let worker = thread::spawn(move || {
57
+ let mut line_ctr: u64 = 0;
58
+ for msg in rx {
59
+ // drains until all senders dropped
60
+ let _ = bw.write_all(&msg);
61
+ line_ctr += 1;
62
+ if line_ctr % FLUSH_EVERY == 0 {
63
+ let _ = bw.flush();
64
+ }
65
+ }
66
+ // Final flush on graceful shutdown
67
+ let _ = bw.flush();
68
+ });
69
+
70
+ Ok(Self {
71
+ tx: Some(tx),
72
+ worker: Some(worker),
73
+ next_id: AtomicU64::new(0),
74
+ })
75
+ }
76
+
77
+ #[inline]
78
+ fn now_us() -> u128 {
79
+ std::time::SystemTime::now()
80
+ .duration_since(std::time::UNIX_EPOCH)
81
+ .unwrap()
82
+ .as_micros()
83
+ }
84
+
85
+ /// Non-blocking; drops a line if the channel is full.
86
+ #[allow(clippy::too_many_arguments)]
87
+ pub fn log_step(
88
+ &self,
89
+ joint_angles: Option<&[f32]>,
90
+ joint_vels: Option<&[f32]>,
91
+ projected_g: Option<&[f32]>,
92
+ accel: Option<&[f32]>,
93
+ gyro: Option<&[f32]>,
94
+ command: Option<&[f32]>,
95
+ output: Option<&[f32]>,
96
+ ) {
97
+ let record = NdjsonStep {
98
+ step_id: self.next_id.fetch_add(1, Ordering::Relaxed),
99
+ t_us: Self::now_us() as u64,
100
+ joint_angles,
101
+ joint_vels,
102
+ projected_g,
103
+ accel,
104
+ gyro,
105
+ command,
106
+ output,
107
+ };
108
+
109
+ // Serialise directly into a Vec<u8>; then push newline and send.
110
+ if let Ok(mut line) = serde_json::to_vec(&record) {
111
+ line.push(b'\n');
112
+ if let Some(tx) = &self.tx {
113
+ if tx.try_send(line).is_err() {
114
+ warn!(
115
+ "kinfer: logging buffer full, dropped message (step_id: {})",
116
+ record.step_id
117
+ );
118
+ }
119
+ }
120
+ }
121
+ }
122
+ }
123
+
124
+ /// Ensure the worker drains and flushes before program exit.
125
+ impl Drop for StepLogger {
126
+ fn drop(&mut self) {
127
+ if let Some(tx) = self.tx.take() {
128
+ drop(tx); // Drop sender to close channel
129
+ }
130
+ // Wait for worker to finish
131
+ if let Some(worker) = self.worker.take() {
132
+ let _ = worker.join();
133
+ }
134
+ }
135
+ }
kinfer/rust/src/model.rs CHANGED
@@ -1,10 +1,11 @@
1
+ use crate::logger::StepLogger;
2
+ use crate::types::{InputType, ModelMetadata};
1
3
  use async_trait::async_trait;
4
+ use chrono;
2
5
  use flate2::read::GzDecoder;
3
- use futures_util::future;
4
6
  use ndarray::{Array, IxDyn};
5
7
  use ort::session::Session;
6
8
  use ort::value::Value;
7
- use serde::Deserialize;
8
9
  use std::collections::HashMap;
9
10
  use std::io::Read;
10
11
  use std::path::Path;
@@ -13,18 +14,6 @@ use tar::Archive;
13
14
  use tokio::fs::File;
14
15
  use tokio::io::AsyncReadExt;
15
16
 
16
- #[derive(Debug, Deserialize)]
17
- struct ModelMetadata {
18
- joint_names: Vec<String>,
19
- num_commands: Option<usize>,
20
- }
21
-
22
- impl ModelMetadata {
23
- fn model_validate_json(json: String) -> Result<Self, Box<dyn std::error::Error>> {
24
- Ok(serde_json::from_str(&json)?)
25
- }
26
- }
27
-
28
17
  #[derive(Debug, thiserror::Error)]
29
18
  pub enum ModelError {
30
19
  #[error("IO error: {0}")]
@@ -35,23 +24,16 @@ pub enum ModelError {
35
24
 
36
25
  #[async_trait]
37
26
  pub trait ModelProvider: Send + Sync {
38
- async fn get_joint_angles(
39
- &self,
40
- joint_names: &[String],
41
- ) -> Result<Array<f32, IxDyn>, ModelError>;
42
- async fn get_joint_angular_velocities(
27
+ async fn get_inputs(
43
28
  &self,
44
- joint_names: &[String],
45
- ) -> Result<Array<f32, IxDyn>, ModelError>;
46
- async fn get_projected_gravity(&self) -> Result<Array<f32, IxDyn>, ModelError>;
47
- async fn get_accelerometer(&self) -> Result<Array<f32, IxDyn>, ModelError>;
48
- async fn get_gyroscope(&self) -> Result<Array<f32, IxDyn>, ModelError>;
49
- async fn get_command(&self) -> Result<Array<f32, IxDyn>, ModelError>;
50
- async fn get_carry(&self, carry: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>, ModelError>;
29
+ input_types: &[InputType],
30
+ metadata: &ModelMetadata,
31
+ ) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError>;
32
+
51
33
  async fn take_action(
52
34
  &self,
53
- joint_names: Vec<String>,
54
35
  action: Array<f32, IxDyn>,
36
+ metadata: &ModelMetadata,
55
37
  ) -> Result<(), ModelError>;
56
38
  }
57
39
 
@@ -60,6 +42,7 @@ pub struct ModelRunner {
60
42
  step_session: Session,
61
43
  metadata: ModelMetadata,
62
44
  provider: Arc<dyn ModelProvider>,
45
+ logger: Option<Arc<StepLogger>>,
63
46
  }
64
47
 
65
48
  impl ModelRunner {
@@ -138,11 +121,29 @@ impl ModelRunner {
138
121
  // Validate step_fn inputs and outputs
139
122
  Self::validate_step_fn(&step_session, &metadata, &carry_shape)?;
140
123
 
124
+ let logger = if let Ok(log_dir) = std::env::var("KINFER_LOG_PATH") {
125
+ let log_dir_path = std::path::Path::new(&log_dir);
126
+
127
+ // Create the directory if it doesn't exist
128
+ if !log_dir_path.exists() {
129
+ std::fs::create_dir_all(log_dir_path)?;
130
+ }
131
+
132
+ // Generate a timestamped filename
133
+ let timestamp = chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S").to_string();
134
+ let log_file_path = log_dir_path.join(format!("{}.ndjson", timestamp));
135
+
136
+ Some(StepLogger::new(log_file_path).map(Arc::new)?)
137
+ } else {
138
+ None
139
+ };
140
+
141
141
  Ok(Self {
142
142
  init_session,
143
143
  step_session,
144
144
  metadata,
145
145
  provider: input_provider,
146
+ logger,
146
147
  })
147
148
  }
148
149
 
@@ -158,46 +159,15 @@ impl ModelRunner {
158
159
  input.name
159
160
  ))?;
160
161
 
161
- match input.name.as_str() {
162
- "joint_angles" | "joint_angular_velocities" => {
163
- let num_joints = metadata.joint_names.len();
164
- if *dims != vec![num_joints as i64] {
165
- return Err(format!(
166
- "Expected shape [{num_joints}] for input `{}`, got {:?}",
167
- input.name, dims
168
- )
169
- .into());
170
- }
171
- }
172
- "projected_gravity" | "accelerometer" | "gyroscope" => {
173
- if *dims != vec![3] {
174
- return Err(format!(
175
- "Expected shape [3] for input `{}`, got {:?}",
176
- input.name, dims
177
- )
178
- .into());
179
- }
180
- }
181
- "command" => {
182
- let num_commands = metadata.num_commands.ok_or("num_commands is not set")?;
183
- if *dims != vec![num_commands as i64] {
184
- return Err(format!(
185
- "Expected shape [{num_commands}] for input `{}`, got {:?}",
186
- input.name, dims
187
- )
188
- .into());
189
- }
190
- }
191
- "carry" => {
192
- if dims != carry_shape {
193
- return Err(format!(
194
- "Expected shape {:?} for input `carry`, got {:?}",
195
- carry_shape, dims
196
- )
197
- .into());
198
- }
199
- }
200
- _ => return Err(format!("Unknown input name: {}", input.name).into()),
162
+ let input_type = InputType::from_name(&input.name)?;
163
+ let expected_shape = input_type.get_shape(metadata);
164
+ let expected_shape_i64: Vec<i64> = expected_shape.iter().map(|&x| x as i64).collect();
165
+ if *dims != expected_shape_i64 {
166
+ return Err(format!(
167
+ "Expected input shape {:?}, got {:?}",
168
+ expected_shape_i64, dims
169
+ )
170
+ .into());
201
171
  }
202
172
  }
203
173
 
@@ -234,6 +204,13 @@ impl ModelRunner {
234
204
  Ok(())
235
205
  }
236
206
 
207
+ pub async fn get_inputs(
208
+ &self,
209
+ input_types: &[InputType],
210
+ ) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError> {
211
+ self.provider.get_inputs(input_types, &self.metadata).await
212
+ }
213
+
237
214
  pub async fn init(&self) -> Result<Array<f32, IxDyn>, Box<dyn std::error::Error>> {
238
215
  let input_values: Vec<(&str, Value)> = Vec::new();
239
216
  let outputs = self.init_session.run(input_values)?;
@@ -254,36 +231,53 @@ impl ModelRunner {
254
231
  .collect();
255
232
 
256
233
  // Calls the relevant getter methods in parallel.
257
- let mut futures = Vec::new();
234
+ let mut input_types = Vec::new();
235
+ let mut inputs = HashMap::new();
258
236
  for name in &input_names {
259
237
  match name.as_str() {
260
238
  "joint_angles" => {
261
- futures.push(self.provider.get_joint_angles(&self.metadata.joint_names))
239
+ input_types.push(InputType::JointAngles);
240
+ }
241
+ "joint_angular_velocities" => {
242
+ input_types.push(InputType::JointAngularVelocities);
243
+ }
244
+ "projected_gravity" => {
245
+ input_types.push(InputType::ProjectedGravity);
246
+ }
247
+ "accelerometer" => {
248
+ input_types.push(InputType::Accelerometer);
249
+ }
250
+ "gyroscope" => {
251
+ input_types.push(InputType::Gyroscope);
252
+ }
253
+ "command" => {
254
+ input_types.push(InputType::Command);
255
+ }
256
+ "time" => {
257
+ input_types.push(InputType::Time);
258
+ }
259
+ "carry" => {
260
+ inputs.insert(InputType::Carry, carry.clone());
262
261
  }
263
- "joint_angular_velocities" => futures.push(
264
- self.provider
265
- .get_joint_angular_velocities(&self.metadata.joint_names),
266
- ),
267
- "projected_gravity" => futures.push(self.provider.get_projected_gravity()),
268
- "accelerometer" => futures.push(self.provider.get_accelerometer()),
269
- "gyroscope" => futures.push(self.provider.get_gyroscope()),
270
- "command" => futures.push(self.provider.get_command()),
271
- "carry" => futures.push(self.provider.get_carry(carry.clone())),
272
262
  _ => return Err(format!("Unknown input name: {}", name).into()),
273
263
  }
274
264
  }
275
265
 
276
- let results = future::try_join_all(futures).await?;
277
- let mut inputs = HashMap::new();
278
- for (name, value) in input_names.iter().zip(results) {
279
- inputs.insert(name.clone(), value);
280
- }
266
+ // Gets the input values.
267
+ let result = self
268
+ .provider
269
+ .get_inputs(&input_types, &self.metadata)
270
+ .await?;
271
+
272
+ // Adds the input values to the input map.
273
+ inputs.extend(result);
281
274
 
282
275
  // Convert inputs to ONNX values
283
276
  let mut input_values: Vec<(&str, Value)> = Vec::new();
284
277
  for input in &self.step_session.inputs {
278
+ let input_type = InputType::from_name(&input.name)?;
285
279
  let input_data = inputs
286
- .get(&input.name)
280
+ .get(&input_type)
287
281
  .ok_or_else(|| format!("Missing input: {}", input.name))?;
288
282
  let input_value = Value::from_array(input_data.view())?.into_dyn();
289
283
  input_values.push((input.name.as_str(), input_value));
@@ -294,6 +288,39 @@ impl ModelRunner {
294
288
  let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
295
289
  let carry_tensor = outputs[1].try_extract_tensor::<f32>()?;
296
290
 
291
+ // Log the step if needed
292
+ if let Some(lg) = &self.logger {
293
+ let joint_angles_opt = inputs
294
+ .get(&InputType::JointAngles)
295
+ .map(|a| a.as_slice().unwrap());
296
+ let joint_vels_opt = inputs
297
+ .get(&InputType::JointAngularVelocities)
298
+ .map(|a| a.as_slice().unwrap());
299
+ let projected_g_opt = inputs
300
+ .get(&InputType::ProjectedGravity)
301
+ .map(|a| a.as_slice().unwrap());
302
+ let accel_opt = inputs
303
+ .get(&InputType::Accelerometer)
304
+ .map(|a| a.as_slice().unwrap());
305
+ let gyro_opt = inputs
306
+ .get(&InputType::Gyroscope)
307
+ .map(|a| a.as_slice().unwrap());
308
+ let command_opt = inputs
309
+ .get(&InputType::Command)
310
+ .map(|a| a.as_slice().unwrap());
311
+ let output_opt = Some(output_tensor.as_slice().unwrap());
312
+
313
+ lg.log_step(
314
+ joint_angles_opt,
315
+ joint_vels_opt,
316
+ projected_g_opt,
317
+ accel_opt,
318
+ gyro_opt,
319
+ command_opt,
320
+ output_opt,
321
+ );
322
+ }
323
+
297
324
  Ok((
298
325
  output_tensor.view().to_owned(),
299
326
  carry_tensor.view().to_owned(),
@@ -304,15 +331,7 @@ impl ModelRunner {
304
331
  &self,
305
332
  action: Array<f32, IxDyn>,
306
333
  ) -> Result<(), Box<dyn std::error::Error>> {
307
- self.provider
308
- .take_action(self.metadata.joint_names.clone(), action)
309
- .await?;
334
+ self.provider.take_action(action, &self.metadata).await?;
310
335
  Ok(())
311
336
  }
312
-
313
- pub async fn get_joint_angles(&self) -> Result<Array<f32, IxDyn>, Box<dyn std::error::Error>> {
314
- let joint_names = &self.metadata.joint_names;
315
- let joint_angles = self.provider.get_joint_angles(joint_names).await?;
316
- Ok(joint_angles)
317
- }
318
337
  }