kinfer 0.5.4__cp313-cp313-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,486 @@
1
+ use async_trait::async_trait;
2
+ use kinfer::model::{ModelError, ModelProvider, ModelRunner};
3
+ use kinfer::runtime::ModelRuntime;
4
+ use kinfer::types::{InputType, ModelMetadata};
5
+ use ndarray::{Array, Ix1, IxDyn};
6
+ use numpy::{PyArray1, PyArrayDyn, PyArrayMethods};
7
+ use pyo3::exceptions::PyNotImplementedError;
8
+ use pyo3::prelude::*;
9
+ use pyo3::types::{PyAny, PyAnyMethods};
10
+ use pyo3::{pymodule, types::PyModule, Bound, PyResult, Python};
11
+ use pyo3_stub_gen::define_stub_info_gatherer;
12
+ use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
13
+ use std::collections::HashMap;
14
+ use std::hash::Hash;
15
+ use std::sync::Arc;
16
+ use std::sync::Mutex;
17
+
18
+ type StepResult = (Py<PyArrayDyn<f32>>, Py<PyArrayDyn<f32>>);
19
+
20
+ // Custom error type for Send/Sync compatibility
21
+ #[derive(Debug)]
22
+ struct SendError(String);
23
+
24
+ unsafe impl Send for SendError {}
25
+ unsafe impl Sync for SendError {}
26
+
27
+ impl std::fmt::Display for SendError {
28
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29
+ write!(f, "{}", self.0)
30
+ }
31
+ }
32
+
33
+ #[pyfunction]
34
+ #[gen_stub_pyfunction]
35
+ fn get_version() -> String {
36
+ env!("CARGO_PKG_VERSION").to_string()
37
+ }
38
+
39
+ #[pyclass]
40
+ #[gen_stub_pyclass]
41
+ #[derive(Debug, Clone, PartialEq, Eq, Hash)]
42
+ struct PyInputType {
43
+ pub input_type: InputType,
44
+ }
45
+
46
+ impl From<InputType> for PyInputType {
47
+ fn from(input_type: InputType) -> Self {
48
+ Self { input_type }
49
+ }
50
+ }
51
+
52
+ impl From<PyInputType> for InputType {
53
+ fn from(input_type: PyInputType) -> Self {
54
+ input_type.input_type
55
+ }
56
+ }
57
+
58
+ #[gen_stub_pymethods]
59
+ #[pymethods]
60
+ impl PyInputType {
61
+ #[new]
62
+ fn __new__(input_type: &str) -> PyResult<Self> {
63
+ let input_type = InputType::from_name(input_type).map_or_else(
64
+ |_| {
65
+ Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
66
+ "Invalid input type: {} (must be one of {})",
67
+ input_type,
68
+ InputType::get_names().join(", "),
69
+ )))
70
+ },
71
+ Ok,
72
+ )?;
73
+ Ok(Self { input_type })
74
+ }
75
+
76
+ fn get_name(&self) -> String {
77
+ self.input_type.get_name().to_string()
78
+ }
79
+
80
+ fn get_shape(&self, metadata: PyModelMetadata) -> Vec<usize> {
81
+ self.input_type.get_shape(&metadata.into())
82
+ }
83
+
84
+ fn __repr__(&self) -> String {
85
+ format!("InputType({})", self.get_name())
86
+ }
87
+
88
+ fn __eq__(&self, other: Bound<'_, PyAny>) -> PyResult<bool> {
89
+ if let Ok(other) = other.extract::<PyInputType>() {
90
+ Ok(self == &other)
91
+ } else {
92
+ Ok(false)
93
+ }
94
+ }
95
+ }
96
+
97
+ #[pyclass]
98
+ #[gen_stub_pyclass]
99
+ #[derive(Debug, Clone, PartialEq, Eq, Hash)]
100
+ struct PyModelMetadata {
101
+ #[pyo3(get, set)]
102
+ pub joint_names: Vec<String>,
103
+ #[pyo3(get, set)]
104
+ pub num_commands: Option<usize>,
105
+ #[pyo3(get, set)]
106
+ pub carry_size: Vec<usize>,
107
+ }
108
+
109
+ #[pymethods]
110
+ #[gen_stub_pymethods]
111
+ impl PyModelMetadata {
112
+ #[new]
113
+ fn __new__(
114
+ joint_names: Vec<String>,
115
+ num_commands: Option<usize>,
116
+ carry_size: Vec<usize>,
117
+ ) -> Self {
118
+ Self {
119
+ joint_names,
120
+ num_commands,
121
+ carry_size,
122
+ }
123
+ }
124
+
125
+ fn to_json(&self) -> PyResult<String> {
126
+ let metadata = ModelMetadata {
127
+ joint_names: self.joint_names.clone(),
128
+ num_commands: self.num_commands,
129
+ carry_size: self.carry_size.clone(),
130
+ }
131
+ .to_json()
132
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
133
+ Ok(metadata)
134
+ }
135
+
136
+ fn __repr__(&self) -> PyResult<String> {
137
+ let json = self.to_json()?;
138
+ Ok(format!("ModelMetadata({:?})", json))
139
+ }
140
+
141
+ fn __eq__(&self, other: Bound<'_, PyAny>) -> PyResult<bool> {
142
+ if let Ok(other) = other.extract::<PyModelMetadata>() {
143
+ Ok(self == &other)
144
+ } else {
145
+ Ok(false)
146
+ }
147
+ }
148
+ }
149
+
150
+ #[pyfunction]
151
+ #[gen_stub_pyfunction]
152
+ fn metadata_from_json(json: &str) -> PyResult<PyModelMetadata> {
153
+ let metadata = ModelMetadata::model_validate_json(json.to_string()).map_err(|e| {
154
+ PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Invalid model metadata: {}", e))
155
+ })?;
156
+ Ok(PyModelMetadata::from(metadata))
157
+ }
158
+
159
+ impl From<ModelMetadata> for PyModelMetadata {
160
+ fn from(metadata: ModelMetadata) -> Self {
161
+ Self {
162
+ joint_names: metadata.joint_names,
163
+ num_commands: metadata.num_commands,
164
+ carry_size: metadata.carry_size,
165
+ }
166
+ }
167
+ }
168
+
169
+ impl From<&ModelMetadata> for PyModelMetadata {
170
+ fn from(metadata: &ModelMetadata) -> Self {
171
+ Self {
172
+ joint_names: metadata.joint_names.clone(),
173
+ num_commands: metadata.num_commands,
174
+ carry_size: metadata.carry_size.clone(),
175
+ }
176
+ }
177
+ }
178
+
179
+ impl From<PyModelMetadata> for ModelMetadata {
180
+ fn from(metadata: PyModelMetadata) -> Self {
181
+ Self {
182
+ joint_names: metadata.joint_names,
183
+ num_commands: metadata.num_commands,
184
+ carry_size: metadata.carry_size,
185
+ }
186
+ }
187
+ }
188
+
189
+ #[pyclass(subclass)]
190
+ #[gen_stub_pyclass]
191
+ struct ModelProviderABC;
192
+
193
+ #[gen_stub_pymethods]
194
+ #[pymethods]
195
+ impl ModelProviderABC {
196
+ #[new]
197
+ fn __new__() -> Self {
198
+ ModelProviderABC
199
+ }
200
+
201
+ fn get_inputs<'py>(
202
+ &self,
203
+ input_types: Vec<String>,
204
+ metadata: PyModelMetadata,
205
+ ) -> PyResult<HashMap<String, Bound<'py, PyArrayDyn<f32>>>> {
206
+ Err(PyNotImplementedError::new_err(format!(
207
+ "Must override get_inputs with {} input types {:?} and metadata {:?}",
208
+ input_types.len(),
209
+ input_types,
210
+ metadata
211
+ )))
212
+ }
213
+
214
+ fn take_action(
215
+ &self,
216
+ action: Bound<'_, PyArray1<f32>>,
217
+ metadata: PyModelMetadata,
218
+ ) -> PyResult<()> {
219
+ let n = action.len()?;
220
+ if metadata.joint_names.len() != n {
221
+ return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
222
+ "Expected {} joints, got {} action elements",
223
+ metadata.joint_names.len(),
224
+ n
225
+ )));
226
+ }
227
+ Err(PyNotImplementedError::new_err(format!(
228
+ "Must override take_action with {} action elements",
229
+ n
230
+ )))
231
+ }
232
+ }
233
+
234
+ #[gen_stub_pyclass]
235
+ #[pyclass]
236
+ #[derive(Clone)]
237
+ struct PyModelProvider {
238
+ obj: Arc<Py<ModelProviderABC>>,
239
+ }
240
+
241
+ #[pymethods]
242
+ impl PyModelProvider {
243
+ #[new]
244
+ fn __new__(obj: Py<ModelProviderABC>) -> Self {
245
+ Self { obj: Arc::new(obj) }
246
+ }
247
+ }
248
+
249
+ #[async_trait]
250
+ impl ModelProvider for PyModelProvider {
251
+ async fn get_inputs(
252
+ &self,
253
+ input_types: &[InputType],
254
+ metadata: &ModelMetadata,
255
+ ) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError> {
256
+ let input_names: Vec<String> = input_types
257
+ .iter()
258
+ .map(|t| t.get_name().to_string())
259
+ .collect();
260
+ let result = Python::with_gil(|py| -> PyResult<HashMap<InputType, Array<f32, IxDyn>>> {
261
+ let obj = self.obj.clone();
262
+ let args = (input_names.clone(), PyModelMetadata::from(metadata.clone()));
263
+ let result = obj.call_method(py, "get_inputs", args, None)?;
264
+ let dict: HashMap<String, Vec<f32>> = result.extract(py)?;
265
+ let mut arrays = HashMap::new();
266
+ for (i, name) in input_names.iter().enumerate() {
267
+ let array = dict.get(name).ok_or_else(|| {
268
+ PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
269
+ "Missing input: {}",
270
+ name
271
+ ))
272
+ })?;
273
+ arrays.insert(input_types[i], Array::from_vec(array.clone()).into_dyn());
274
+ }
275
+ Ok(arrays)
276
+ })
277
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
278
+ Ok(result)
279
+ }
280
+
281
+ async fn take_action(
282
+ &self,
283
+ action: Array<f32, IxDyn>,
284
+ metadata: &ModelMetadata,
285
+ ) -> Result<(), ModelError> {
286
+ Python::with_gil(|py| -> PyResult<()> {
287
+ let obj = self.obj.clone();
288
+ let action_1d = action
289
+ .into_dimensionality::<Ix1>()
290
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
291
+ let args = (
292
+ PyArray1::from_array(py, &action_1d),
293
+ PyModelMetadata::from(metadata.clone()),
294
+ );
295
+ obj.call_method(py, "take_action", args, None)?;
296
+ Ok(())
297
+ })
298
+ .map_err(|e| ModelError::Provider(e.to_string()))?;
299
+ Ok(())
300
+ }
301
+ }
302
+
303
+ #[gen_stub_pyclass]
304
+ #[pyclass]
305
+ #[derive(Clone)]
306
+ struct PyModelRunner {
307
+ runner: Arc<ModelRunner>,
308
+ runtime: Arc<tokio::runtime::Runtime>,
309
+ }
310
+
311
+ #[gen_stub_pymethods]
312
+ #[pymethods]
313
+ impl PyModelRunner {
314
+ #[new]
315
+ fn __new__(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
316
+ let input_provider = Arc::new(PyModelProvider::__new__(provider));
317
+
318
+ // Create a single runtime to be reused for all operations
319
+ let runtime = Arc::new(tokio::runtime::Runtime::new()
320
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?);
321
+
322
+ let runner = runtime.block_on(async {
323
+ ModelRunner::new(model_path, input_provider)
324
+ .await
325
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
326
+ })?;
327
+
328
+ Ok(Self {
329
+ runner: Arc::new(runner),
330
+ runtime,
331
+ })
332
+ }
333
+
334
+ // Reuse runtime and release GIL
335
+ fn init(&self) -> PyResult<Py<PyArrayDyn<f32>>> {
336
+ let runner = self.runner.clone();
337
+ let runtime = self.runtime.clone();
338
+
339
+ let result = Python::with_gil(|py| {
340
+ // Release GIL during async operation
341
+ py.allow_threads(|| {
342
+ runtime.block_on(async {
343
+ runner
344
+ .init()
345
+ .await
346
+ .map_err(|e| SendError(e.to_string()))
347
+ })
348
+ })
349
+ })
350
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.0))?;
351
+
352
+ Python::with_gil(|py| {
353
+ let array = numpy::PyArray::from_array(py, &result);
354
+ Ok(array.into())
355
+ })
356
+ }
357
+
358
+ // Reuse runtime and release GIL
359
+ fn step(&self, carry: Py<PyArrayDyn<f32>>) -> PyResult<StepResult> {
360
+ let runner = self.runner.clone();
361
+ let runtime = self.runtime.clone();
362
+
363
+ // Extract the carry array from Python with GIL
364
+ let carry_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
365
+ let carry_array = carry.bind(py);
366
+ Ok(carry_array.to_owned_array())
367
+ })?;
368
+
369
+ // Release GIL during computation
370
+ let result = Python::with_gil(|py| {
371
+ py.allow_threads(|| {
372
+ runtime.block_on(async {
373
+ runner
374
+ .step(carry_array)
375
+ .await
376
+ .map_err(|e| SendError(e.to_string()))
377
+ })
378
+ })
379
+ })
380
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.0))?;
381
+
382
+ // Reacquire the GIL to convert results back to Python objects
383
+ Python::with_gil(|py| {
384
+ let (output, carry) = result;
385
+ let output_array = numpy::PyArray::from_array(py, &output);
386
+ let carry_array = numpy::PyArray::from_array(py, &carry);
387
+ Ok((output_array.into(), carry_array.into()))
388
+ })
389
+ }
390
+
391
+ // Reuse runtime and release GIL
392
+ fn take_action(&self, action: Py<PyArrayDyn<f32>>) -> PyResult<()> {
393
+ let runner = self.runner.clone();
394
+ let runtime = self.runtime.clone();
395
+
396
+ // Extract action data with GIL
397
+ let action_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
398
+ let action_array = action.bind(py);
399
+ Ok(action_array.to_owned_array())
400
+ })?;
401
+
402
+ // Release GIL during computation
403
+ Python::with_gil(|py| {
404
+ py.allow_threads(|| {
405
+ runtime.block_on(async {
406
+ runner
407
+ .take_action(action_array)
408
+ .await
409
+ .map_err(|e| SendError(e.to_string()))
410
+ })
411
+ })
412
+ })
413
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.0))?;
414
+
415
+ Ok(())
416
+ }
417
+ }
418
+
419
+ #[gen_stub_pyclass]
420
+ #[pyclass]
421
+ #[derive(Clone)]
422
+ struct PyModelRuntime {
423
+ runtime: Arc<Mutex<ModelRuntime>>,
424
+ }
425
+
426
+ #[gen_stub_pymethods]
427
+ #[pymethods]
428
+ impl PyModelRuntime {
429
+ #[new]
430
+ fn __new__(model_runner: PyModelRunner, dt: u64) -> PyResult<Self> {
431
+ Ok(Self {
432
+ runtime: Arc::new(Mutex::new(ModelRuntime::new(model_runner.runner, dt))),
433
+ })
434
+ }
435
+
436
+ fn set_slowdown_factor(&self, slowdown_factor: i32) -> PyResult<()> {
437
+ let mut runtime = self
438
+ .runtime
439
+ .lock()
440
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
441
+ runtime.set_slowdown_factor(slowdown_factor);
442
+ Ok(())
443
+ }
444
+
445
+ fn set_magnitude_factor(&self, magnitude_factor: f32) -> PyResult<()> {
446
+ let mut runtime = self
447
+ .runtime
448
+ .lock()
449
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
450
+ runtime.set_magnitude_factor(magnitude_factor);
451
+ Ok(())
452
+ }
453
+
454
+ fn start(&self) -> PyResult<()> {
455
+ let mut runtime = self
456
+ .runtime
457
+ .lock()
458
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
459
+ runtime
460
+ .start()
461
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
462
+ }
463
+
464
+ fn stop(&self) -> PyResult<()> {
465
+ let mut runtime = self
466
+ .runtime
467
+ .lock()
468
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
469
+ runtime.stop();
470
+ Ok(())
471
+ }
472
+ }
473
+
474
+ #[pymodule]
475
+ fn rust_bindings(m: &Bound<PyModule>) -> PyResult<()> {
476
+ m.add_function(wrap_pyfunction!(get_version, m)?)?;
477
+ m.add_class::<PyInputType>()?;
478
+ m.add_class::<PyModelMetadata>()?;
479
+ m.add_function(wrap_pyfunction!(metadata_from_json, m)?)?;
480
+ m.add_class::<ModelProviderABC>()?;
481
+ m.add_class::<PyModelRunner>()?;
482
+ m.add_class::<PyModelRuntime>()?;
483
+ Ok(())
484
+ }
485
+
486
+ define_stub_info_gatherer!(stub_info);
@@ -0,0 +1,46 @@
1
+ # This file is automatically generated by pyo3_stub_gen
2
+ # ruff: noqa: E501, F401
3
+
4
+ import builtins
5
+ import numpy
6
+ import numpy.typing
7
+ import typing
8
+
9
+ class ModelProviderABC:
10
+ def __new__(cls) -> ModelProviderABC: ...
11
+ def get_inputs(self, input_types:typing.Sequence[builtins.str], metadata:PyModelMetadata) -> builtins.dict[builtins.str, numpy.typing.NDArray[numpy.float32]]: ...
12
+ def take_action(self, action:numpy.typing.NDArray[numpy.float32], metadata:PyModelMetadata) -> None: ...
13
+
14
+ class PyInputType:
15
+ def __new__(cls, input_type:builtins.str) -> PyInputType: ...
16
+ def get_name(self) -> builtins.str: ...
17
+ def get_shape(self, metadata:PyModelMetadata) -> builtins.list[builtins.int]: ...
18
+ def __repr__(self) -> builtins.str: ...
19
+ def __eq__(self, other:typing.Any) -> builtins.bool: ...
20
+
21
+ class PyModelMetadata:
22
+ def __new__(self, joint_names:typing.Sequence[builtins.str], num_commands:typing.Optional[builtins.int], carry_size:typing.Sequence[builtins.int]) -> PyModelMetadata: ...
23
+ def to_json(self) -> builtins.str: ...
24
+ def __repr__(self) -> builtins.str: ...
25
+ def __eq__(self, other:typing.Any) -> builtins.bool: ...
26
+
27
+ class PyModelProvider:
28
+ ...
29
+
30
+ class PyModelRunner:
31
+ def __new__(cls, model_path:builtins.str, provider:ModelProviderABC) -> PyModelRunner: ...
32
+ def init(self) -> numpy.typing.NDArray[numpy.float32]: ...
33
+ def step(self, carry:numpy.typing.NDArray[numpy.float32]) -> tuple[numpy.typing.NDArray[numpy.float32], numpy.typing.NDArray[numpy.float32]]: ...
34
+ def take_action(self, action:numpy.typing.NDArray[numpy.float32]) -> None: ...
35
+
36
+ class PyModelRuntime:
37
+ def __new__(cls, model_runner:PyModelRunner, dt:builtins.int) -> PyModelRuntime: ...
38
+ def set_slowdown_factor(self, slowdown_factor:builtins.int) -> None: ...
39
+ def set_magnitude_factor(self, magnitude_factor:builtins.float) -> None: ...
40
+ def start(self) -> None: ...
41
+ def stop(self) -> None: ...
42
+
43
+ def get_version() -> builtins.str: ...
44
+
45
+ def metadata_from_json(json:builtins.str) -> PyModelMetadata: ...
46
+
@@ -0,0 +1,177 @@
1
+ """Plot NDJSON logs saved by kinfer."""
2
+
3
+ import argparse
4
+ import json
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import Optional, Union
8
+
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+
12
+ # Set up logger
13
+ logger = logging.getLogger(__name__)
14
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
15
+
16
+
17
+ def read_ndjson(filepath: str) -> list[dict]:
18
+ """Read NDJSON file and return list of parsed objects."""
19
+ data = []
20
+ with open(filepath, "r") as f:
21
+ for line in f:
22
+ line = line.strip()
23
+ if line:
24
+ data.append(json.loads(line))
25
+ return data
26
+
27
+
28
+ def skip_initial_data(data: list[dict], skip_seconds: float) -> list[dict]:
29
+ """Skip the first n seconds of data based on timestamps."""
30
+ if not data or skip_seconds <= 0.0:
31
+ return data
32
+
33
+ # Extract timestamps and convert to seconds relative to first timestamp
34
+ timestamps = [d["t_us"] for d in data]
35
+ t_start = timestamps[0]
36
+ times = [(t - t_start) / 1e6 for t in timestamps] # Convert to seconds
37
+
38
+ # Find indices where time >= skip_seconds
39
+ skip_indices = [i for i, t in enumerate(times) if t >= skip_seconds]
40
+ if not skip_indices:
41
+ logger.info("All data points are within the skip period (%.2f seconds). No data to plot.", skip_seconds)
42
+ return []
43
+
44
+ # Filter data
45
+ start_idx = skip_indices[0]
46
+ filtered_data = data[start_idx:]
47
+ logger.info("Skipped first %.2f seconds (%d data points)", skip_seconds, start_idx)
48
+
49
+ return filtered_data
50
+
51
+
52
+ def plot_data(data: list[dict], save_path: Optional[Union[str, Path]] = None) -> None:
53
+ """Plot all data fields from the NDJSON."""
54
+ if not data:
55
+ logger.info("No data to plot")
56
+ return
57
+
58
+ # Extract timestamps and convert to seconds relative to first timestamp
59
+ timestamps = [d["t_us"] for d in data]
60
+ t_start = timestamps[0]
61
+ times = [(t - t_start) / 1e6 for t in timestamps] # Convert to seconds
62
+
63
+ # Extract data arrays
64
+ joint_angles = np.array([d["joint_angles"] for d in data if d["joint_angles"] is not None])
65
+ joint_vels = np.array([d["joint_vels"] for d in data if d["joint_vels"] is not None])
66
+ projected_g = np.array([d["projected_g"] for d in data if d["projected_g"] is not None])
67
+ accel = np.array([d["accel"] for d in data if d["accel"] is not None])
68
+ command = np.array([d["command"] for d in data if d["command"] is not None])
69
+ output = np.array([d["output"] for d in data if d["output"] is not None])
70
+
71
+ # Create subplots
72
+ fig, axes = plt.subplots(3, 2, figsize=(15, 12))
73
+ fig.suptitle("Robot Data Over Time", fontsize=16)
74
+
75
+ if len(joint_angles) > 0:
76
+ ax = axes[0, 0]
77
+ for i in range(joint_angles.shape[1]):
78
+ ax.plot(times[: len(joint_angles)], joint_angles[:, i], alpha=0.7, linewidth=0.8)
79
+ ax.set_title("Joint Angles")
80
+ ax.set_xlabel("Time (s)")
81
+ ax.set_ylabel("Angle (rad)")
82
+ ax.grid(True, alpha=0.3)
83
+
84
+ if len(joint_vels) > 0:
85
+ ax = axes[0, 1]
86
+ for i in range(joint_vels.shape[1]):
87
+ ax.plot(times[: len(joint_vels)], joint_vels[:, i], alpha=0.7, linewidth=0.8)
88
+ ax.set_title("Joint Velocities")
89
+ ax.set_xlabel("Time (s)")
90
+ ax.set_ylabel("Velocity (rad/s)")
91
+ ax.grid(True, alpha=0.3)
92
+
93
+ if len(projected_g) > 0:
94
+ ax = axes[1, 0]
95
+ labels = ["X", "Y", "Z"]
96
+ for i in range(projected_g.shape[1]):
97
+ ax.plot(times[: len(projected_g)], projected_g[:, i], label=labels[i], linewidth=1.5)
98
+ ax.set_title("Projected Gravity")
99
+ ax.set_xlabel("Time (s)")
100
+ ax.set_ylabel("Acceleration (m/s²)")
101
+ ax.legend()
102
+ ax.grid(True, alpha=0.3)
103
+
104
+ if len(accel) > 0:
105
+ ax = axes[1, 1]
106
+ labels = ["X", "Y", "Z"]
107
+ for i in range(accel.shape[1]):
108
+ ax.plot(times[: len(accel)], accel[:, i], label=labels[i], linewidth=1.5)
109
+ ax.set_title("Acceleration")
110
+ ax.set_xlabel("Time (s)")
111
+ ax.set_ylabel("Acceleration (m/s²)")
112
+ ax.legend()
113
+ ax.grid(True, alpha=0.3)
114
+
115
+ if len(command) > 0:
116
+ ax = axes[2, 0]
117
+ for i in range(command.shape[1]):
118
+ ax.plot(times[: len(command)], command[:, i], label=f"Cmd {i}", linewidth=1.2)
119
+ ax.set_title("Command")
120
+ ax.set_xlabel("Time (s)")
121
+ ax.set_ylabel("Command Value")
122
+ ax.legend()
123
+ ax.grid(True, alpha=0.3)
124
+
125
+ if len(output) > 0:
126
+ ax = axes[2, 1]
127
+ for i in range(output.shape[1]):
128
+ ax.plot(times[: len(output)], output[:, i], alpha=0.7, linewidth=0.8)
129
+ ax.set_title("Output")
130
+ ax.set_xlabel("Time (s)")
131
+ ax.set_ylabel("Output Value")
132
+ ax.grid(True, alpha=0.3)
133
+
134
+ plt.tight_layout()
135
+
136
+ if save_path:
137
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
138
+ logger.info("Plot saved to: %s", save_path)
139
+ plt.close()
140
+ else:
141
+ plt.show()
142
+
143
+
144
+ def main() -> None:
145
+ parser = argparse.ArgumentParser(description="Plot NDJSON logs saved by kinfer")
146
+ parser.add_argument("filepath", help="Path to the NDJSON file to plot")
147
+ parser.add_argument("--skip", type=float, default=0.0, help="Skip the first n seconds of data")
148
+ parser.add_argument("--save", action="store_true", help="Save the plot to a PNG file in a plots folder")
149
+ args = parser.parse_args()
150
+
151
+ filepath = args.filepath
152
+ if not Path(filepath).exists():
153
+ logger.info("File not found: %s", filepath)
154
+ return
155
+
156
+ logger.info("Reading data from %s...", filepath)
157
+ data = read_ndjson(filepath)
158
+ logger.info("Loaded %d data points", len(data))
159
+
160
+ filtered_data = skip_initial_data(data, args.skip)
161
+
162
+ save_path = None
163
+ if args.save:
164
+ # Create save path in plots folder with same name but .png extension
165
+ input_path = Path(filepath)
166
+ plots_dir = input_path.parent / "plots"
167
+ plots_dir.mkdir(exist_ok=True)
168
+
169
+ # Change extension from .ndjson to .png
170
+ filename = input_path.stem + ".png"
171
+ save_path = str(plots_dir / filename)
172
+
173
+ plot_data(filtered_data, save_path)
174
+
175
+
176
+ if __name__ == "__main__":
177
+ main()