amesa-inference-dev 0.20.5.dev13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amesa_inference_dev-0.20.5.dev13/PKG-INFO +119 -0
- amesa_inference_dev-0.20.5.dev13/README.md +102 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference/__init__.py +20 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference/inference_engine.py +189 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference/inference_network_mgr.py +44 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference/onnx_helper.py +152 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference/onnx_inference.py +198 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference/onnx_selector_processor.py +388 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference/onnx_skill_processor.py +209 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference/onnx_skill_processor_factory.py +146 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference/skill_processor_base.py +551 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference_dev.egg-info/PKG-INFO +119 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference_dev.egg-info/SOURCES.txt +29 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference_dev.egg-info/dependency_links.txt +1 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference_dev.egg-info/requires.txt +6 -0
- amesa_inference_dev-0.20.5.dev13/amesa_inference_dev.egg-info/top_level.txt +1 -0
- amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/__init__.c +5063 -0
- amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/inference_engine.c +14582 -0
- amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/inference_network_mgr.c +8440 -0
- amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/onnx_helper.c +13729 -0
- amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/onnx_inference.c +12742 -0
- amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/onnx_selector_processor.c +19732 -0
- amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/onnx_skill_processor.c +15402 -0
- amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/onnx_skill_processor_factory.c +12074 -0
- amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/skill_processor_base.c +26156 -0
- amesa_inference_dev-0.20.5.dev13/pyproject.toml +118 -0
- amesa_inference_dev-0.20.5.dev13/setup.cfg +4 -0
- amesa_inference_dev-0.20.5.dev13/setup.py +164 -0
- amesa_inference_dev-0.20.5.dev13/tests/test_inference_engine.py +531 -0
- amesa_inference_dev-0.20.5.dev13/tests/test_onnx_inference.py +95 -0
- amesa_inference_dev-0.20.5.dev13/tests/test_selector_agent_inference.py +150 -0
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: amesa-inference-dev
|
|
3
|
+
Version: 0.20.5.dev13
|
|
4
|
+
Summary: Agent inference package using ONNX models without Ray or PyTorch dependencies
|
|
5
|
+
Author-email: Hunter Park <hunter@amesa.com>
|
|
6
|
+
Requires-Python: >=3.10, <3.13
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: amesa-api-dev>=0.20.5.dev13
|
|
9
|
+
Requires-Dist: amesa-core-dev>=0.20.5.dev13
|
|
10
|
+
Requires-Dist: numpy<2.0.0
|
|
11
|
+
Requires-Dist: onnxruntime
|
|
12
|
+
Requires-Dist: onnx
|
|
13
|
+
Requires-Dist: ruff
|
|
14
|
+
Dynamic: description
|
|
15
|
+
Dynamic: description-content-type
|
|
16
|
+
Dynamic: requires-python
|
|
17
|
+
|
|
18
|
+
# Amesa Inference
|
|
19
|
+
|
|
20
|
+
A lightweight inference package for running Amesa agents using ONNX models without Ray or PyTorch dependencies.
|
|
21
|
+
|
|
22
|
+
## Overview
|
|
23
|
+
|
|
24
|
+
`composabl_inference` provides a standalone inference engine for running trained Amesa agents. It uses ONNX Runtime for model inference, making it suitable for deployment scenarios where you want to avoid heavy dependencies like Ray and PyTorch.
|
|
25
|
+
|
|
26
|
+
## Features
|
|
27
|
+
|
|
28
|
+
- **ONNX-based inference**: Uses ONNX Runtime for efficient model inference
|
|
29
|
+
- **No Ray or PyTorch dependencies**: Lightweight package suitable for production deployment
|
|
30
|
+
- **Network management**: Supports both local and remote objects (skills, perceptors, controllers)
|
|
31
|
+
- **Compatible API**: Similar interface to `Trainer.package()` for easy migration
|
|
32
|
+
|
|
33
|
+
## Installation
|
|
34
|
+
|
|
35
|
+
```bash
|
|
36
|
+
pip install amesa-inference
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
## Usage
|
|
40
|
+
|
|
41
|
+
### Basic Inference
|
|
42
|
+
|
|
43
|
+
```python
|
|
44
|
+
from composabl_inference import InferenceEngine
|
|
45
|
+
from composabl_core import Agent
|
|
46
|
+
|
|
47
|
+
# Create inference engine (only license needed for license validation)
|
|
48
|
+
engine = InferenceEngine(license="your-license-key")
|
|
49
|
+
|
|
50
|
+
# Load agent
|
|
51
|
+
agent = Agent.load("path/to/agent")
|
|
52
|
+
await engine.load_agent(agent)
|
|
53
|
+
|
|
54
|
+
# Package agent for inference (similar to Trainer.package())
|
|
55
|
+
await engine.package()
|
|
56
|
+
|
|
57
|
+
# Run inference
|
|
58
|
+
observation = {...} # Your observation from the simulator
|
|
59
|
+
action = engine.execute(observation)
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
### With Remote Objects
|
|
63
|
+
|
|
64
|
+
The inference engine supports remote skills, perceptors, and controllers, just like the Trainer:
|
|
65
|
+
|
|
66
|
+
```python
|
|
67
|
+
from composabl_inference import InferenceEngine
|
|
68
|
+
|
|
69
|
+
# Optional: provide custom config for NetworkMgr (e.g., for remote targets)
|
|
70
|
+
config = {
|
|
71
|
+
"target": {
|
|
72
|
+
"local": {
|
|
73
|
+
"address": "localhost:1337",
|
|
74
|
+
},
|
|
75
|
+
},
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
engine = InferenceEngine(license="your-license-key", config=config)
|
|
79
|
+
await engine.load_agent("path/to/agent")
|
|
80
|
+
await engine.package()
|
|
81
|
+
|
|
82
|
+
# The skill processor will automatically handle remote objects
|
|
83
|
+
action = engine.execute(observation)
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
### Cleanup
|
|
87
|
+
|
|
88
|
+
```python
|
|
89
|
+
# Clean up resources
|
|
90
|
+
await engine.close()
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
## Architecture
|
|
94
|
+
|
|
95
|
+
### Components
|
|
96
|
+
|
|
97
|
+
1. **InferenceEngine**: Main entry point for inference operations
|
|
98
|
+
2. **NetworkMgr**: Manages network connections (non-Ray version)
|
|
99
|
+
3. **ONNXInferenceEngine**: Handles ONNX model loading and inference
|
|
100
|
+
4. **ONNXSkillProcessor**: Processes skills using ONNX models instead of PyTorch
|
|
101
|
+
|
|
102
|
+
### Differences from Trainer
|
|
103
|
+
|
|
104
|
+
- Uses ONNX Runtime instead of PyTorch for model inference
|
|
105
|
+
- NetworkMgr is not a Ray actor (runs in the same process)
|
|
106
|
+
- No Ray initialization required
|
|
107
|
+
- Lighter weight, suitable for production deployment
|
|
108
|
+
|
|
109
|
+
## Requirements
|
|
110
|
+
|
|
111
|
+
- Python >= 3.10
|
|
112
|
+
- composabl-core
|
|
113
|
+
- composabl-api
|
|
114
|
+
- onnxruntime
|
|
115
|
+
- numpy
|
|
116
|
+
|
|
117
|
+
## License
|
|
118
|
+
|
|
119
|
+
Proprietary and confidential - Copyright (C) Amesa, Inc
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# Amesa Inference
|
|
2
|
+
|
|
3
|
+
A lightweight inference package for running Amesa agents using ONNX models without Ray or PyTorch dependencies.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
`composabl_inference` provides a standalone inference engine for running trained Amesa agents. It uses ONNX Runtime for model inference, making it suitable for deployment scenarios where you want to avoid heavy dependencies like Ray and PyTorch.
|
|
8
|
+
|
|
9
|
+
## Features
|
|
10
|
+
|
|
11
|
+
- **ONNX-based inference**: Uses ONNX Runtime for efficient model inference
|
|
12
|
+
- **No Ray or PyTorch dependencies**: Lightweight package suitable for production deployment
|
|
13
|
+
- **Network management**: Supports both local and remote objects (skills, perceptors, controllers)
|
|
14
|
+
- **Compatible API**: Similar interface to `Trainer.package()` for easy migration
|
|
15
|
+
|
|
16
|
+
## Installation
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
pip install amesa-inference
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
## Usage
|
|
23
|
+
|
|
24
|
+
### Basic Inference
|
|
25
|
+
|
|
26
|
+
```python
|
|
27
|
+
from composabl_inference import InferenceEngine
|
|
28
|
+
from composabl_core import Agent
|
|
29
|
+
|
|
30
|
+
# Create inference engine (only license needed for license validation)
|
|
31
|
+
engine = InferenceEngine(license="your-license-key")
|
|
32
|
+
|
|
33
|
+
# Load agent
|
|
34
|
+
agent = Agent.load("path/to/agent")
|
|
35
|
+
await engine.load_agent(agent)
|
|
36
|
+
|
|
37
|
+
# Package agent for inference (similar to Trainer.package())
|
|
38
|
+
await engine.package()
|
|
39
|
+
|
|
40
|
+
# Run inference
|
|
41
|
+
observation = {...} # Your observation from the simulator
|
|
42
|
+
action = engine.execute(observation)
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
### With Remote Objects
|
|
46
|
+
|
|
47
|
+
The inference engine supports remote skills, perceptors, and controllers, just like the Trainer:
|
|
48
|
+
|
|
49
|
+
```python
|
|
50
|
+
from composabl_inference import InferenceEngine
|
|
51
|
+
|
|
52
|
+
# Optional: provide custom config for NetworkMgr (e.g., for remote targets)
|
|
53
|
+
config = {
|
|
54
|
+
"target": {
|
|
55
|
+
"local": {
|
|
56
|
+
"address": "localhost:1337",
|
|
57
|
+
},
|
|
58
|
+
},
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
engine = InferenceEngine(license="your-license-key", config=config)
|
|
62
|
+
await engine.load_agent("path/to/agent")
|
|
63
|
+
await engine.package()
|
|
64
|
+
|
|
65
|
+
# The skill processor will automatically handle remote objects
|
|
66
|
+
action = engine.execute(observation)
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
### Cleanup
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
# Clean up resources
|
|
73
|
+
await engine.close()
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
## Architecture
|
|
77
|
+
|
|
78
|
+
### Components
|
|
79
|
+
|
|
80
|
+
1. **InferenceEngine**: Main entry point for inference operations
|
|
81
|
+
2. **NetworkMgr**: Manages network connections (non-Ray version)
|
|
82
|
+
3. **ONNXInferenceEngine**: Handles ONNX model loading and inference
|
|
83
|
+
4. **ONNXSkillProcessor**: Processes skills using ONNX models instead of PyTorch
|
|
84
|
+
|
|
85
|
+
### Differences from Trainer
|
|
86
|
+
|
|
87
|
+
- Uses ONNX Runtime instead of PyTorch for model inference
|
|
88
|
+
- NetworkMgr is not a Ray actor (runs in the same process)
|
|
89
|
+
- No Ray initialization required
|
|
90
|
+
- Lighter weight, suitable for production deployment
|
|
91
|
+
|
|
92
|
+
## Requirements
|
|
93
|
+
|
|
94
|
+
- Python >= 3.10
|
|
95
|
+
- composabl-core
|
|
96
|
+
- composabl-api
|
|
97
|
+
- onnxruntime
|
|
98
|
+
- numpy
|
|
99
|
+
|
|
100
|
+
## License
|
|
101
|
+
|
|
102
|
+
Proprietary and confidential - Copyright (C) Amesa, Inc
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright (C) Amesa, Inc - All Rights Reserved
|
|
2
|
+
# Unauthorized copying of this file, via any medium is strictly prohibited
|
|
3
|
+
# Proprietary and confidential
|
|
4
|
+
|
|
5
|
+
from amesa_core.networking.network_mgr import NetworkMgr
|
|
6
|
+
|
|
7
|
+
from amesa_inference.inference_engine import InferenceEngine
|
|
8
|
+
from amesa_inference.onnx_skill_processor import ONNXSkillProcessor
|
|
9
|
+
from amesa_inference.onnx_selector_processor import ONNXSelectorProcessor
|
|
10
|
+
from amesa_inference.onnx_skill_processor_factory import create_onnx_skill_processor
|
|
11
|
+
from amesa_inference.skill_processor_base import BaseSkillProcessor
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"InferenceEngine",
|
|
15
|
+
"NetworkMgr",
|
|
16
|
+
"BaseSkillProcessor",
|
|
17
|
+
"ONNXSkillProcessor",
|
|
18
|
+
"ONNXSelectorProcessor",
|
|
19
|
+
"create_onnx_skill_processor",
|
|
20
|
+
]
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# Copyright (C) Amesa, Inc - All Rights Reserved
|
|
2
|
+
# Unauthorized copying of this file, via any medium is strictly prohibited
|
|
3
|
+
# Proprietary and confidential
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, Optional, Union
|
|
6
|
+
|
|
7
|
+
from amesa_core.settings import TrainerMode
|
|
8
|
+
from amesa_core.utils import license_util
|
|
9
|
+
import amesa_core.utils.logger as logger_util
|
|
10
|
+
from amesa_core.agent.agent import Agent
|
|
11
|
+
from amesa_core.agent.skill.skill import Skill
|
|
12
|
+
from amesa_core.config.trainer_config import TrainerConfig
|
|
13
|
+
from amesa_core.networking.config.skill_processor_context import (
|
|
14
|
+
SkillProcessorContext,
|
|
15
|
+
)
|
|
16
|
+
from amesa_core.networking.network_mgr import NetworkMgr
|
|
17
|
+
from amesa_inference.inference_network_mgr import InferenceNetworkMgr
|
|
18
|
+
from amesa_core.settings import settings
|
|
19
|
+
|
|
20
|
+
from amesa_inference.skill_processor_base import BaseSkillProcessor
|
|
21
|
+
|
|
22
|
+
logger = logger_util.get_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class InferenceEngine:
|
|
26
|
+
"""
|
|
27
|
+
Inference engine for running agent inference with and without remote objects.
|
|
28
|
+
Uses ONNX models instead of PyTorch/Ray for inference.
|
|
29
|
+
|
|
30
|
+
Supports all skill types including:
|
|
31
|
+
- Basic skills (ONNXSkillProcessor)
|
|
32
|
+
- Selector skills (ONNXSelectorProcessor) - automatically orchestrates child skills
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
license: Optional[str] = None,
|
|
38
|
+
config: Optional[Dict] = None,
|
|
39
|
+
network_mgr: Optional[NetworkMgr] = None,
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Initialize the inference engine.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
license: License key for Amesa (required for license validation)
|
|
46
|
+
config: Optional configuration dictionary for NetworkMgr (defaults to local target)
|
|
47
|
+
network_mgr: Optional pre-configured NetworkMgr instance. If not provided,
|
|
48
|
+
creates an InferenceNetworkMgr.
|
|
49
|
+
"""
|
|
50
|
+
# Set license in settings for license validation
|
|
51
|
+
if license is not None:
|
|
52
|
+
settings.AMESA_LICENSE = license
|
|
53
|
+
|
|
54
|
+
license_type, license_module, license_data = license_util.validate(
|
|
55
|
+
settings.AMESA_LICENSE,
|
|
56
|
+
settings.AMESA_ENV == TrainerMode.STAGING,
|
|
57
|
+
settings.URL_AMESA_LICENSE_SERVER,
|
|
58
|
+
)
|
|
59
|
+
logger.info(f"license_validated {license_type}")
|
|
60
|
+
|
|
61
|
+
# Create minimal config for NetworkMgr (defaults to local target)
|
|
62
|
+
if config is None:
|
|
63
|
+
config = {
|
|
64
|
+
"target": {
|
|
65
|
+
"local": {
|
|
66
|
+
"address": "localhost:1337",
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
self.config = TrainerConfig(**config)
|
|
72
|
+
|
|
73
|
+
# Use provided network_mgr or create default InferenceNetworkMgr
|
|
74
|
+
if network_mgr is not None:
|
|
75
|
+
self.network_mgr = network_mgr
|
|
76
|
+
else:
|
|
77
|
+
self.network_mgr = InferenceNetworkMgr(self.config)
|
|
78
|
+
|
|
79
|
+
self.agent: Optional[Agent] = None
|
|
80
|
+
self.skill_processor: Optional[BaseSkillProcessor] = None
|
|
81
|
+
|
|
82
|
+
async def load_agent(self, agent: Union[Agent, str]):
|
|
83
|
+
"""
|
|
84
|
+
Load an agent for inference.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
agent: Agent object or path to agent JSON file
|
|
88
|
+
"""
|
|
89
|
+
if isinstance(agent, str):
|
|
90
|
+
self.agent = Agent.load(agent)
|
|
91
|
+
else:
|
|
92
|
+
self.agent = agent
|
|
93
|
+
|
|
94
|
+
if not self.agent.is_initialized:
|
|
95
|
+
await self.agent.init()
|
|
96
|
+
|
|
97
|
+
logger.info(f"Loaded agent: {self.agent.id}")
|
|
98
|
+
|
|
99
|
+
async def package(
|
|
100
|
+
self, agent: Optional[Agent] = None, skill: Optional[Union[str, Skill]] = None
|
|
101
|
+
) -> None:
|
|
102
|
+
"""
|
|
103
|
+
Package an agent for inference, similar to Trainer.package().
|
|
104
|
+
Creates an ONNX-based skill processor that can be used for inference.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
agent: Agent object (if None, uses the loaded agent)
|
|
108
|
+
skill: Specific skill to package (if None, uses top skill)
|
|
109
|
+
"""
|
|
110
|
+
if agent is None:
|
|
111
|
+
agent = self.agent
|
|
112
|
+
|
|
113
|
+
if agent is None:
|
|
114
|
+
raise ValueError("No agent provided and no agent loaded")
|
|
115
|
+
|
|
116
|
+
if not agent.is_initialized:
|
|
117
|
+
await agent.init()
|
|
118
|
+
|
|
119
|
+
# Determine which skill to use
|
|
120
|
+
if skill is None:
|
|
121
|
+
node = agent.get_top_skill()
|
|
122
|
+
else:
|
|
123
|
+
if isinstance(skill, str):
|
|
124
|
+
node = agent.get_node_by_name(skill)
|
|
125
|
+
else:
|
|
126
|
+
node = skill
|
|
127
|
+
|
|
128
|
+
# Create skill processor context
|
|
129
|
+
context = SkillProcessorContext(
|
|
130
|
+
agent=agent,
|
|
131
|
+
skill=node,
|
|
132
|
+
network_mgr=self.network_mgr,
|
|
133
|
+
is_training=False,
|
|
134
|
+
is_validating=False,
|
|
135
|
+
for_skill_group=False,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Use factory to create appropriate skill processor based on skill type
|
|
139
|
+
# This handles selectors, basic skills, and other skill types automatically
|
|
140
|
+
from amesa_inference.onnx_skill_processor_factory import (
|
|
141
|
+
create_onnx_skill_processor,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
self.skill_processor = await create_onnx_skill_processor(context)
|
|
145
|
+
|
|
146
|
+
logger.info(
|
|
147
|
+
f"Packaged agent for inference with skill: {node.get_name()} "
|
|
148
|
+
f"(processor type: {type(self.skill_processor).__name__})"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
async def execute(
|
|
152
|
+
self,
|
|
153
|
+
obs: Any,
|
|
154
|
+
sim_action_mask: Optional[Any] = None,
|
|
155
|
+
explore: bool = False,
|
|
156
|
+
previous_action: Optional[Any] = None,
|
|
157
|
+
) -> Any:
|
|
158
|
+
"""
|
|
159
|
+
Execute inference on an observation.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
obs: Observation from the simulator
|
|
163
|
+
sim_action_mask: Optional action mask from simulator
|
|
164
|
+
explore: Whether to explore (not used in inference, kept for compatibility)
|
|
165
|
+
previous_action: Previous action (for sequential skills)
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Action to take
|
|
169
|
+
"""
|
|
170
|
+
if self.skill_processor is None:
|
|
171
|
+
raise RuntimeError(
|
|
172
|
+
"No skill processor available. Call package() first."
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return await self.skill_processor._execute(
|
|
176
|
+
obs,
|
|
177
|
+
sim_action_mask=sim_action_mask,
|
|
178
|
+
explore=explore,
|
|
179
|
+
previous_action=previous_action,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
async def close(self):
|
|
183
|
+
"""Close the inference engine and cleanup resources."""
|
|
184
|
+
if self.network_mgr is not None:
|
|
185
|
+
self.network_mgr.stop_all()
|
|
186
|
+
self.network_mgr.stop_watchdogs()
|
|
187
|
+
|
|
188
|
+
logger.info("Inference engine closed")
|
|
189
|
+
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Copyright (C) Amesa, Inc - All Rights Reserved
|
|
2
|
+
# Unauthorized copying of this file, via any medium is strictly prohibited
|
|
3
|
+
# Proprietary and confidential
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
InferenceNetworkMgr - NetworkMgr subclass for async-first usage.
|
|
7
|
+
|
|
8
|
+
This module provides a NetworkMgr designed for use with InferenceEngine,
|
|
9
|
+
which operates in async contexts. Unlike the base NetworkMgr which uses
|
|
10
|
+
async_to_sync() (which may case event loop conflicts in async contexts), this
|
|
11
|
+
class returns coroutines that callers await directly.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import inspect
|
|
15
|
+
from typing import Callable
|
|
16
|
+
|
|
17
|
+
from amesa_core.networking.network_mgr import NetworkMgr
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class InferenceNetworkMgr(NetworkMgr):
|
|
21
|
+
"""
|
|
22
|
+
NetworkMgr subclass designed for async contexts.
|
|
23
|
+
|
|
24
|
+
This class overrides _call_method() to return coroutines directly,
|
|
25
|
+
which callers await. This avoids the event loop conflict entirely.
|
|
26
|
+
|
|
27
|
+
Usage:
|
|
28
|
+
mgr = InferenceNetworkMgr(config)
|
|
29
|
+
result = await mgr.call_remote_perceptor_mgr("method", ...)
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def _call_method(self, attr: Callable, *args, **kwargs):
|
|
33
|
+
"""
|
|
34
|
+
Override base _call_method to return coroutines for async methods.
|
|
35
|
+
|
|
36
|
+
Callers must await the result for async methods.
|
|
37
|
+
"""
|
|
38
|
+
if inspect.iscoroutinefunction(attr):
|
|
39
|
+
# Return coroutine directly - caller will await
|
|
40
|
+
return attr(*args, **kwargs)
|
|
41
|
+
else:
|
|
42
|
+
# Sync method - call directly
|
|
43
|
+
return attr(*args, **kwargs)
|
|
44
|
+
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
# Copyright (C) Amesa, Inc - All Rights Reserved
|
|
2
|
+
# Unauthorized copying of this file, via any medium is strictly prohibited
|
|
3
|
+
# Proprietary and confidential
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
ONNX helper utilities for amesa_inference.
|
|
7
|
+
|
|
8
|
+
This module contains utilities for working with ONNX models during inference,
|
|
9
|
+
including ensuring models have the correct input structure.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
import amesa_core.utils.logger as logger_util
|
|
16
|
+
import onnx
|
|
17
|
+
|
|
18
|
+
logger = logger_util.get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def ensure_onnx_model_has_action_mask(
|
|
22
|
+
onnx_path: str, action_space=None, skill=None
|
|
23
|
+
) -> bool:
|
|
24
|
+
"""
|
|
25
|
+
Ensure an ONNX model has an action_mask input. If it doesn't, attempt to add it.
|
|
26
|
+
|
|
27
|
+
This function is used after downloading ONNX models to ensure they have the
|
|
28
|
+
action_mask input required for inference, especially for selectors.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
onnx_path: Path to the ONNX model file
|
|
32
|
+
action_space: Optional action space to determine action_mask shape
|
|
33
|
+
skill: Optional skill object to get action space from
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
True if the model was modified, False if it already had action_mask
|
|
37
|
+
"""
|
|
38
|
+
try:
|
|
39
|
+
onnx_model = onnx.load(onnx_path)
|
|
40
|
+
|
|
41
|
+
# Check if action_mask input already exists
|
|
42
|
+
input_names = [input.name for input in onnx_model.graph.input]
|
|
43
|
+
has_action_mask = any("action_mask" in name.lower() for name in input_names)
|
|
44
|
+
|
|
45
|
+
if has_action_mask:
|
|
46
|
+
logger.debug(f"ONNX model at {onnx_path} already has action_mask input")
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
# Determine action_mask shape
|
|
50
|
+
action_mask_shape = None
|
|
51
|
+
if action_space is not None:
|
|
52
|
+
try:
|
|
53
|
+
action_mask_space = action_space.get_action_mask_space()
|
|
54
|
+
if hasattr(action_mask_space, 'shape'):
|
|
55
|
+
action_mask_shape = action_mask_space.shape
|
|
56
|
+
elif hasattr(action_space, 'n'):
|
|
57
|
+
action_mask_shape = (action_space.n,)
|
|
58
|
+
elif hasattr(action_space, 'shape'):
|
|
59
|
+
action_mask_shape = action_space.shape
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.warning(f"Could not get action_mask shape from action_space: {e}")
|
|
62
|
+
|
|
63
|
+
# If we have a skill, try to get action space from it
|
|
64
|
+
if action_mask_shape is None and skill is not None:
|
|
65
|
+
try:
|
|
66
|
+
skill_action_space = skill.get_action_space()
|
|
67
|
+
if skill_action_space is not None:
|
|
68
|
+
action_mask_space = skill_action_space.get_action_mask_space()
|
|
69
|
+
if hasattr(action_mask_space, 'shape'):
|
|
70
|
+
action_mask_shape = action_mask_space.shape
|
|
71
|
+
elif hasattr(skill_action_space, 'n'):
|
|
72
|
+
action_mask_shape = (skill_action_space.n,)
|
|
73
|
+
except Exception as e:
|
|
74
|
+
logger.warning(f"Could not get action_mask shape from skill: {e}")
|
|
75
|
+
|
|
76
|
+
# Try to infer from action_dist_inputs output shape
|
|
77
|
+
if action_mask_shape is None:
|
|
78
|
+
try:
|
|
79
|
+
output_names = [output.name for output in onnx_model.graph.output]
|
|
80
|
+
if "action_dist_inputs" in output_names:
|
|
81
|
+
# Find the output and get its shape
|
|
82
|
+
for output in onnx_model.graph.output:
|
|
83
|
+
if output.name == "action_dist_inputs":
|
|
84
|
+
shape = [
|
|
85
|
+
dim.dim_value if dim.dim_value > 0 else 1
|
|
86
|
+
for dim in output.type.tensor_type.shape.dim
|
|
87
|
+
]
|
|
88
|
+
if len(shape) >= 2:
|
|
89
|
+
# Shape is [batch, action_size], we want action_size
|
|
90
|
+
action_mask_shape = (shape[-1],)
|
|
91
|
+
elif len(shape) == 1:
|
|
92
|
+
action_mask_shape = (shape[0],)
|
|
93
|
+
break
|
|
94
|
+
except Exception as e:
|
|
95
|
+
logger.warning(f"Could not infer action_mask shape from model outputs: {e}")
|
|
96
|
+
|
|
97
|
+
# Fallback to shape (1,)
|
|
98
|
+
if action_mask_shape is None:
|
|
99
|
+
logger.warning("Could not determine action_mask shape, using default (1,)")
|
|
100
|
+
action_mask_shape = (1,)
|
|
101
|
+
|
|
102
|
+
# Get observation input to determine batch dimension
|
|
103
|
+
obs_input = None
|
|
104
|
+
for input in onnx_model.graph.input:
|
|
105
|
+
if "observation" in input.name.lower() or input.name == "observation":
|
|
106
|
+
obs_input = input
|
|
107
|
+
break
|
|
108
|
+
|
|
109
|
+
if obs_input is None and len(onnx_model.graph.input) > 0:
|
|
110
|
+
obs_input = onnx_model.graph.input[0]
|
|
111
|
+
|
|
112
|
+
batch_dim = None
|
|
113
|
+
if obs_input is not None:
|
|
114
|
+
try:
|
|
115
|
+
# Get batch dimension from observation input (usually first dimension)
|
|
116
|
+
obs_shape = obs_input.type.tensor_type.shape.dim
|
|
117
|
+
if len(obs_shape) > 0:
|
|
118
|
+
batch_dim = obs_shape[0]
|
|
119
|
+
except Exception:
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
# Create action_mask input
|
|
123
|
+
# Use dynamic batch dimension if available, otherwise use 1
|
|
124
|
+
if batch_dim is not None:
|
|
125
|
+
action_mask_dims = [batch_dim]
|
|
126
|
+
else:
|
|
127
|
+
action_mask_dims = [onnx.TensorShapeProto.Dimension(dim_value=1)]
|
|
128
|
+
|
|
129
|
+
for dim_size in action_mask_shape:
|
|
130
|
+
action_mask_dims.append(onnx.TensorShapeProto.Dimension(dim_value=dim_size))
|
|
131
|
+
|
|
132
|
+
action_mask_type = onnx.TypeProto()
|
|
133
|
+
action_mask_type.tensor_type.elem_type = onnx.TensorProto.DOUBLE
|
|
134
|
+
action_mask_type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto(dim=action_mask_dims))
|
|
135
|
+
|
|
136
|
+
action_mask_value_info = onnx.ValueInfoProto()
|
|
137
|
+
action_mask_value_info.name = "action_mask"
|
|
138
|
+
action_mask_value_info.type.CopyFrom(action_mask_type)
|
|
139
|
+
|
|
140
|
+
# Add action_mask input to the graph
|
|
141
|
+
onnx_model.graph.input.append(action_mask_value_info)
|
|
142
|
+
|
|
143
|
+
# Validate and save
|
|
144
|
+
onnx.checker.check_model(onnx_model)
|
|
145
|
+
onnx.save(onnx_model, onnx_path)
|
|
146
|
+
|
|
147
|
+
logger.info(f"Added action_mask input with shape {action_mask_shape} to ONNX model at {onnx_path}")
|
|
148
|
+
return True
|
|
149
|
+
|
|
150
|
+
except Exception as e:
|
|
151
|
+
logger.error(f"Failed to ensure action_mask input in ONNX model at {onnx_path}: {e}")
|
|
152
|
+
raise
|