kinfer 0.3.1__cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kinfer/py.typed ADDED
File without changes
@@ -0,0 +1,8 @@
1
+ # requirements-dev.txt
2
+
3
+ black
4
+ darglint
5
+ mypy
6
+ mypy-protobuf
7
+ pytest
8
+ ruff
@@ -0,0 +1,9 @@
1
+ # requirements.txt
2
+
3
+ # Machine Learning
4
+ torch
5
+ onnx
6
+ onnxruntime
7
+
8
+ # Protocol Buffers
9
+ protobuf
kinfer/rust/Cargo.toml ADDED
@@ -0,0 +1,20 @@
1
+ [package]
2
+
3
+ name = "kinfer"
4
+ version = "0.1.0"
5
+ edition = "2021"
6
+
7
+ [dependencies]
8
+
9
+ futures-util = "0.3.30"
10
+ ndarray = "0.16.1"
11
+ ort = { version = "2.0.0-rc.6", features = [ "load-dynamic" ] }
12
+ prost = "0.12"
13
+ thiserror = "1.0"
14
+
15
+ [dev-dependencies]
16
+ rand = "0.8"
17
+
18
+ [build-dependencies]
19
+
20
+ prost-build = "0.12"
kinfer/rust/build.rs ADDED
@@ -0,0 +1,16 @@
1
+ use std::env;
2
+ use std::path::PathBuf;
3
+
4
+ fn main() {
5
+ let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()).join("proto");
6
+ std::fs::create_dir_all(&out_dir).unwrap();
7
+
8
+ let mut config = prost_build::Config::new();
9
+ config.out_dir(&out_dir);
10
+ config.retain_enum_prefix();
11
+ config.enable_type_names();
12
+
13
+ config
14
+ .compile_protos(&["../proto/kinfer.proto"], &["../proto/"])
15
+ .unwrap();
16
+ }
@@ -0,0 +1,14 @@
1
+ pub mod proto {
2
+ include!(concat!(env!("OUT_DIR"), "/proto/kinfer.proto.rs"));
3
+ }
4
+
5
+ pub use proto::{
6
+ AudioFrameSchema, AudioFrameValue, CameraFrameSchema, CameraFrameValue, DType,
7
+ ImuAccelerometerValue, ImuGyroscopeValue, ImuMagnetometerValue, ImuSchema, ImuValue,
8
+ Io as ProtoIO, IoSchema as ProtoIOSchema, JointCommandValue, JointCommandsSchema,
9
+ JointCommandsValue, JointPositionUnit, JointPositionValue, JointPositionsSchema,
10
+ JointPositionsValue, JointTorqueUnit, JointTorqueValue, JointTorquesSchema, JointTorquesValue,
11
+ JointVelocitiesSchema, JointVelocitiesValue, JointVelocityUnit, JointVelocityValue,
12
+ ModelSchema, StateTensorSchema, StateTensorValue, TimestampSchema, TimestampValue,
13
+ Value as ProtoValue, ValueSchema, VectorCommandSchema, VectorCommandValue,
14
+ };
kinfer/rust/src/lib.rs ADDED
@@ -0,0 +1,14 @@
1
+ pub mod kinfer_proto;
2
+ pub mod model;
3
+ pub mod onnx_serializer;
4
+ pub mod serializer;
5
+
6
+ pub use kinfer_proto::*;
7
+ pub use model::*;
8
+ pub use onnx_serializer::*;
9
+ pub use serializer::*;
10
+
11
+ #[cfg(test)]
12
+ mod tests {
13
+ mod onnx_serializer_tests;
14
+ }
@@ -0,0 +1,6 @@
1
+ use kinfer::*;
2
+
3
+ fn main() -> Result<(), Box<dyn std::error::Error>> {
4
+ println!("Hello, world!");
5
+ Ok(())
6
+ }
@@ -0,0 +1,153 @@
1
+ use crate::kinfer_proto::{ModelSchema, ProtoIO, ProtoIOSchema};
2
+ use crate::onnx_serializer::OnnxMultiSerializer;
3
+ use std::path::Path;
4
+
5
+ use ort::session::builder::GraphOptimizationLevel;
6
+ use prost::Message;
7
+ use ort::{session::Session, Error as OrtError};
8
+
9
+ pub fn load_onnx_model<P: AsRef<Path>>(model_path: P) -> Result<Session, OrtError> {
10
+ let model = Session::builder()?
11
+ .with_optimization_level(GraphOptimizationLevel::Level3)?
12
+ .with_intra_threads(4)?
13
+ .commit_from_file(model_path)?;
14
+
15
+ Ok(model)
16
+ }
17
+
18
+ const KINFER_METADATA_KEY: &str = "kinfer_metadata";
19
+
20
+ pub struct ModelRunner {
21
+ session: Session,
22
+ attached_metadata: std::collections::HashMap<String, String>,
23
+ schema: ModelSchema,
24
+ input_serializer: OnnxMultiSerializer,
25
+ output_serializer: OnnxMultiSerializer,
26
+ }
27
+
28
+ impl ModelRunner {
29
+ pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self, Box<dyn std::error::Error>> {
30
+ let session = load_onnx_model(model_path)?;
31
+ let mut attached_metadata = std::collections::HashMap::new();
32
+
33
+ // Extract metadata and attempt to parse schema
34
+ let mut schema = None;
35
+ {
36
+ let metadata = session.metadata()?;
37
+ for prop in metadata.custom_keys()? {
38
+ if prop == KINFER_METADATA_KEY {
39
+ let schema_bytes = metadata.custom(prop.as_str())?;
40
+ if let Some(bytes) = schema_bytes {
41
+ schema = Some(ModelSchema::decode(&mut bytes.as_bytes())?);
42
+ }
43
+ } else {
44
+ attached_metadata.insert(
45
+ prop.to_string(),
46
+ metadata
47
+ .custom(prop.as_str())?
48
+ .map_or_else(String::new, |s| s.to_string()),
49
+ );
50
+ }
51
+ }
52
+ }
53
+
54
+ let schema: ModelSchema = schema.ok_or_else(|| "kinfer_metadata not found in model metadata")?;
55
+
56
+ // Use as_ref() to borrow the Option contents and clone after ok_or
57
+ let input_schema = schema
58
+ .input_schema
59
+ .as_ref()
60
+ .ok_or("Missing input schema")?
61
+ .clone();
62
+ let output_schema = schema
63
+ .output_schema
64
+ .as_ref()
65
+ .ok_or("Missing output schema")?
66
+ .clone();
67
+
68
+ // Create serializers for input and output
69
+ let input_serializer = OnnxMultiSerializer::new(input_schema);
70
+ let output_serializer = OnnxMultiSerializer::new(output_schema);
71
+
72
+ Ok(Self {
73
+ session,
74
+ attached_metadata,
75
+ schema,
76
+ input_serializer,
77
+ output_serializer,
78
+ })
79
+ }
80
+
81
+ pub fn run(&self, inputs: ProtoIO) -> Result<ProtoIO, Box<dyn std::error::Error>> {
82
+ // Serialize inputs to ONNX format
83
+ let inputs = self.input_serializer.serialize_io(inputs)?;
84
+
85
+ // Get input names from the session
86
+ let input_names = self
87
+ .session
88
+ .inputs
89
+ .iter()
90
+ .map(|input| input.name.as_str())
91
+ .collect::<Vec<_>>();
92
+
93
+ // Create input name-value pairs
94
+ let input_values = vec![(input_names[0], inputs)];
95
+
96
+ let outputs = self.session.run(input_values)?;
97
+
98
+ let output_values = outputs
99
+ .values()
100
+ .map(|v: ort::value::ValueRef<'_>| {
101
+ v.try_upgrade().map_err(|e| {
102
+ Box::new(std::io::Error::new(
103
+ std::io::ErrorKind::Other,
104
+ format!("Failed to upgrade value"),
105
+ )) as Box<dyn std::error::Error>
106
+ })
107
+ })
108
+ .collect::<Result<Vec<_>, _>>()?;
109
+ // Deserialize outputs from ONNX format
110
+ let outputs = self.output_serializer.deserialize_io(output_values)?;
111
+
112
+ Ok(outputs)
113
+ }
114
+
115
+ pub fn export_model<P: AsRef<Path>>(&self, model_path: P) -> Result<(), Box<dyn std::error::Error>> {
116
+ let model_bytes = self.session.model_as_bytes()?;
117
+ let mut model = ModelProto::decode(&mut model_bytes.as_slice())?;
118
+ model.set_metadata_props(self.schema.encode_to_vec())?;
119
+ std::fs::write(model_path, model.write_to_bytes()?)?;
120
+ Ok(())
121
+ }
122
+
123
+ pub fn input_schema(&self) -> Result<ProtoIOSchema, Box<dyn std::error::Error>> {
124
+ self.schema
125
+ .input_schema
126
+ .as_ref()
127
+ .ok_or_else(|| {
128
+ Box::new(std::io::Error::new(
129
+ std::io::ErrorKind::NotFound,
130
+ "Missing input schema",
131
+ )) as Box<dyn std::error::Error>
132
+ })
133
+ .map(|schema| schema.clone())
134
+ }
135
+
136
+ pub fn output_schema(&self) -> Result<ProtoIOSchema, Box<dyn std::error::Error>> {
137
+ self.schema
138
+ .output_schema
139
+ .as_ref()
140
+ .ok_or_else(|| {
141
+ Box::new(std::io::Error::new(
142
+ std::io::ErrorKind::NotFound,
143
+ "Missing output schema",
144
+ )) as Box<dyn std::error::Error>
145
+ })
146
+ .map(|schema| schema.clone())
147
+ }
148
+ }
149
+
150
+ fn main() -> Result<(), Box<dyn std::error::Error>> {
151
+ println!("Hello, world!");
152
+ Ok(())
153
+ }