amesa-inference-dev 0.20.5.dev13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. amesa_inference_dev-0.20.5.dev13/PKG-INFO +119 -0
  2. amesa_inference_dev-0.20.5.dev13/README.md +102 -0
  3. amesa_inference_dev-0.20.5.dev13/amesa_inference/__init__.py +20 -0
  4. amesa_inference_dev-0.20.5.dev13/amesa_inference/inference_engine.py +189 -0
  5. amesa_inference_dev-0.20.5.dev13/amesa_inference/inference_network_mgr.py +44 -0
  6. amesa_inference_dev-0.20.5.dev13/amesa_inference/onnx_helper.py +152 -0
  7. amesa_inference_dev-0.20.5.dev13/amesa_inference/onnx_inference.py +198 -0
  8. amesa_inference_dev-0.20.5.dev13/amesa_inference/onnx_selector_processor.py +388 -0
  9. amesa_inference_dev-0.20.5.dev13/amesa_inference/onnx_skill_processor.py +209 -0
  10. amesa_inference_dev-0.20.5.dev13/amesa_inference/onnx_skill_processor_factory.py +146 -0
  11. amesa_inference_dev-0.20.5.dev13/amesa_inference/skill_processor_base.py +551 -0
  12. amesa_inference_dev-0.20.5.dev13/amesa_inference_dev.egg-info/PKG-INFO +119 -0
  13. amesa_inference_dev-0.20.5.dev13/amesa_inference_dev.egg-info/SOURCES.txt +29 -0
  14. amesa_inference_dev-0.20.5.dev13/amesa_inference_dev.egg-info/dependency_links.txt +1 -0
  15. amesa_inference_dev-0.20.5.dev13/amesa_inference_dev.egg-info/requires.txt +6 -0
  16. amesa_inference_dev-0.20.5.dev13/amesa_inference_dev.egg-info/top_level.txt +1 -0
  17. amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/__init__.c +5063 -0
  18. amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/inference_engine.c +14582 -0
  19. amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/inference_network_mgr.c +8440 -0
  20. amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/onnx_helper.c +13729 -0
  21. amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/onnx_inference.c +12742 -0
  22. amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/onnx_selector_processor.c +19732 -0
  23. amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/onnx_skill_processor.c +15402 -0
  24. amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/onnx_skill_processor_factory.c +12074 -0
  25. amesa_inference_dev-0.20.5.dev13/dist/build_cython/amesa_inference/skill_processor_base.c +26156 -0
  26. amesa_inference_dev-0.20.5.dev13/pyproject.toml +118 -0
  27. amesa_inference_dev-0.20.5.dev13/setup.cfg +4 -0
  28. amesa_inference_dev-0.20.5.dev13/setup.py +164 -0
  29. amesa_inference_dev-0.20.5.dev13/tests/test_inference_engine.py +531 -0
  30. amesa_inference_dev-0.20.5.dev13/tests/test_onnx_inference.py +95 -0
  31. amesa_inference_dev-0.20.5.dev13/tests/test_selector_agent_inference.py +150 -0
@@ -0,0 +1,119 @@
1
+ Metadata-Version: 2.4
2
+ Name: amesa-inference-dev
3
+ Version: 0.20.5.dev13
4
+ Summary: Agent inference package using ONNX models without Ray or PyTorch dependencies
5
+ Author-email: Hunter Park <hunter@amesa.com>
6
+ Requires-Python: >=3.10, <3.13
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: amesa-api-dev>=0.20.5.dev13
9
+ Requires-Dist: amesa-core-dev>=0.20.5.dev13
10
+ Requires-Dist: numpy<2.0.0
11
+ Requires-Dist: onnxruntime
12
+ Requires-Dist: onnx
13
+ Requires-Dist: ruff
14
+ Dynamic: description
15
+ Dynamic: description-content-type
16
+ Dynamic: requires-python
17
+
18
+ # Amesa Inference
19
+
20
+ A lightweight inference package for running Amesa agents using ONNX models without Ray or PyTorch dependencies.
21
+
22
+ ## Overview
23
+
24
+ `composabl_inference` provides a standalone inference engine for running trained Amesa agents. It uses ONNX Runtime for model inference, making it suitable for deployment scenarios where you want to avoid heavy dependencies like Ray and PyTorch.
25
+
26
+ ## Features
27
+
28
+ - **ONNX-based inference**: Uses ONNX Runtime for efficient model inference
29
+ - **No Ray or PyTorch dependencies**: Lightweight package suitable for production deployment
30
+ - **Network management**: Supports both local and remote objects (skills, perceptors, controllers)
31
+ - **Compatible API**: Similar interface to `Trainer.package()` for easy migration
32
+
33
+ ## Installation
34
+
35
+ ```bash
36
+ pip install amesa-inference
37
+ ```
38
+
39
+ ## Usage
40
+
41
+ ### Basic Inference
42
+
43
+ ```python
44
+ from composabl_inference import InferenceEngine
45
+ from composabl_core import Agent
46
+
47
+ # Create inference engine (only license needed for license validation)
48
+ engine = InferenceEngine(license="your-license-key")
49
+
50
+ # Load agent
51
+ agent = Agent.load("path/to/agent")
52
+ await engine.load_agent(agent)
53
+
54
+ # Package agent for inference (similar to Trainer.package())
55
+ await engine.package()
56
+
57
+ # Run inference
58
+ observation = {...} # Your observation from the simulator
59
+ action = engine.execute(observation)
60
+ ```
61
+
62
+ ### With Remote Objects
63
+
64
+ The inference engine supports remote skills, perceptors, and controllers, just like the Trainer:
65
+
66
+ ```python
67
+ from composabl_inference import InferenceEngine
68
+
69
+ # Optional: provide custom config for NetworkMgr (e.g., for remote targets)
70
+ config = {
71
+ "target": {
72
+ "local": {
73
+ "address": "localhost:1337",
74
+ },
75
+ },
76
+ }
77
+
78
+ engine = InferenceEngine(license="your-license-key", config=config)
79
+ await engine.load_agent("path/to/agent")
80
+ await engine.package()
81
+
82
+ # The skill processor will automatically handle remote objects
83
+ action = engine.execute(observation)
84
+ ```
85
+
86
+ ### Cleanup
87
+
88
+ ```python
89
+ # Clean up resources
90
+ await engine.close()
91
+ ```
92
+
93
+ ## Architecture
94
+
95
+ ### Components
96
+
97
+ 1. **InferenceEngine**: Main entry point for inference operations
98
+ 2. **NetworkMgr**: Manages network connections (non-Ray version)
99
+ 3. **ONNXInferenceEngine**: Handles ONNX model loading and inference
100
+ 4. **ONNXSkillProcessor**: Processes skills using ONNX models instead of PyTorch
101
+
102
+ ### Differences from Trainer
103
+
104
+ - Uses ONNX Runtime instead of PyTorch for model inference
105
+ - NetworkMgr is not a Ray actor (runs in the same process)
106
+ - No Ray initialization required
107
+ - Lighter weight, suitable for production deployment
108
+
109
+ ## Requirements
110
+
111
+ - Python >= 3.10
112
+ - composabl-core
113
+ - composabl-api
114
+ - onnxruntime
115
+ - numpy
116
+
117
+ ## License
118
+
119
+ Proprietary and confidential - Copyright (C) Amesa, Inc
@@ -0,0 +1,102 @@
1
+ # Amesa Inference
2
+
3
+ A lightweight inference package for running Amesa agents using ONNX models without Ray or PyTorch dependencies.
4
+
5
+ ## Overview
6
+
7
+ `composabl_inference` provides a standalone inference engine for running trained Amesa agents. It uses ONNX Runtime for model inference, making it suitable for deployment scenarios where you want to avoid heavy dependencies like Ray and PyTorch.
8
+
9
+ ## Features
10
+
11
+ - **ONNX-based inference**: Uses ONNX Runtime for efficient model inference
12
+ - **No Ray or PyTorch dependencies**: Lightweight package suitable for production deployment
13
+ - **Network management**: Supports both local and remote objects (skills, perceptors, controllers)
14
+ - **Compatible API**: Similar interface to `Trainer.package()` for easy migration
15
+
16
+ ## Installation
17
+
18
+ ```bash
19
+ pip install amesa-inference
20
+ ```
21
+
22
+ ## Usage
23
+
24
+ ### Basic Inference
25
+
26
+ ```python
27
+ from composabl_inference import InferenceEngine
28
+ from composabl_core import Agent
29
+
30
+ # Create inference engine (only license needed for license validation)
31
+ engine = InferenceEngine(license="your-license-key")
32
+
33
+ # Load agent
34
+ agent = Agent.load("path/to/agent")
35
+ await engine.load_agent(agent)
36
+
37
+ # Package agent for inference (similar to Trainer.package())
38
+ await engine.package()
39
+
40
+ # Run inference
41
+ observation = {...} # Your observation from the simulator
42
+ action = engine.execute(observation)
43
+ ```
44
+
45
+ ### With Remote Objects
46
+
47
+ The inference engine supports remote skills, perceptors, and controllers, just like the Trainer:
48
+
49
+ ```python
50
+ from composabl_inference import InferenceEngine
51
+
52
+ # Optional: provide custom config for NetworkMgr (e.g., for remote targets)
53
+ config = {
54
+ "target": {
55
+ "local": {
56
+ "address": "localhost:1337",
57
+ },
58
+ },
59
+ }
60
+
61
+ engine = InferenceEngine(license="your-license-key", config=config)
62
+ await engine.load_agent("path/to/agent")
63
+ await engine.package()
64
+
65
+ # The skill processor will automatically handle remote objects
66
+ action = engine.execute(observation)
67
+ ```
68
+
69
+ ### Cleanup
70
+
71
+ ```python
72
+ # Clean up resources
73
+ await engine.close()
74
+ ```
75
+
76
+ ## Architecture
77
+
78
+ ### Components
79
+
80
+ 1. **InferenceEngine**: Main entry point for inference operations
81
+ 2. **NetworkMgr**: Manages network connections (non-Ray version)
82
+ 3. **ONNXInferenceEngine**: Handles ONNX model loading and inference
83
+ 4. **ONNXSkillProcessor**: Processes skills using ONNX models instead of PyTorch
84
+
85
+ ### Differences from Trainer
86
+
87
+ - Uses ONNX Runtime instead of PyTorch for model inference
88
+ - NetworkMgr is not a Ray actor (runs in the same process)
89
+ - No Ray initialization required
90
+ - Lighter weight, suitable for production deployment
91
+
92
+ ## Requirements
93
+
94
+ - Python >= 3.10
95
+ - composabl-core
96
+ - composabl-api
97
+ - onnxruntime
98
+ - numpy
99
+
100
+ ## License
101
+
102
+ Proprietary and confidential - Copyright (C) Amesa, Inc
@@ -0,0 +1,20 @@
1
+ # Copyright (C) Amesa, Inc - All Rights Reserved
2
+ # Unauthorized copying of this file, via any medium is strictly prohibited
3
+ # Proprietary and confidential
4
+
5
+ from amesa_core.networking.network_mgr import NetworkMgr
6
+
7
+ from amesa_inference.inference_engine import InferenceEngine
8
+ from amesa_inference.onnx_skill_processor import ONNXSkillProcessor
9
+ from amesa_inference.onnx_selector_processor import ONNXSelectorProcessor
10
+ from amesa_inference.onnx_skill_processor_factory import create_onnx_skill_processor
11
+ from amesa_inference.skill_processor_base import BaseSkillProcessor
12
+
13
+ __all__ = [
14
+ "InferenceEngine",
15
+ "NetworkMgr",
16
+ "BaseSkillProcessor",
17
+ "ONNXSkillProcessor",
18
+ "ONNXSelectorProcessor",
19
+ "create_onnx_skill_processor",
20
+ ]
@@ -0,0 +1,189 @@
1
+ # Copyright (C) Amesa, Inc - All Rights Reserved
2
+ # Unauthorized copying of this file, via any medium is strictly prohibited
3
+ # Proprietary and confidential
4
+
5
+ from typing import Any, Dict, Optional, Union
6
+
7
+ from amesa_core.settings import TrainerMode
8
+ from amesa_core.utils import license_util
9
+ import amesa_core.utils.logger as logger_util
10
+ from amesa_core.agent.agent import Agent
11
+ from amesa_core.agent.skill.skill import Skill
12
+ from amesa_core.config.trainer_config import TrainerConfig
13
+ from amesa_core.networking.config.skill_processor_context import (
14
+ SkillProcessorContext,
15
+ )
16
+ from amesa_core.networking.network_mgr import NetworkMgr
17
+ from amesa_inference.inference_network_mgr import InferenceNetworkMgr
18
+ from amesa_core.settings import settings
19
+
20
+ from amesa_inference.skill_processor_base import BaseSkillProcessor
21
+
22
+ logger = logger_util.get_logger(__name__)
23
+
24
+
25
+ class InferenceEngine:
26
+ """
27
+ Inference engine for running agent inference with and without remote objects.
28
+ Uses ONNX models instead of PyTorch/Ray for inference.
29
+
30
+ Supports all skill types including:
31
+ - Basic skills (ONNXSkillProcessor)
32
+ - Selector skills (ONNXSelectorProcessor) - automatically orchestrates child skills
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ license: Optional[str] = None,
38
+ config: Optional[Dict] = None,
39
+ network_mgr: Optional[NetworkMgr] = None,
40
+ ):
41
+ """
42
+ Initialize the inference engine.
43
+
44
+ Args:
45
+ license: License key for Amesa (required for license validation)
46
+ config: Optional configuration dictionary for NetworkMgr (defaults to local target)
47
+ network_mgr: Optional pre-configured NetworkMgr instance. If not provided,
48
+ creates an InferenceNetworkMgr.
49
+ """
50
+ # Set license in settings for license validation
51
+ if license is not None:
52
+ settings.AMESA_LICENSE = license
53
+
54
+ license_type, license_module, license_data = license_util.validate(
55
+ settings.AMESA_LICENSE,
56
+ settings.AMESA_ENV == TrainerMode.STAGING,
57
+ settings.URL_AMESA_LICENSE_SERVER,
58
+ )
59
+ logger.info(f"license_validated {license_type}")
60
+
61
+ # Create minimal config for NetworkMgr (defaults to local target)
62
+ if config is None:
63
+ config = {
64
+ "target": {
65
+ "local": {
66
+ "address": "localhost:1337",
67
+ }
68
+ }
69
+ }
70
+
71
+ self.config = TrainerConfig(**config)
72
+
73
+ # Use provided network_mgr or create default InferenceNetworkMgr
74
+ if network_mgr is not None:
75
+ self.network_mgr = network_mgr
76
+ else:
77
+ self.network_mgr = InferenceNetworkMgr(self.config)
78
+
79
+ self.agent: Optional[Agent] = None
80
+ self.skill_processor: Optional[BaseSkillProcessor] = None
81
+
82
+ async def load_agent(self, agent: Union[Agent, str]):
83
+ """
84
+ Load an agent for inference.
85
+
86
+ Args:
87
+ agent: Agent object or path to agent JSON file
88
+ """
89
+ if isinstance(agent, str):
90
+ self.agent = Agent.load(agent)
91
+ else:
92
+ self.agent = agent
93
+
94
+ if not self.agent.is_initialized:
95
+ await self.agent.init()
96
+
97
+ logger.info(f"Loaded agent: {self.agent.id}")
98
+
99
+ async def package(
100
+ self, agent: Optional[Agent] = None, skill: Optional[Union[str, Skill]] = None
101
+ ) -> None:
102
+ """
103
+ Package an agent for inference, similar to Trainer.package().
104
+ Creates an ONNX-based skill processor that can be used for inference.
105
+
106
+ Args:
107
+ agent: Agent object (if None, uses the loaded agent)
108
+ skill: Specific skill to package (if None, uses top skill)
109
+ """
110
+ if agent is None:
111
+ agent = self.agent
112
+
113
+ if agent is None:
114
+ raise ValueError("No agent provided and no agent loaded")
115
+
116
+ if not agent.is_initialized:
117
+ await agent.init()
118
+
119
+ # Determine which skill to use
120
+ if skill is None:
121
+ node = agent.get_top_skill()
122
+ else:
123
+ if isinstance(skill, str):
124
+ node = agent.get_node_by_name(skill)
125
+ else:
126
+ node = skill
127
+
128
+ # Create skill processor context
129
+ context = SkillProcessorContext(
130
+ agent=agent,
131
+ skill=node,
132
+ network_mgr=self.network_mgr,
133
+ is_training=False,
134
+ is_validating=False,
135
+ for_skill_group=False,
136
+ )
137
+
138
+ # Use factory to create appropriate skill processor based on skill type
139
+ # This handles selectors, basic skills, and other skill types automatically
140
+ from amesa_inference.onnx_skill_processor_factory import (
141
+ create_onnx_skill_processor,
142
+ )
143
+
144
+ self.skill_processor = await create_onnx_skill_processor(context)
145
+
146
+ logger.info(
147
+ f"Packaged agent for inference with skill: {node.get_name()} "
148
+ f"(processor type: {type(self.skill_processor).__name__})"
149
+ )
150
+
151
+ async def execute(
152
+ self,
153
+ obs: Any,
154
+ sim_action_mask: Optional[Any] = None,
155
+ explore: bool = False,
156
+ previous_action: Optional[Any] = None,
157
+ ) -> Any:
158
+ """
159
+ Execute inference on an observation.
160
+
161
+ Args:
162
+ obs: Observation from the simulator
163
+ sim_action_mask: Optional action mask from simulator
164
+ explore: Whether to explore (not used in inference, kept for compatibility)
165
+ previous_action: Previous action (for sequential skills)
166
+
167
+ Returns:
168
+ Action to take
169
+ """
170
+ if self.skill_processor is None:
171
+ raise RuntimeError(
172
+ "No skill processor available. Call package() first."
173
+ )
174
+
175
+ return await self.skill_processor._execute(
176
+ obs,
177
+ sim_action_mask=sim_action_mask,
178
+ explore=explore,
179
+ previous_action=previous_action,
180
+ )
181
+
182
+ async def close(self):
183
+ """Close the inference engine and cleanup resources."""
184
+ if self.network_mgr is not None:
185
+ self.network_mgr.stop_all()
186
+ self.network_mgr.stop_watchdogs()
187
+
188
+ logger.info("Inference engine closed")
189
+
@@ -0,0 +1,44 @@
1
+ # Copyright (C) Amesa, Inc - All Rights Reserved
2
+ # Unauthorized copying of this file, via any medium is strictly prohibited
3
+ # Proprietary and confidential
4
+
5
+ """
6
+ InferenceNetworkMgr - NetworkMgr subclass for async-first usage.
7
+
8
+ This module provides a NetworkMgr designed for use with InferenceEngine,
9
+ which operates in async contexts. Unlike the base NetworkMgr which uses
10
+ async_to_sync() (which may case event loop conflicts in async contexts), this
11
+ class returns coroutines that callers await directly.
12
+ """
13
+
14
+ import inspect
15
+ from typing import Callable
16
+
17
+ from amesa_core.networking.network_mgr import NetworkMgr
18
+
19
+
20
+ class InferenceNetworkMgr(NetworkMgr):
21
+ """
22
+ NetworkMgr subclass designed for async contexts.
23
+
24
+ This class overrides _call_method() to return coroutines directly,
25
+ which callers await. This avoids the event loop conflict entirely.
26
+
27
+ Usage:
28
+ mgr = InferenceNetworkMgr(config)
29
+ result = await mgr.call_remote_perceptor_mgr("method", ...)
30
+ """
31
+
32
+ def _call_method(self, attr: Callable, *args, **kwargs):
33
+ """
34
+ Override base _call_method to return coroutines for async methods.
35
+
36
+ Callers must await the result for async methods.
37
+ """
38
+ if inspect.iscoroutinefunction(attr):
39
+ # Return coroutine directly - caller will await
40
+ return attr(*args, **kwargs)
41
+ else:
42
+ # Sync method - call directly
43
+ return attr(*args, **kwargs)
44
+
@@ -0,0 +1,152 @@
1
+ # Copyright (C) Amesa, Inc - All Rights Reserved
2
+ # Unauthorized copying of this file, via any medium is strictly prohibited
3
+ # Proprietary and confidential
4
+
5
+ """
6
+ ONNX helper utilities for amesa_inference.
7
+
8
+ This module contains utilities for working with ONNX models during inference,
9
+ including ensuring models have the correct input structure.
10
+ """
11
+
12
+ import os
13
+ from typing import Optional
14
+
15
+ import amesa_core.utils.logger as logger_util
16
+ import onnx
17
+
18
+ logger = logger_util.get_logger(__name__)
19
+
20
+
21
+ def ensure_onnx_model_has_action_mask(
22
+ onnx_path: str, action_space=None, skill=None
23
+ ) -> bool:
24
+ """
25
+ Ensure an ONNX model has an action_mask input. If it doesn't, attempt to add it.
26
+
27
+ This function is used after downloading ONNX models to ensure they have the
28
+ action_mask input required for inference, especially for selectors.
29
+
30
+ Args:
31
+ onnx_path: Path to the ONNX model file
32
+ action_space: Optional action space to determine action_mask shape
33
+ skill: Optional skill object to get action space from
34
+
35
+ Returns:
36
+ True if the model was modified, False if it already had action_mask
37
+ """
38
+ try:
39
+ onnx_model = onnx.load(onnx_path)
40
+
41
+ # Check if action_mask input already exists
42
+ input_names = [input.name for input in onnx_model.graph.input]
43
+ has_action_mask = any("action_mask" in name.lower() for name in input_names)
44
+
45
+ if has_action_mask:
46
+ logger.debug(f"ONNX model at {onnx_path} already has action_mask input")
47
+ return False
48
+
49
+ # Determine action_mask shape
50
+ action_mask_shape = None
51
+ if action_space is not None:
52
+ try:
53
+ action_mask_space = action_space.get_action_mask_space()
54
+ if hasattr(action_mask_space, 'shape'):
55
+ action_mask_shape = action_mask_space.shape
56
+ elif hasattr(action_space, 'n'):
57
+ action_mask_shape = (action_space.n,)
58
+ elif hasattr(action_space, 'shape'):
59
+ action_mask_shape = action_space.shape
60
+ except Exception as e:
61
+ logger.warning(f"Could not get action_mask shape from action_space: {e}")
62
+
63
+ # If we have a skill, try to get action space from it
64
+ if action_mask_shape is None and skill is not None:
65
+ try:
66
+ skill_action_space = skill.get_action_space()
67
+ if skill_action_space is not None:
68
+ action_mask_space = skill_action_space.get_action_mask_space()
69
+ if hasattr(action_mask_space, 'shape'):
70
+ action_mask_shape = action_mask_space.shape
71
+ elif hasattr(skill_action_space, 'n'):
72
+ action_mask_shape = (skill_action_space.n,)
73
+ except Exception as e:
74
+ logger.warning(f"Could not get action_mask shape from skill: {e}")
75
+
76
+ # Try to infer from action_dist_inputs output shape
77
+ if action_mask_shape is None:
78
+ try:
79
+ output_names = [output.name for output in onnx_model.graph.output]
80
+ if "action_dist_inputs" in output_names:
81
+ # Find the output and get its shape
82
+ for output in onnx_model.graph.output:
83
+ if output.name == "action_dist_inputs":
84
+ shape = [
85
+ dim.dim_value if dim.dim_value > 0 else 1
86
+ for dim in output.type.tensor_type.shape.dim
87
+ ]
88
+ if len(shape) >= 2:
89
+ # Shape is [batch, action_size], we want action_size
90
+ action_mask_shape = (shape[-1],)
91
+ elif len(shape) == 1:
92
+ action_mask_shape = (shape[0],)
93
+ break
94
+ except Exception as e:
95
+ logger.warning(f"Could not infer action_mask shape from model outputs: {e}")
96
+
97
+ # Fallback to shape (1,)
98
+ if action_mask_shape is None:
99
+ logger.warning("Could not determine action_mask shape, using default (1,)")
100
+ action_mask_shape = (1,)
101
+
102
+ # Get observation input to determine batch dimension
103
+ obs_input = None
104
+ for input in onnx_model.graph.input:
105
+ if "observation" in input.name.lower() or input.name == "observation":
106
+ obs_input = input
107
+ break
108
+
109
+ if obs_input is None and len(onnx_model.graph.input) > 0:
110
+ obs_input = onnx_model.graph.input[0]
111
+
112
+ batch_dim = None
113
+ if obs_input is not None:
114
+ try:
115
+ # Get batch dimension from observation input (usually first dimension)
116
+ obs_shape = obs_input.type.tensor_type.shape.dim
117
+ if len(obs_shape) > 0:
118
+ batch_dim = obs_shape[0]
119
+ except Exception:
120
+ pass
121
+
122
+ # Create action_mask input
123
+ # Use dynamic batch dimension if available, otherwise use 1
124
+ if batch_dim is not None:
125
+ action_mask_dims = [batch_dim]
126
+ else:
127
+ action_mask_dims = [onnx.TensorShapeProto.Dimension(dim_value=1)]
128
+
129
+ for dim_size in action_mask_shape:
130
+ action_mask_dims.append(onnx.TensorShapeProto.Dimension(dim_value=dim_size))
131
+
132
+ action_mask_type = onnx.TypeProto()
133
+ action_mask_type.tensor_type.elem_type = onnx.TensorProto.DOUBLE
134
+ action_mask_type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto(dim=action_mask_dims))
135
+
136
+ action_mask_value_info = onnx.ValueInfoProto()
137
+ action_mask_value_info.name = "action_mask"
138
+ action_mask_value_info.type.CopyFrom(action_mask_type)
139
+
140
+ # Add action_mask input to the graph
141
+ onnx_model.graph.input.append(action_mask_value_info)
142
+
143
+ # Validate and save
144
+ onnx.checker.check_model(onnx_model)
145
+ onnx.save(onnx_model, onnx_path)
146
+
147
+ logger.info(f"Added action_mask input with shape {action_mask_shape} to ONNX model at {onnx_path}")
148
+ return True
149
+
150
+ except Exception as e:
151
+ logger.error(f"Failed to ensure action_mask input in ONNX model at {onnx_path}: {e}")
152
+ raise