kinfer 0.0.3__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: kinfer
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: The kinfer project
5
5
  Home-page: https://github.com/kscalelabs/kinfer.git
6
6
  Author: K-Scale Labs
@@ -1,3 +1,3 @@
1
- __version__ = "0.0.3"
1
+ __version__ = "0.0.4"
2
2
 
3
3
  from . import export, inference
@@ -1,6 +1,8 @@
1
1
  """PyTorch model export utilities."""
2
2
 
3
3
  import inspect
4
+ import json
5
+ import logging
4
6
  import sys
5
7
  from dataclasses import fields, is_dataclass
6
8
  from io import BytesIO
@@ -30,7 +32,6 @@ def get_model_info(model: nn.Module) -> Dict[str, Any]:
30
32
  if name == "self":
31
33
  continue
32
34
  params_info[name] = {
33
- "annotation": str(param.annotation),
34
35
  "default": None if param.default is param.empty else str(param.default),
35
36
  }
36
37
 
@@ -53,19 +54,21 @@ def add_metadata_to_onnx(
53
54
  Returns:
54
55
  ONNX model with added metadata
55
56
  """
56
- # Add model metadata
57
- for key, value in metadata.items():
58
- meta = model_proto.metadata_props.add()
59
- meta.key = key
60
- meta.value = str(value)
57
+ # Build metadata dictionary
58
+ metadata_dict = metadata.copy()
61
59
 
62
60
  # Add configuration if provided
63
- if config is not None and is_dataclass(config):
64
- for field in fields(config):
65
- value = getattr(config, field.name)
66
- meta = model_proto.metadata_props.add()
67
- meta.key = field.name
68
- meta.value = str(value)
61
+ if config is not None:
62
+ if is_dataclass(config):
63
+ for field in fields(config):
64
+ metadata_dict[field.name] = getattr(config, field.name)
65
+ elif not isinstance(config, dict):
66
+ raise ValueError("config must be a dataclass or dict. Got: " + str(type(config)))
67
+
68
+ # Add metadata as JSON string
69
+ meta = model_proto.metadata_props.add()
70
+ meta.key = "kinfer_metadata"
71
+ meta.value = json.dumps(metadata_dict)
69
72
 
70
73
  return model_proto
71
74
 
@@ -165,6 +168,10 @@ def export_to_onnx(
165
168
 
166
169
  # Create example inputs if not provided
167
170
  if input_tensors is None:
171
+ logging.warning(
172
+ "No input_tensors provided. Attempting to automatically infer input shapes. "
173
+ "Note: Input shape inference is *highly* experimental and may not work correctly for all models."
174
+ )
168
175
  try:
169
176
  input_tensors = create_example_inputs(model)
170
177
  model_info["inferred_input_shapes"] = str(
@@ -198,6 +205,10 @@ def export_to_onnx(
198
205
  # Load as ONNX model
199
206
  model_proto = onnx.load_model(buffer)
200
207
 
208
+ # Add config dict to model info if provided
209
+ if isinstance(config, dict):
210
+ model_info.update(config)
211
+
201
212
  # Add metadata
202
213
  model_proto = add_metadata_to_onnx(model_proto, model_info, config)
203
214
 
@@ -1,5 +1,7 @@
1
1
  """ONNX model inference utilities for Python."""
2
2
 
3
+ import json
4
+ import logging
3
5
  from typing import Any, Dict, List, Union
4
6
 
5
7
  import numpy as np
@@ -28,8 +30,21 @@ class ONNXModel:
28
30
  model_path,
29
31
  )
30
32
 
31
- # Extract metadata
32
- self.metadata = {prop.key: prop.value for prop in self.model.metadata_props}
33
+ # Extract metadata and attempt to parse JSON values
34
+ self.metadata = {}
35
+ self.attached_metadata = {}
36
+ for prop in self.model.metadata_props:
37
+ if prop.key == "kinfer_metadata":
38
+ try:
39
+ self.metadata = json.loads(prop.value)
40
+ except json.JSONDecodeError:
41
+ logging.warning(
42
+ "Failed to parse kinfer_metadata value with JSON parser. Saving as string: %s",
43
+ prop.value,
44
+ )
45
+ self.metadata = prop.value
46
+
47
+ self.attached_metadata[prop.key] = prop.value
33
48
 
34
49
  # Get input and output details
35
50
  self.input_details = [{"name": x.name, "shape": x.shape, "type": x.type} for x in self.session.get_inputs()]
@@ -66,7 +81,7 @@ class ONNXModel:
66
81
  else:
67
82
  return {detail["name"]: arr for detail, arr in zip(self.output_details, outputs)}
68
83
 
69
- def get_metadata(self) -> Dict[str, str]:
84
+ def get_metadata(self) -> Dict[str, Any]:
70
85
  """Get model metadata.
71
86
 
72
87
  Returns:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: kinfer
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: The kinfer project
5
5
  Home-page: https://github.com/kscalelabs/kinfer.git
6
6
  Author: K-Scale Labs
@@ -67,8 +67,8 @@ def test_model_metadata(model_path: str) -> None:
67
67
  # Check if config parameters are in metadata
68
68
  assert "hidden_size" in metadata
69
69
  assert "num_layers" in metadata
70
- assert metadata["hidden_size"] == "64"
71
- assert metadata["num_layers"] == "2"
70
+ assert metadata["hidden_size"] == 64
71
+ assert metadata["num_layers"] == 2
72
72
 
73
73
 
74
74
  def test_model_inference(model_path: str) -> None:
@@ -123,8 +123,8 @@ def test_comprehensive_model_workflow(tmp_path: Path) -> None:
123
123
  metadata = onnx_model.get_metadata()
124
124
  assert "hidden_size" in metadata
125
125
  assert "num_layers" in metadata
126
- assert metadata["hidden_size"] == "64"
127
- assert metadata["num_layers"] == "2"
126
+ assert metadata["hidden_size"] == 64
127
+ assert metadata["num_layers"] == 2
128
128
 
129
129
  # Test input/output details
130
130
  input_details = onnx_model.get_input_details()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes