kinfer 0.3.2__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kinfer/__init__.py +6 -0
- kinfer/export/__init__.py +1 -0
- kinfer/export/pytorch.py +128 -0
- kinfer/inference/__init__.py +1 -0
- kinfer/inference/python.py +92 -0
- kinfer/proto/__init__.py +40 -0
- kinfer/proto/kinfer_pb2.py +103 -0
- kinfer/proto/kinfer_pb2.pyi +1097 -0
- kinfer/py.typed +0 -0
- kinfer/requirements-dev.txt +8 -0
- kinfer/requirements.txt +9 -0
- kinfer/rust/Cargo.toml +20 -0
- kinfer/rust/build.rs +16 -0
- kinfer/rust/src/kinfer_proto.rs +14 -0
- kinfer/rust/src/lib.rs +14 -0
- kinfer/rust/src/main.rs +6 -0
- kinfer/rust/src/model.rs +153 -0
- kinfer/rust/src/onnx_serializer.rs +804 -0
- kinfer/rust/src/serializer.rs +221 -0
- kinfer/rust/src/tests/onnx_serializer_tests.rs +212 -0
- kinfer/rust_bindings/Cargo.toml +19 -0
- kinfer/rust_bindings/pyproject.toml +7 -0
- kinfer/rust_bindings/src/bin/stub_gen.rs +7 -0
- kinfer/rust_bindings/src/lib.rs +17 -0
- kinfer/rust_bindings.cpython-311-darwin.so +0 -0
- kinfer/rust_bindings.pyi +7 -0
- kinfer/serialize/__init__.py +36 -0
- kinfer/serialize/base.py +536 -0
- kinfer/serialize/json.py +399 -0
- kinfer/serialize/numpy.py +426 -0
- kinfer/serialize/pytorch.py +402 -0
- kinfer/serialize/schema.py +125 -0
- kinfer/serialize/types.py +17 -0
- kinfer/serialize/utils.py +177 -0
- kinfer-0.3.2.dist-info/LICENSE +21 -0
- kinfer-0.3.2.dist-info/METADATA +57 -0
- kinfer-0.3.2.dist-info/RECORD +39 -0
- kinfer-0.3.2.dist-info/WHEEL +5 -0
- kinfer-0.3.2.dist-info/top_level.txt +1 -0
kinfer/py.typed
ADDED
File without changes
|
kinfer/requirements.txt
ADDED
kinfer/rust/Cargo.toml
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
[package]
|
2
|
+
|
3
|
+
name = "kinfer"
|
4
|
+
version = "0.1.0"
|
5
|
+
edition = "2021"
|
6
|
+
|
7
|
+
[dependencies]
|
8
|
+
|
9
|
+
futures-util = "0.3.30"
|
10
|
+
ndarray = "0.16.1"
|
11
|
+
ort = { version = "2.0.0-rc.6", features = [ "load-dynamic" ] }
|
12
|
+
prost = "0.12"
|
13
|
+
thiserror = "1.0"
|
14
|
+
|
15
|
+
[dev-dependencies]
|
16
|
+
rand = "0.8"
|
17
|
+
|
18
|
+
[build-dependencies]
|
19
|
+
|
20
|
+
prost-build = "0.12"
|
kinfer/rust/build.rs
ADDED
@@ -0,0 +1,16 @@
|
|
1
|
+
use std::env;
|
2
|
+
use std::path::PathBuf;
|
3
|
+
|
4
|
+
fn main() {
|
5
|
+
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()).join("proto");
|
6
|
+
std::fs::create_dir_all(&out_dir).unwrap();
|
7
|
+
|
8
|
+
let mut config = prost_build::Config::new();
|
9
|
+
config.out_dir(&out_dir);
|
10
|
+
config.retain_enum_prefix();
|
11
|
+
config.enable_type_names();
|
12
|
+
|
13
|
+
config
|
14
|
+
.compile_protos(&["../proto/kinfer.proto"], &["../proto/"])
|
15
|
+
.unwrap();
|
16
|
+
}
|
@@ -0,0 +1,14 @@
|
|
1
|
+
pub mod proto {
|
2
|
+
include!(concat!(env!("OUT_DIR"), "/proto/kinfer.proto.rs"));
|
3
|
+
}
|
4
|
+
|
5
|
+
pub use proto::{
|
6
|
+
AudioFrameSchema, AudioFrameValue, CameraFrameSchema, CameraFrameValue, DType,
|
7
|
+
ImuAccelerometerValue, ImuGyroscopeValue, ImuMagnetometerValue, ImuSchema, ImuValue,
|
8
|
+
Io as ProtoIO, IoSchema as ProtoIOSchema, JointCommandValue, JointCommandsSchema,
|
9
|
+
JointCommandsValue, JointPositionUnit, JointPositionValue, JointPositionsSchema,
|
10
|
+
JointPositionsValue, JointTorqueUnit, JointTorqueValue, JointTorquesSchema, JointTorquesValue,
|
11
|
+
JointVelocitiesSchema, JointVelocitiesValue, JointVelocityUnit, JointVelocityValue,
|
12
|
+
ModelSchema, StateTensorSchema, StateTensorValue, TimestampSchema, TimestampValue,
|
13
|
+
Value as ProtoValue, ValueSchema, VectorCommandSchema, VectorCommandValue,
|
14
|
+
};
|
kinfer/rust/src/lib.rs
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
pub mod kinfer_proto;
|
2
|
+
pub mod model;
|
3
|
+
pub mod onnx_serializer;
|
4
|
+
pub mod serializer;
|
5
|
+
|
6
|
+
pub use kinfer_proto::*;
|
7
|
+
pub use model::*;
|
8
|
+
pub use onnx_serializer::*;
|
9
|
+
pub use serializer::*;
|
10
|
+
|
11
|
+
#[cfg(test)]
|
12
|
+
mod tests {
|
13
|
+
mod onnx_serializer_tests;
|
14
|
+
}
|
kinfer/rust/src/main.rs
ADDED
kinfer/rust/src/model.rs
ADDED
@@ -0,0 +1,153 @@
|
|
1
|
+
use crate::kinfer_proto::{ModelSchema, ProtoIO, ProtoIOSchema};
|
2
|
+
use crate::onnx_serializer::OnnxMultiSerializer;
|
3
|
+
use std::path::Path;
|
4
|
+
|
5
|
+
use ort::session::builder::GraphOptimizationLevel;
|
6
|
+
use prost::Message;
|
7
|
+
use ort::{session::Session, Error as OrtError};
|
8
|
+
|
9
|
+
pub fn load_onnx_model<P: AsRef<Path>>(model_path: P) -> Result<Session, OrtError> {
|
10
|
+
let model = Session::builder()?
|
11
|
+
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
12
|
+
.with_intra_threads(4)?
|
13
|
+
.commit_from_file(model_path)?;
|
14
|
+
|
15
|
+
Ok(model)
|
16
|
+
}
|
17
|
+
|
18
|
+
const KINFER_METADATA_KEY: &str = "kinfer_metadata";
|
19
|
+
|
20
|
+
pub struct ModelRunner {
|
21
|
+
session: Session,
|
22
|
+
attached_metadata: std::collections::HashMap<String, String>,
|
23
|
+
schema: ModelSchema,
|
24
|
+
input_serializer: OnnxMultiSerializer,
|
25
|
+
output_serializer: OnnxMultiSerializer,
|
26
|
+
}
|
27
|
+
|
28
|
+
impl ModelRunner {
|
29
|
+
pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self, Box<dyn std::error::Error>> {
|
30
|
+
let session = load_onnx_model(model_path)?;
|
31
|
+
let mut attached_metadata = std::collections::HashMap::new();
|
32
|
+
|
33
|
+
// Extract metadata and attempt to parse schema
|
34
|
+
let mut schema = None;
|
35
|
+
{
|
36
|
+
let metadata = session.metadata()?;
|
37
|
+
for prop in metadata.custom_keys()? {
|
38
|
+
if prop == KINFER_METADATA_KEY {
|
39
|
+
let schema_bytes = metadata.custom(prop.as_str())?;
|
40
|
+
if let Some(bytes) = schema_bytes {
|
41
|
+
schema = Some(ModelSchema::decode(&mut bytes.as_bytes())?);
|
42
|
+
}
|
43
|
+
} else {
|
44
|
+
attached_metadata.insert(
|
45
|
+
prop.to_string(),
|
46
|
+
metadata
|
47
|
+
.custom(prop.as_str())?
|
48
|
+
.map_or_else(String::new, |s| s.to_string()),
|
49
|
+
);
|
50
|
+
}
|
51
|
+
}
|
52
|
+
}
|
53
|
+
|
54
|
+
let schema: ModelSchema = schema.ok_or_else(|| "kinfer_metadata not found in model metadata")?;
|
55
|
+
|
56
|
+
// Use as_ref() to borrow the Option contents and clone after ok_or
|
57
|
+
let input_schema = schema
|
58
|
+
.input_schema
|
59
|
+
.as_ref()
|
60
|
+
.ok_or("Missing input schema")?
|
61
|
+
.clone();
|
62
|
+
let output_schema = schema
|
63
|
+
.output_schema
|
64
|
+
.as_ref()
|
65
|
+
.ok_or("Missing output schema")?
|
66
|
+
.clone();
|
67
|
+
|
68
|
+
// Create serializers for input and output
|
69
|
+
let input_serializer = OnnxMultiSerializer::new(input_schema);
|
70
|
+
let output_serializer = OnnxMultiSerializer::new(output_schema);
|
71
|
+
|
72
|
+
Ok(Self {
|
73
|
+
session,
|
74
|
+
attached_metadata,
|
75
|
+
schema,
|
76
|
+
input_serializer,
|
77
|
+
output_serializer,
|
78
|
+
})
|
79
|
+
}
|
80
|
+
|
81
|
+
pub fn run(&self, inputs: ProtoIO) -> Result<ProtoIO, Box<dyn std::error::Error>> {
|
82
|
+
// Serialize inputs to ONNX format
|
83
|
+
let inputs = self.input_serializer.serialize_io(inputs)?;
|
84
|
+
|
85
|
+
// Get input names from the session
|
86
|
+
let input_names = self
|
87
|
+
.session
|
88
|
+
.inputs
|
89
|
+
.iter()
|
90
|
+
.map(|input| input.name.as_str())
|
91
|
+
.collect::<Vec<_>>();
|
92
|
+
|
93
|
+
// Create input name-value pairs
|
94
|
+
let input_values = vec![(input_names[0], inputs)];
|
95
|
+
|
96
|
+
let outputs = self.session.run(input_values)?;
|
97
|
+
|
98
|
+
let output_values = outputs
|
99
|
+
.values()
|
100
|
+
.map(|v: ort::value::ValueRef<'_>| {
|
101
|
+
v.try_upgrade().map_err(|e| {
|
102
|
+
Box::new(std::io::Error::new(
|
103
|
+
std::io::ErrorKind::Other,
|
104
|
+
format!("Failed to upgrade value"),
|
105
|
+
)) as Box<dyn std::error::Error>
|
106
|
+
})
|
107
|
+
})
|
108
|
+
.collect::<Result<Vec<_>, _>>()?;
|
109
|
+
// Deserialize outputs from ONNX format
|
110
|
+
let outputs = self.output_serializer.deserialize_io(output_values)?;
|
111
|
+
|
112
|
+
Ok(outputs)
|
113
|
+
}
|
114
|
+
|
115
|
+
pub fn export_model<P: AsRef<Path>>(&self, model_path: P) -> Result<(), Box<dyn std::error::Error>> {
|
116
|
+
let model_bytes = self.session.model_as_bytes()?;
|
117
|
+
let mut model = ModelProto::decode(&mut model_bytes.as_slice())?;
|
118
|
+
model.set_metadata_props(self.schema.encode_to_vec())?;
|
119
|
+
std::fs::write(model_path, model.write_to_bytes()?)?;
|
120
|
+
Ok(())
|
121
|
+
}
|
122
|
+
|
123
|
+
pub fn input_schema(&self) -> Result<ProtoIOSchema, Box<dyn std::error::Error>> {
|
124
|
+
self.schema
|
125
|
+
.input_schema
|
126
|
+
.as_ref()
|
127
|
+
.ok_or_else(|| {
|
128
|
+
Box::new(std::io::Error::new(
|
129
|
+
std::io::ErrorKind::NotFound,
|
130
|
+
"Missing input schema",
|
131
|
+
)) as Box<dyn std::error::Error>
|
132
|
+
})
|
133
|
+
.map(|schema| schema.clone())
|
134
|
+
}
|
135
|
+
|
136
|
+
pub fn output_schema(&self) -> Result<ProtoIOSchema, Box<dyn std::error::Error>> {
|
137
|
+
self.schema
|
138
|
+
.output_schema
|
139
|
+
.as_ref()
|
140
|
+
.ok_or_else(|| {
|
141
|
+
Box::new(std::io::Error::new(
|
142
|
+
std::io::ErrorKind::NotFound,
|
143
|
+
"Missing output schema",
|
144
|
+
)) as Box<dyn std::error::Error>
|
145
|
+
})
|
146
|
+
.map(|schema| schema.clone())
|
147
|
+
}
|
148
|
+
}
|
149
|
+
|
150
|
+
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
151
|
+
println!("Hello, world!");
|
152
|
+
Ok(())
|
153
|
+
}
|