kinfer 0.5.4__cp313-cp313-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +16 -0
- kinfer/export/__init__.py +3 -0
- kinfer/export/jax.py +55 -0
- kinfer/export/pytorch.py +53 -0
- kinfer/export/serialize.py +84 -0
- kinfer/py.typed +0 -0
- kinfer/requirements.txt +13 -0
- kinfer/rust/Cargo.toml +36 -0
- kinfer/rust/src/lib.rs +8 -0
- kinfer/rust/src/logger.rs +141 -0
- kinfer/rust/src/model.rs +354 -0
- kinfer/rust/src/runtime.rs +107 -0
- kinfer/rust/src/types.rs +96 -0
- kinfer/rust_bindings/Cargo.toml +26 -0
- kinfer/rust_bindings/pyproject.toml +7 -0
- kinfer/rust_bindings/rust_bindings.pyi +46 -0
- kinfer/rust_bindings/src/bin/stub_gen.rs +7 -0
- kinfer/rust_bindings/src/lib.rs +486 -0
- kinfer/rust_bindings.cpython-313-x86_64-linux-gnu.so +0 -0
- kinfer/rust_bindings.pyi +46 -0
- kinfer/scripts/plot_ndjson.py +177 -0
- kinfer-0.5.4.dist-info/METADATA +63 -0
- kinfer-0.5.4.dist-info/RECORD +26 -0
- kinfer-0.5.4.dist-info/WHEEL +5 -0
- kinfer-0.5.4.dist-info/licenses/LICENSE +21 -0
- kinfer-0.5.4.dist-info/top_level.txt +1 -0
@@ -0,0 +1,486 @@
|
|
1
|
+
use async_trait::async_trait;
|
2
|
+
use kinfer::model::{ModelError, ModelProvider, ModelRunner};
|
3
|
+
use kinfer::runtime::ModelRuntime;
|
4
|
+
use kinfer::types::{InputType, ModelMetadata};
|
5
|
+
use ndarray::{Array, Ix1, IxDyn};
|
6
|
+
use numpy::{PyArray1, PyArrayDyn, PyArrayMethods};
|
7
|
+
use pyo3::exceptions::PyNotImplementedError;
|
8
|
+
use pyo3::prelude::*;
|
9
|
+
use pyo3::types::{PyAny, PyAnyMethods};
|
10
|
+
use pyo3::{pymodule, types::PyModule, Bound, PyResult, Python};
|
11
|
+
use pyo3_stub_gen::define_stub_info_gatherer;
|
12
|
+
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
|
13
|
+
use std::collections::HashMap;
|
14
|
+
use std::hash::Hash;
|
15
|
+
use std::sync::Arc;
|
16
|
+
use std::sync::Mutex;
|
17
|
+
|
18
|
+
type StepResult = (Py<PyArrayDyn<f32>>, Py<PyArrayDyn<f32>>);
|
19
|
+
|
20
|
+
// Custom error type for Send/Sync compatibility
|
21
|
+
#[derive(Debug)]
|
22
|
+
struct SendError(String);
|
23
|
+
|
24
|
+
unsafe impl Send for SendError {}
|
25
|
+
unsafe impl Sync for SendError {}
|
26
|
+
|
27
|
+
impl std::fmt::Display for SendError {
|
28
|
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
29
|
+
write!(f, "{}", self.0)
|
30
|
+
}
|
31
|
+
}
|
32
|
+
|
33
|
+
#[pyfunction]
|
34
|
+
#[gen_stub_pyfunction]
|
35
|
+
fn get_version() -> String {
|
36
|
+
env!("CARGO_PKG_VERSION").to_string()
|
37
|
+
}
|
38
|
+
|
39
|
+
#[pyclass]
|
40
|
+
#[gen_stub_pyclass]
|
41
|
+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
42
|
+
struct PyInputType {
|
43
|
+
pub input_type: InputType,
|
44
|
+
}
|
45
|
+
|
46
|
+
impl From<InputType> for PyInputType {
|
47
|
+
fn from(input_type: InputType) -> Self {
|
48
|
+
Self { input_type }
|
49
|
+
}
|
50
|
+
}
|
51
|
+
|
52
|
+
impl From<PyInputType> for InputType {
|
53
|
+
fn from(input_type: PyInputType) -> Self {
|
54
|
+
input_type.input_type
|
55
|
+
}
|
56
|
+
}
|
57
|
+
|
58
|
+
#[gen_stub_pymethods]
|
59
|
+
#[pymethods]
|
60
|
+
impl PyInputType {
|
61
|
+
#[new]
|
62
|
+
fn __new__(input_type: &str) -> PyResult<Self> {
|
63
|
+
let input_type = InputType::from_name(input_type).map_or_else(
|
64
|
+
|_| {
|
65
|
+
Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
66
|
+
"Invalid input type: {} (must be one of {})",
|
67
|
+
input_type,
|
68
|
+
InputType::get_names().join(", "),
|
69
|
+
)))
|
70
|
+
},
|
71
|
+
Ok,
|
72
|
+
)?;
|
73
|
+
Ok(Self { input_type })
|
74
|
+
}
|
75
|
+
|
76
|
+
fn get_name(&self) -> String {
|
77
|
+
self.input_type.get_name().to_string()
|
78
|
+
}
|
79
|
+
|
80
|
+
fn get_shape(&self, metadata: PyModelMetadata) -> Vec<usize> {
|
81
|
+
self.input_type.get_shape(&metadata.into())
|
82
|
+
}
|
83
|
+
|
84
|
+
fn __repr__(&self) -> String {
|
85
|
+
format!("InputType({})", self.get_name())
|
86
|
+
}
|
87
|
+
|
88
|
+
fn __eq__(&self, other: Bound<'_, PyAny>) -> PyResult<bool> {
|
89
|
+
if let Ok(other) = other.extract::<PyInputType>() {
|
90
|
+
Ok(self == &other)
|
91
|
+
} else {
|
92
|
+
Ok(false)
|
93
|
+
}
|
94
|
+
}
|
95
|
+
}
|
96
|
+
|
97
|
+
#[pyclass]
|
98
|
+
#[gen_stub_pyclass]
|
99
|
+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
100
|
+
struct PyModelMetadata {
|
101
|
+
#[pyo3(get, set)]
|
102
|
+
pub joint_names: Vec<String>,
|
103
|
+
#[pyo3(get, set)]
|
104
|
+
pub num_commands: Option<usize>,
|
105
|
+
#[pyo3(get, set)]
|
106
|
+
pub carry_size: Vec<usize>,
|
107
|
+
}
|
108
|
+
|
109
|
+
#[pymethods]
|
110
|
+
#[gen_stub_pymethods]
|
111
|
+
impl PyModelMetadata {
|
112
|
+
#[new]
|
113
|
+
fn __new__(
|
114
|
+
joint_names: Vec<String>,
|
115
|
+
num_commands: Option<usize>,
|
116
|
+
carry_size: Vec<usize>,
|
117
|
+
) -> Self {
|
118
|
+
Self {
|
119
|
+
joint_names,
|
120
|
+
num_commands,
|
121
|
+
carry_size,
|
122
|
+
}
|
123
|
+
}
|
124
|
+
|
125
|
+
fn to_json(&self) -> PyResult<String> {
|
126
|
+
let metadata = ModelMetadata {
|
127
|
+
joint_names: self.joint_names.clone(),
|
128
|
+
num_commands: self.num_commands,
|
129
|
+
carry_size: self.carry_size.clone(),
|
130
|
+
}
|
131
|
+
.to_json()
|
132
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
133
|
+
Ok(metadata)
|
134
|
+
}
|
135
|
+
|
136
|
+
fn __repr__(&self) -> PyResult<String> {
|
137
|
+
let json = self.to_json()?;
|
138
|
+
Ok(format!("ModelMetadata({:?})", json))
|
139
|
+
}
|
140
|
+
|
141
|
+
fn __eq__(&self, other: Bound<'_, PyAny>) -> PyResult<bool> {
|
142
|
+
if let Ok(other) = other.extract::<PyModelMetadata>() {
|
143
|
+
Ok(self == &other)
|
144
|
+
} else {
|
145
|
+
Ok(false)
|
146
|
+
}
|
147
|
+
}
|
148
|
+
}
|
149
|
+
|
150
|
+
#[pyfunction]
|
151
|
+
#[gen_stub_pyfunction]
|
152
|
+
fn metadata_from_json(json: &str) -> PyResult<PyModelMetadata> {
|
153
|
+
let metadata = ModelMetadata::model_validate_json(json.to_string()).map_err(|e| {
|
154
|
+
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Invalid model metadata: {}", e))
|
155
|
+
})?;
|
156
|
+
Ok(PyModelMetadata::from(metadata))
|
157
|
+
}
|
158
|
+
|
159
|
+
impl From<ModelMetadata> for PyModelMetadata {
|
160
|
+
fn from(metadata: ModelMetadata) -> Self {
|
161
|
+
Self {
|
162
|
+
joint_names: metadata.joint_names,
|
163
|
+
num_commands: metadata.num_commands,
|
164
|
+
carry_size: metadata.carry_size,
|
165
|
+
}
|
166
|
+
}
|
167
|
+
}
|
168
|
+
|
169
|
+
impl From<&ModelMetadata> for PyModelMetadata {
|
170
|
+
fn from(metadata: &ModelMetadata) -> Self {
|
171
|
+
Self {
|
172
|
+
joint_names: metadata.joint_names.clone(),
|
173
|
+
num_commands: metadata.num_commands,
|
174
|
+
carry_size: metadata.carry_size.clone(),
|
175
|
+
}
|
176
|
+
}
|
177
|
+
}
|
178
|
+
|
179
|
+
impl From<PyModelMetadata> for ModelMetadata {
|
180
|
+
fn from(metadata: PyModelMetadata) -> Self {
|
181
|
+
Self {
|
182
|
+
joint_names: metadata.joint_names,
|
183
|
+
num_commands: metadata.num_commands,
|
184
|
+
carry_size: metadata.carry_size,
|
185
|
+
}
|
186
|
+
}
|
187
|
+
}
|
188
|
+
|
189
|
+
#[pyclass(subclass)]
|
190
|
+
#[gen_stub_pyclass]
|
191
|
+
struct ModelProviderABC;
|
192
|
+
|
193
|
+
#[gen_stub_pymethods]
|
194
|
+
#[pymethods]
|
195
|
+
impl ModelProviderABC {
|
196
|
+
#[new]
|
197
|
+
fn __new__() -> Self {
|
198
|
+
ModelProviderABC
|
199
|
+
}
|
200
|
+
|
201
|
+
fn get_inputs<'py>(
|
202
|
+
&self,
|
203
|
+
input_types: Vec<String>,
|
204
|
+
metadata: PyModelMetadata,
|
205
|
+
) -> PyResult<HashMap<String, Bound<'py, PyArrayDyn<f32>>>> {
|
206
|
+
Err(PyNotImplementedError::new_err(format!(
|
207
|
+
"Must override get_inputs with {} input types {:?} and metadata {:?}",
|
208
|
+
input_types.len(),
|
209
|
+
input_types,
|
210
|
+
metadata
|
211
|
+
)))
|
212
|
+
}
|
213
|
+
|
214
|
+
fn take_action(
|
215
|
+
&self,
|
216
|
+
action: Bound<'_, PyArray1<f32>>,
|
217
|
+
metadata: PyModelMetadata,
|
218
|
+
) -> PyResult<()> {
|
219
|
+
let n = action.len()?;
|
220
|
+
if metadata.joint_names.len() != n {
|
221
|
+
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
222
|
+
"Expected {} joints, got {} action elements",
|
223
|
+
metadata.joint_names.len(),
|
224
|
+
n
|
225
|
+
)));
|
226
|
+
}
|
227
|
+
Err(PyNotImplementedError::new_err(format!(
|
228
|
+
"Must override take_action with {} action elements",
|
229
|
+
n
|
230
|
+
)))
|
231
|
+
}
|
232
|
+
}
|
233
|
+
|
234
|
+
#[gen_stub_pyclass]
|
235
|
+
#[pyclass]
|
236
|
+
#[derive(Clone)]
|
237
|
+
struct PyModelProvider {
|
238
|
+
obj: Arc<Py<ModelProviderABC>>,
|
239
|
+
}
|
240
|
+
|
241
|
+
#[pymethods]
|
242
|
+
impl PyModelProvider {
|
243
|
+
#[new]
|
244
|
+
fn __new__(obj: Py<ModelProviderABC>) -> Self {
|
245
|
+
Self { obj: Arc::new(obj) }
|
246
|
+
}
|
247
|
+
}
|
248
|
+
|
249
|
+
#[async_trait]
|
250
|
+
impl ModelProvider for PyModelProvider {
|
251
|
+
async fn get_inputs(
|
252
|
+
&self,
|
253
|
+
input_types: &[InputType],
|
254
|
+
metadata: &ModelMetadata,
|
255
|
+
) -> Result<HashMap<InputType, Array<f32, IxDyn>>, ModelError> {
|
256
|
+
let input_names: Vec<String> = input_types
|
257
|
+
.iter()
|
258
|
+
.map(|t| t.get_name().to_string())
|
259
|
+
.collect();
|
260
|
+
let result = Python::with_gil(|py| -> PyResult<HashMap<InputType, Array<f32, IxDyn>>> {
|
261
|
+
let obj = self.obj.clone();
|
262
|
+
let args = (input_names.clone(), PyModelMetadata::from(metadata.clone()));
|
263
|
+
let result = obj.call_method(py, "get_inputs", args, None)?;
|
264
|
+
let dict: HashMap<String, Vec<f32>> = result.extract(py)?;
|
265
|
+
let mut arrays = HashMap::new();
|
266
|
+
for (i, name) in input_names.iter().enumerate() {
|
267
|
+
let array = dict.get(name).ok_or_else(|| {
|
268
|
+
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
269
|
+
"Missing input: {}",
|
270
|
+
name
|
271
|
+
))
|
272
|
+
})?;
|
273
|
+
arrays.insert(input_types[i], Array::from_vec(array.clone()).into_dyn());
|
274
|
+
}
|
275
|
+
Ok(arrays)
|
276
|
+
})
|
277
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
278
|
+
Ok(result)
|
279
|
+
}
|
280
|
+
|
281
|
+
async fn take_action(
|
282
|
+
&self,
|
283
|
+
action: Array<f32, IxDyn>,
|
284
|
+
metadata: &ModelMetadata,
|
285
|
+
) -> Result<(), ModelError> {
|
286
|
+
Python::with_gil(|py| -> PyResult<()> {
|
287
|
+
let obj = self.obj.clone();
|
288
|
+
let action_1d = action
|
289
|
+
.into_dimensionality::<Ix1>()
|
290
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
291
|
+
let args = (
|
292
|
+
PyArray1::from_array(py, &action_1d),
|
293
|
+
PyModelMetadata::from(metadata.clone()),
|
294
|
+
);
|
295
|
+
obj.call_method(py, "take_action", args, None)?;
|
296
|
+
Ok(())
|
297
|
+
})
|
298
|
+
.map_err(|e| ModelError::Provider(e.to_string()))?;
|
299
|
+
Ok(())
|
300
|
+
}
|
301
|
+
}
|
302
|
+
|
303
|
+
#[gen_stub_pyclass]
|
304
|
+
#[pyclass]
|
305
|
+
#[derive(Clone)]
|
306
|
+
struct PyModelRunner {
|
307
|
+
runner: Arc<ModelRunner>,
|
308
|
+
runtime: Arc<tokio::runtime::Runtime>,
|
309
|
+
}
|
310
|
+
|
311
|
+
#[gen_stub_pymethods]
|
312
|
+
#[pymethods]
|
313
|
+
impl PyModelRunner {
|
314
|
+
#[new]
|
315
|
+
fn __new__(model_path: String, provider: Py<ModelProviderABC>) -> PyResult<Self> {
|
316
|
+
let input_provider = Arc::new(PyModelProvider::__new__(provider));
|
317
|
+
|
318
|
+
// Create a single runtime to be reused for all operations
|
319
|
+
let runtime = Arc::new(tokio::runtime::Runtime::new()
|
320
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?);
|
321
|
+
|
322
|
+
let runner = runtime.block_on(async {
|
323
|
+
ModelRunner::new(model_path, input_provider)
|
324
|
+
.await
|
325
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
326
|
+
})?;
|
327
|
+
|
328
|
+
Ok(Self {
|
329
|
+
runner: Arc::new(runner),
|
330
|
+
runtime,
|
331
|
+
})
|
332
|
+
}
|
333
|
+
|
334
|
+
// Reuse runtime and release GIL
|
335
|
+
fn init(&self) -> PyResult<Py<PyArrayDyn<f32>>> {
|
336
|
+
let runner = self.runner.clone();
|
337
|
+
let runtime = self.runtime.clone();
|
338
|
+
|
339
|
+
let result = Python::with_gil(|py| {
|
340
|
+
// Release GIL during async operation
|
341
|
+
py.allow_threads(|| {
|
342
|
+
runtime.block_on(async {
|
343
|
+
runner
|
344
|
+
.init()
|
345
|
+
.await
|
346
|
+
.map_err(|e| SendError(e.to_string()))
|
347
|
+
})
|
348
|
+
})
|
349
|
+
})
|
350
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.0))?;
|
351
|
+
|
352
|
+
Python::with_gil(|py| {
|
353
|
+
let array = numpy::PyArray::from_array(py, &result);
|
354
|
+
Ok(array.into())
|
355
|
+
})
|
356
|
+
}
|
357
|
+
|
358
|
+
// Reuse runtime and release GIL
|
359
|
+
fn step(&self, carry: Py<PyArrayDyn<f32>>) -> PyResult<StepResult> {
|
360
|
+
let runner = self.runner.clone();
|
361
|
+
let runtime = self.runtime.clone();
|
362
|
+
|
363
|
+
// Extract the carry array from Python with GIL
|
364
|
+
let carry_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
365
|
+
let carry_array = carry.bind(py);
|
366
|
+
Ok(carry_array.to_owned_array())
|
367
|
+
})?;
|
368
|
+
|
369
|
+
// Release GIL during computation
|
370
|
+
let result = Python::with_gil(|py| {
|
371
|
+
py.allow_threads(|| {
|
372
|
+
runtime.block_on(async {
|
373
|
+
runner
|
374
|
+
.step(carry_array)
|
375
|
+
.await
|
376
|
+
.map_err(|e| SendError(e.to_string()))
|
377
|
+
})
|
378
|
+
})
|
379
|
+
})
|
380
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.0))?;
|
381
|
+
|
382
|
+
// Reacquire the GIL to convert results back to Python objects
|
383
|
+
Python::with_gil(|py| {
|
384
|
+
let (output, carry) = result;
|
385
|
+
let output_array = numpy::PyArray::from_array(py, &output);
|
386
|
+
let carry_array = numpy::PyArray::from_array(py, &carry);
|
387
|
+
Ok((output_array.into(), carry_array.into()))
|
388
|
+
})
|
389
|
+
}
|
390
|
+
|
391
|
+
// Reuse runtime and release GIL
|
392
|
+
fn take_action(&self, action: Py<PyArrayDyn<f32>>) -> PyResult<()> {
|
393
|
+
let runner = self.runner.clone();
|
394
|
+
let runtime = self.runtime.clone();
|
395
|
+
|
396
|
+
// Extract action data with GIL
|
397
|
+
let action_array = Python::with_gil(|py| -> PyResult<Array<f32, IxDyn>> {
|
398
|
+
let action_array = action.bind(py);
|
399
|
+
Ok(action_array.to_owned_array())
|
400
|
+
})?;
|
401
|
+
|
402
|
+
// Release GIL during computation
|
403
|
+
Python::with_gil(|py| {
|
404
|
+
py.allow_threads(|| {
|
405
|
+
runtime.block_on(async {
|
406
|
+
runner
|
407
|
+
.take_action(action_array)
|
408
|
+
.await
|
409
|
+
.map_err(|e| SendError(e.to_string()))
|
410
|
+
})
|
411
|
+
})
|
412
|
+
})
|
413
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.0))?;
|
414
|
+
|
415
|
+
Ok(())
|
416
|
+
}
|
417
|
+
}
|
418
|
+
|
419
|
+
#[gen_stub_pyclass]
|
420
|
+
#[pyclass]
|
421
|
+
#[derive(Clone)]
|
422
|
+
struct PyModelRuntime {
|
423
|
+
runtime: Arc<Mutex<ModelRuntime>>,
|
424
|
+
}
|
425
|
+
|
426
|
+
#[gen_stub_pymethods]
|
427
|
+
#[pymethods]
|
428
|
+
impl PyModelRuntime {
|
429
|
+
#[new]
|
430
|
+
fn __new__(model_runner: PyModelRunner, dt: u64) -> PyResult<Self> {
|
431
|
+
Ok(Self {
|
432
|
+
runtime: Arc::new(Mutex::new(ModelRuntime::new(model_runner.runner, dt))),
|
433
|
+
})
|
434
|
+
}
|
435
|
+
|
436
|
+
fn set_slowdown_factor(&self, slowdown_factor: i32) -> PyResult<()> {
|
437
|
+
let mut runtime = self
|
438
|
+
.runtime
|
439
|
+
.lock()
|
440
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
441
|
+
runtime.set_slowdown_factor(slowdown_factor);
|
442
|
+
Ok(())
|
443
|
+
}
|
444
|
+
|
445
|
+
fn set_magnitude_factor(&self, magnitude_factor: f32) -> PyResult<()> {
|
446
|
+
let mut runtime = self
|
447
|
+
.runtime
|
448
|
+
.lock()
|
449
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
450
|
+
runtime.set_magnitude_factor(magnitude_factor);
|
451
|
+
Ok(())
|
452
|
+
}
|
453
|
+
|
454
|
+
fn start(&self) -> PyResult<()> {
|
455
|
+
let mut runtime = self
|
456
|
+
.runtime
|
457
|
+
.lock()
|
458
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
459
|
+
runtime
|
460
|
+
.start()
|
461
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
462
|
+
}
|
463
|
+
|
464
|
+
fn stop(&self) -> PyResult<()> {
|
465
|
+
let mut runtime = self
|
466
|
+
.runtime
|
467
|
+
.lock()
|
468
|
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
|
469
|
+
runtime.stop();
|
470
|
+
Ok(())
|
471
|
+
}
|
472
|
+
}
|
473
|
+
|
474
|
+
#[pymodule]
|
475
|
+
fn rust_bindings(m: &Bound<PyModule>) -> PyResult<()> {
|
476
|
+
m.add_function(wrap_pyfunction!(get_version, m)?)?;
|
477
|
+
m.add_class::<PyInputType>()?;
|
478
|
+
m.add_class::<PyModelMetadata>()?;
|
479
|
+
m.add_function(wrap_pyfunction!(metadata_from_json, m)?)?;
|
480
|
+
m.add_class::<ModelProviderABC>()?;
|
481
|
+
m.add_class::<PyModelRunner>()?;
|
482
|
+
m.add_class::<PyModelRuntime>()?;
|
483
|
+
Ok(())
|
484
|
+
}
|
485
|
+
|
486
|
+
define_stub_info_gatherer!(stub_info);
|
Binary file
|
kinfer/rust_bindings.pyi
ADDED
@@ -0,0 +1,46 @@
|
|
1
|
+
# This file is automatically generated by pyo3_stub_gen
|
2
|
+
# ruff: noqa: E501, F401
|
3
|
+
|
4
|
+
import builtins
|
5
|
+
import numpy
|
6
|
+
import numpy.typing
|
7
|
+
import typing
|
8
|
+
|
9
|
+
class ModelProviderABC:
|
10
|
+
def __new__(cls) -> ModelProviderABC: ...
|
11
|
+
def get_inputs(self, input_types:typing.Sequence[builtins.str], metadata:PyModelMetadata) -> builtins.dict[builtins.str, numpy.typing.NDArray[numpy.float32]]: ...
|
12
|
+
def take_action(self, action:numpy.typing.NDArray[numpy.float32], metadata:PyModelMetadata) -> None: ...
|
13
|
+
|
14
|
+
class PyInputType:
|
15
|
+
def __new__(cls, input_type:builtins.str) -> PyInputType: ...
|
16
|
+
def get_name(self) -> builtins.str: ...
|
17
|
+
def get_shape(self, metadata:PyModelMetadata) -> builtins.list[builtins.int]: ...
|
18
|
+
def __repr__(self) -> builtins.str: ...
|
19
|
+
def __eq__(self, other:typing.Any) -> builtins.bool: ...
|
20
|
+
|
21
|
+
class PyModelMetadata:
|
22
|
+
def __new__(self, joint_names:typing.Sequence[builtins.str], num_commands:typing.Optional[builtins.int], carry_size:typing.Sequence[builtins.int]) -> PyModelMetadata: ...
|
23
|
+
def to_json(self) -> builtins.str: ...
|
24
|
+
def __repr__(self) -> builtins.str: ...
|
25
|
+
def __eq__(self, other:typing.Any) -> builtins.bool: ...
|
26
|
+
|
27
|
+
class PyModelProvider:
|
28
|
+
...
|
29
|
+
|
30
|
+
class PyModelRunner:
|
31
|
+
def __new__(cls, model_path:builtins.str, provider:ModelProviderABC) -> PyModelRunner: ...
|
32
|
+
def init(self) -> numpy.typing.NDArray[numpy.float32]: ...
|
33
|
+
def step(self, carry:numpy.typing.NDArray[numpy.float32]) -> tuple[numpy.typing.NDArray[numpy.float32], numpy.typing.NDArray[numpy.float32]]: ...
|
34
|
+
def take_action(self, action:numpy.typing.NDArray[numpy.float32]) -> None: ...
|
35
|
+
|
36
|
+
class PyModelRuntime:
|
37
|
+
def __new__(cls, model_runner:PyModelRunner, dt:builtins.int) -> PyModelRuntime: ...
|
38
|
+
def set_slowdown_factor(self, slowdown_factor:builtins.int) -> None: ...
|
39
|
+
def set_magnitude_factor(self, magnitude_factor:builtins.float) -> None: ...
|
40
|
+
def start(self) -> None: ...
|
41
|
+
def stop(self) -> None: ...
|
42
|
+
|
43
|
+
def get_version() -> builtins.str: ...
|
44
|
+
|
45
|
+
def metadata_from_json(json:builtins.str) -> PyModelMetadata: ...
|
46
|
+
|
@@ -0,0 +1,177 @@
|
|
1
|
+
"""Plot NDJSON logs saved by kinfer."""
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Optional, Union
|
8
|
+
|
9
|
+
import matplotlib.pyplot as plt
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
# Set up logger
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
15
|
+
|
16
|
+
|
17
|
+
def read_ndjson(filepath: str) -> list[dict]:
|
18
|
+
"""Read NDJSON file and return list of parsed objects."""
|
19
|
+
data = []
|
20
|
+
with open(filepath, "r") as f:
|
21
|
+
for line in f:
|
22
|
+
line = line.strip()
|
23
|
+
if line:
|
24
|
+
data.append(json.loads(line))
|
25
|
+
return data
|
26
|
+
|
27
|
+
|
28
|
+
def skip_initial_data(data: list[dict], skip_seconds: float) -> list[dict]:
|
29
|
+
"""Skip the first n seconds of data based on timestamps."""
|
30
|
+
if not data or skip_seconds <= 0.0:
|
31
|
+
return data
|
32
|
+
|
33
|
+
# Extract timestamps and convert to seconds relative to first timestamp
|
34
|
+
timestamps = [d["t_us"] for d in data]
|
35
|
+
t_start = timestamps[0]
|
36
|
+
times = [(t - t_start) / 1e6 for t in timestamps] # Convert to seconds
|
37
|
+
|
38
|
+
# Find indices where time >= skip_seconds
|
39
|
+
skip_indices = [i for i, t in enumerate(times) if t >= skip_seconds]
|
40
|
+
if not skip_indices:
|
41
|
+
logger.info("All data points are within the skip period (%.2f seconds). No data to plot.", skip_seconds)
|
42
|
+
return []
|
43
|
+
|
44
|
+
# Filter data
|
45
|
+
start_idx = skip_indices[0]
|
46
|
+
filtered_data = data[start_idx:]
|
47
|
+
logger.info("Skipped first %.2f seconds (%d data points)", skip_seconds, start_idx)
|
48
|
+
|
49
|
+
return filtered_data
|
50
|
+
|
51
|
+
|
52
|
+
def plot_data(data: list[dict], save_path: Optional[Union[str, Path]] = None) -> None:
|
53
|
+
"""Plot all data fields from the NDJSON."""
|
54
|
+
if not data:
|
55
|
+
logger.info("No data to plot")
|
56
|
+
return
|
57
|
+
|
58
|
+
# Extract timestamps and convert to seconds relative to first timestamp
|
59
|
+
timestamps = [d["t_us"] for d in data]
|
60
|
+
t_start = timestamps[0]
|
61
|
+
times = [(t - t_start) / 1e6 for t in timestamps] # Convert to seconds
|
62
|
+
|
63
|
+
# Extract data arrays
|
64
|
+
joint_angles = np.array([d["joint_angles"] for d in data if d["joint_angles"] is not None])
|
65
|
+
joint_vels = np.array([d["joint_vels"] for d in data if d["joint_vels"] is not None])
|
66
|
+
projected_g = np.array([d["projected_g"] for d in data if d["projected_g"] is not None])
|
67
|
+
accel = np.array([d["accel"] for d in data if d["accel"] is not None])
|
68
|
+
command = np.array([d["command"] for d in data if d["command"] is not None])
|
69
|
+
output = np.array([d["output"] for d in data if d["output"] is not None])
|
70
|
+
|
71
|
+
# Create subplots
|
72
|
+
fig, axes = plt.subplots(3, 2, figsize=(15, 12))
|
73
|
+
fig.suptitle("Robot Data Over Time", fontsize=16)
|
74
|
+
|
75
|
+
if len(joint_angles) > 0:
|
76
|
+
ax = axes[0, 0]
|
77
|
+
for i in range(joint_angles.shape[1]):
|
78
|
+
ax.plot(times[: len(joint_angles)], joint_angles[:, i], alpha=0.7, linewidth=0.8)
|
79
|
+
ax.set_title("Joint Angles")
|
80
|
+
ax.set_xlabel("Time (s)")
|
81
|
+
ax.set_ylabel("Angle (rad)")
|
82
|
+
ax.grid(True, alpha=0.3)
|
83
|
+
|
84
|
+
if len(joint_vels) > 0:
|
85
|
+
ax = axes[0, 1]
|
86
|
+
for i in range(joint_vels.shape[1]):
|
87
|
+
ax.plot(times[: len(joint_vels)], joint_vels[:, i], alpha=0.7, linewidth=0.8)
|
88
|
+
ax.set_title("Joint Velocities")
|
89
|
+
ax.set_xlabel("Time (s)")
|
90
|
+
ax.set_ylabel("Velocity (rad/s)")
|
91
|
+
ax.grid(True, alpha=0.3)
|
92
|
+
|
93
|
+
if len(projected_g) > 0:
|
94
|
+
ax = axes[1, 0]
|
95
|
+
labels = ["X", "Y", "Z"]
|
96
|
+
for i in range(projected_g.shape[1]):
|
97
|
+
ax.plot(times[: len(projected_g)], projected_g[:, i], label=labels[i], linewidth=1.5)
|
98
|
+
ax.set_title("Projected Gravity")
|
99
|
+
ax.set_xlabel("Time (s)")
|
100
|
+
ax.set_ylabel("Acceleration (m/s²)")
|
101
|
+
ax.legend()
|
102
|
+
ax.grid(True, alpha=0.3)
|
103
|
+
|
104
|
+
if len(accel) > 0:
|
105
|
+
ax = axes[1, 1]
|
106
|
+
labels = ["X", "Y", "Z"]
|
107
|
+
for i in range(accel.shape[1]):
|
108
|
+
ax.plot(times[: len(accel)], accel[:, i], label=labels[i], linewidth=1.5)
|
109
|
+
ax.set_title("Acceleration")
|
110
|
+
ax.set_xlabel("Time (s)")
|
111
|
+
ax.set_ylabel("Acceleration (m/s²)")
|
112
|
+
ax.legend()
|
113
|
+
ax.grid(True, alpha=0.3)
|
114
|
+
|
115
|
+
if len(command) > 0:
|
116
|
+
ax = axes[2, 0]
|
117
|
+
for i in range(command.shape[1]):
|
118
|
+
ax.plot(times[: len(command)], command[:, i], label=f"Cmd {i}", linewidth=1.2)
|
119
|
+
ax.set_title("Command")
|
120
|
+
ax.set_xlabel("Time (s)")
|
121
|
+
ax.set_ylabel("Command Value")
|
122
|
+
ax.legend()
|
123
|
+
ax.grid(True, alpha=0.3)
|
124
|
+
|
125
|
+
if len(output) > 0:
|
126
|
+
ax = axes[2, 1]
|
127
|
+
for i in range(output.shape[1]):
|
128
|
+
ax.plot(times[: len(output)], output[:, i], alpha=0.7, linewidth=0.8)
|
129
|
+
ax.set_title("Output")
|
130
|
+
ax.set_xlabel("Time (s)")
|
131
|
+
ax.set_ylabel("Output Value")
|
132
|
+
ax.grid(True, alpha=0.3)
|
133
|
+
|
134
|
+
plt.tight_layout()
|
135
|
+
|
136
|
+
if save_path:
|
137
|
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
138
|
+
logger.info("Plot saved to: %s", save_path)
|
139
|
+
plt.close()
|
140
|
+
else:
|
141
|
+
plt.show()
|
142
|
+
|
143
|
+
|
144
|
+
def main() -> None:
|
145
|
+
parser = argparse.ArgumentParser(description="Plot NDJSON logs saved by kinfer")
|
146
|
+
parser.add_argument("filepath", help="Path to the NDJSON file to plot")
|
147
|
+
parser.add_argument("--skip", type=float, default=0.0, help="Skip the first n seconds of data")
|
148
|
+
parser.add_argument("--save", action="store_true", help="Save the plot to a PNG file in a plots folder")
|
149
|
+
args = parser.parse_args()
|
150
|
+
|
151
|
+
filepath = args.filepath
|
152
|
+
if not Path(filepath).exists():
|
153
|
+
logger.info("File not found: %s", filepath)
|
154
|
+
return
|
155
|
+
|
156
|
+
logger.info("Reading data from %s...", filepath)
|
157
|
+
data = read_ndjson(filepath)
|
158
|
+
logger.info("Loaded %d data points", len(data))
|
159
|
+
|
160
|
+
filtered_data = skip_initial_data(data, args.skip)
|
161
|
+
|
162
|
+
save_path = None
|
163
|
+
if args.save:
|
164
|
+
# Create save path in plots folder with same name but .png extension
|
165
|
+
input_path = Path(filepath)
|
166
|
+
plots_dir = input_path.parent / "plots"
|
167
|
+
plots_dir.mkdir(exist_ok=True)
|
168
|
+
|
169
|
+
# Change extension from .ndjson to .png
|
170
|
+
filename = input_path.stem + ".png"
|
171
|
+
save_path = str(plots_dir / filename)
|
172
|
+
|
173
|
+
plot_data(filtered_data, save_path)
|
174
|
+
|
175
|
+
|
176
|
+
if __name__ == "__main__":
|
177
|
+
main()
|