gliner2-onnx 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gliner2_onnx-0.1.0/.gitignore +10 -0
- gliner2_onnx-0.1.0/LICENSE +21 -0
- gliner2_onnx-0.1.0/PKG-INFO +163 -0
- gliner2_onnx-0.1.0/README.md +125 -0
- gliner2_onnx-0.1.0/gliner2_onnx/__init__.py +23 -0
- gliner2_onnx-0.1.0/gliner2_onnx/_version.py +34 -0
- gliner2_onnx-0.1.0/gliner2_onnx/constants.py +38 -0
- gliner2_onnx-0.1.0/gliner2_onnx/exceptions.py +13 -0
- gliner2_onnx-0.1.0/gliner2_onnx/runtime.py +499 -0
- gliner2_onnx-0.1.0/gliner2_onnx/types.py +40 -0
- gliner2_onnx-0.1.0/pyproject.toml +159 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 @lmoe
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: gliner2-onnx
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: GLiNER2 ONNX runtime for NER and classification without PyTorch
|
|
5
|
+
Project-URL: Homepage, https://github.com/lmoe/gliner2-onnx
|
|
6
|
+
Project-URL: Repository, https://github.com/lmoe/gliner2-onnx
|
|
7
|
+
Project-URL: Issues, https://github.com/lmoe/gliner2-onnx/issues
|
|
8
|
+
Author: lmoe
|
|
9
|
+
License-Expression: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Keywords: gliner,named-entity-recognition,ner,nlp,onnx,zero-shot
|
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
+
Classifier: Topic :: Text Processing :: Linguistic
|
|
22
|
+
Requires-Python: >=3.10
|
|
23
|
+
Requires-Dist: huggingface-hub>=0.23.0
|
|
24
|
+
Requires-Dist: numpy>=1.26.0
|
|
25
|
+
Requires-Dist: onnxruntime>=1.18.0
|
|
26
|
+
Requires-Dist: transformers>=4.40.0
|
|
27
|
+
Provides-Extra: export
|
|
28
|
+
Requires-Dist: gliner2>=1.2.4; extra == 'export'
|
|
29
|
+
Requires-Dist: onnx<1.18,>=1.14.0; extra == 'export'
|
|
30
|
+
Requires-Dist: onnxconverter-common>=1.14.0; extra == 'export'
|
|
31
|
+
Requires-Dist: onnxscript>=0.6.0; extra == 'export'
|
|
32
|
+
Requires-Dist: requests>=2.32.5; extra == 'export'
|
|
33
|
+
Requires-Dist: torch>=2.0.0; extra == 'export'
|
|
34
|
+
Requires-Dist: urllib3>=2.6.3; extra == 'export'
|
|
35
|
+
Provides-Extra: test
|
|
36
|
+
Requires-Dist: pytest; extra == 'test'
|
|
37
|
+
Description-Content-Type: text/markdown
|
|
38
|
+
|
|
39
|
+
# gliner2-onnx
|
|
40
|
+
|
|
41
|
+
GLiNER2 ONNX runtime for Python. Runs GLiNER2 models without PyTorch.
|
|
42
|
+
|
|
43
|
+
This library is experimental. The API may change between versions.
|
|
44
|
+
|
|
45
|
+
## Features
|
|
46
|
+
|
|
47
|
+
- Zero-shot NER and text classification
|
|
48
|
+
- Runs with ONNX Runtime (no PyTorch dependency)
|
|
49
|
+
- FP32 and FP16 precision support
|
|
50
|
+
- GPU acceleration via CUDA
|
|
51
|
+
|
|
52
|
+
All other GLiNER2 features such as JSON export are not supported.
|
|
53
|
+
|
|
54
|
+
## Installation
|
|
55
|
+
|
|
56
|
+
```bash
|
|
57
|
+
pip install gliner2-onnx
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
## NER
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
from gliner2_onnx import GLiNER2ONNXRuntime
|
|
64
|
+
|
|
65
|
+
runtime = GLiNER2ONNXRuntime.from_pretrained("lmoe/gliner2-large-v1-onnx")
|
|
66
|
+
|
|
67
|
+
entities = runtime.extract_entities(
|
|
68
|
+
"John works at Google in Seattle",
|
|
69
|
+
["person", "organization", "location"]
|
|
70
|
+
)
|
|
71
|
+
# [
|
|
72
|
+
# Entity(text='John', label='person', start=0, end=4, score=0.98),
|
|
73
|
+
# Entity(text='Google', label='organization', start=14, end=20, score=0.97),
|
|
74
|
+
# Entity(text='Seattle', label='location', start=24, end=31, score=0.96)
|
|
75
|
+
# ]
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
## Classification
|
|
79
|
+
|
|
80
|
+
```python
|
|
81
|
+
from gliner2_onnx import GLiNER2ONNXRuntime
|
|
82
|
+
|
|
83
|
+
runtime = GLiNER2ONNXRuntime.from_pretrained("lmoe/gliner2-large-v1-onnx")
|
|
84
|
+
|
|
85
|
+
# Single-label classification
|
|
86
|
+
result = runtime.classify(
|
|
87
|
+
"Buy milk from the store",
|
|
88
|
+
["shopping", "work", "entertainment"]
|
|
89
|
+
)
|
|
90
|
+
# {'shopping': 0.95}
|
|
91
|
+
|
|
92
|
+
# Multi-label classification
|
|
93
|
+
result = runtime.classify(
|
|
94
|
+
"Buy milk and finish the report",
|
|
95
|
+
["shopping", "work", "entertainment"],
|
|
96
|
+
threshold=0.3,
|
|
97
|
+
multi_label=True
|
|
98
|
+
)
|
|
99
|
+
# {'shopping': 0.85, 'work': 0.72}
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
## CUDA
|
|
103
|
+
|
|
104
|
+
To use CUDA for GPU acceleration:
|
|
105
|
+
|
|
106
|
+
```python
|
|
107
|
+
runtime = GLiNER2ONNXRuntime.from_pretrained(
|
|
108
|
+
"lmoe/gliner2-large-v1-onnx",
|
|
109
|
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
110
|
+
)
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
## Precision
|
|
114
|
+
|
|
115
|
+
Both FP32 and FP16 models are supported. Only the requested precision is downloaded.
|
|
116
|
+
|
|
117
|
+
```python
|
|
118
|
+
runtime = GLiNER2ONNXRuntime.from_pretrained(
|
|
119
|
+
"lmoe/gliner2-large-v1-onnx",
|
|
120
|
+
precision="fp16"
|
|
121
|
+
)
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
## Models
|
|
125
|
+
|
|
126
|
+
Pre-exported ONNX models:
|
|
127
|
+
|
|
128
|
+
| Model | HuggingFace |
|
|
129
|
+
|-------|-------------|
|
|
130
|
+
| gliner2-large-v1 | [lmoe/gliner2-large-v1-onnx](https://huggingface.co/lmoe/gliner2-large-v1-onnx) |
|
|
131
|
+
| gliner2-multi-v1 | [lmoe/gliner2-multi-v1-onnx](https://huggingface.co/lmoe/gliner2-multi-v1-onnx) |
|
|
132
|
+
|
|
133
|
+
Note: `gliner2-base-v1` is not supported (uses a different architecture).
|
|
134
|
+
|
|
135
|
+
## Exporting Models
|
|
136
|
+
|
|
137
|
+
To export your own models, clone the repository and use make:
|
|
138
|
+
|
|
139
|
+
```bash
|
|
140
|
+
git clone https://github.com/lmoe/gliner2-onnx
|
|
141
|
+
cd gliner2-onnx
|
|
142
|
+
|
|
143
|
+
# FP32 only
|
|
144
|
+
make onnx-export MODEL=fastino/gliner2-large-v1
|
|
145
|
+
|
|
146
|
+
# FP32 + FP16
|
|
147
|
+
make onnx-export MODEL=fastino/gliner2-large-v1 QUANTIZE=fp16
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
Output is saved to `model_out/<model-name>/`.
|
|
151
|
+
|
|
152
|
+
## JavaScript/TypeScript
|
|
153
|
+
|
|
154
|
+
For Node.js, see [@lmoe/gliner-onnx.js](https://github.com/lmoe/gliner-onnx.js).
|
|
155
|
+
|
|
156
|
+
## Credits
|
|
157
|
+
|
|
158
|
+
- [fastino-ai/GLiNER2](https://github.com/fastino-ai/GLiNER2) - Original GLiNER2 implementation
|
|
159
|
+
- [fastino/gliner2-large-v1](https://huggingface.co/fastino/gliner2-large-v1) - Pre-trained models
|
|
160
|
+
|
|
161
|
+
## License
|
|
162
|
+
|
|
163
|
+
MIT
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# gliner2-onnx
|
|
2
|
+
|
|
3
|
+
GLiNER2 ONNX runtime for Python. Runs GLiNER2 models without PyTorch.
|
|
4
|
+
|
|
5
|
+
This library is experimental. The API may change between versions.
|
|
6
|
+
|
|
7
|
+
## Features
|
|
8
|
+
|
|
9
|
+
- Zero-shot NER and text classification
|
|
10
|
+
- Runs with ONNX Runtime (no PyTorch dependency)
|
|
11
|
+
- FP32 and FP16 precision support
|
|
12
|
+
- GPU acceleration via CUDA
|
|
13
|
+
|
|
14
|
+
All other GLiNER2 features such as JSON export are not supported.
|
|
15
|
+
|
|
16
|
+
## Installation
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
pip install gliner2-onnx
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
## NER
|
|
23
|
+
|
|
24
|
+
```python
|
|
25
|
+
from gliner2_onnx import GLiNER2ONNXRuntime
|
|
26
|
+
|
|
27
|
+
runtime = GLiNER2ONNXRuntime.from_pretrained("lmoe/gliner2-large-v1-onnx")
|
|
28
|
+
|
|
29
|
+
entities = runtime.extract_entities(
|
|
30
|
+
"John works at Google in Seattle",
|
|
31
|
+
["person", "organization", "location"]
|
|
32
|
+
)
|
|
33
|
+
# [
|
|
34
|
+
# Entity(text='John', label='person', start=0, end=4, score=0.98),
|
|
35
|
+
# Entity(text='Google', label='organization', start=14, end=20, score=0.97),
|
|
36
|
+
# Entity(text='Seattle', label='location', start=24, end=31, score=0.96)
|
|
37
|
+
# ]
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## Classification
|
|
41
|
+
|
|
42
|
+
```python
|
|
43
|
+
from gliner2_onnx import GLiNER2ONNXRuntime
|
|
44
|
+
|
|
45
|
+
runtime = GLiNER2ONNXRuntime.from_pretrained("lmoe/gliner2-large-v1-onnx")
|
|
46
|
+
|
|
47
|
+
# Single-label classification
|
|
48
|
+
result = runtime.classify(
|
|
49
|
+
"Buy milk from the store",
|
|
50
|
+
["shopping", "work", "entertainment"]
|
|
51
|
+
)
|
|
52
|
+
# {'shopping': 0.95}
|
|
53
|
+
|
|
54
|
+
# Multi-label classification
|
|
55
|
+
result = runtime.classify(
|
|
56
|
+
"Buy milk and finish the report",
|
|
57
|
+
["shopping", "work", "entertainment"],
|
|
58
|
+
threshold=0.3,
|
|
59
|
+
multi_label=True
|
|
60
|
+
)
|
|
61
|
+
# {'shopping': 0.85, 'work': 0.72}
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
## CUDA
|
|
65
|
+
|
|
66
|
+
To use CUDA for GPU acceleration:
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
runtime = GLiNER2ONNXRuntime.from_pretrained(
|
|
70
|
+
"lmoe/gliner2-large-v1-onnx",
|
|
71
|
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
72
|
+
)
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
## Precision
|
|
76
|
+
|
|
77
|
+
Both FP32 and FP16 models are supported. Only the requested precision is downloaded.
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
runtime = GLiNER2ONNXRuntime.from_pretrained(
|
|
81
|
+
"lmoe/gliner2-large-v1-onnx",
|
|
82
|
+
precision="fp16"
|
|
83
|
+
)
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
## Models
|
|
87
|
+
|
|
88
|
+
Pre-exported ONNX models:
|
|
89
|
+
|
|
90
|
+
| Model | HuggingFace |
|
|
91
|
+
|-------|-------------|
|
|
92
|
+
| gliner2-large-v1 | [lmoe/gliner2-large-v1-onnx](https://huggingface.co/lmoe/gliner2-large-v1-onnx) |
|
|
93
|
+
| gliner2-multi-v1 | [lmoe/gliner2-multi-v1-onnx](https://huggingface.co/lmoe/gliner2-multi-v1-onnx) |
|
|
94
|
+
|
|
95
|
+
Note: `gliner2-base-v1` is not supported (uses a different architecture).
|
|
96
|
+
|
|
97
|
+
## Exporting Models
|
|
98
|
+
|
|
99
|
+
To export your own models, clone the repository and use make:
|
|
100
|
+
|
|
101
|
+
```bash
|
|
102
|
+
git clone https://github.com/lmoe/gliner2-onnx
|
|
103
|
+
cd gliner2-onnx
|
|
104
|
+
|
|
105
|
+
# FP32 only
|
|
106
|
+
make onnx-export MODEL=fastino/gliner2-large-v1
|
|
107
|
+
|
|
108
|
+
# FP32 + FP16
|
|
109
|
+
make onnx-export MODEL=fastino/gliner2-large-v1 QUANTIZE=fp16
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
Output is saved to `model_out/<model-name>/`.
|
|
113
|
+
|
|
114
|
+
## JavaScript/TypeScript
|
|
115
|
+
|
|
116
|
+
For Node.js, see [@lmoe/gliner-onnx.js](https://github.com/lmoe/gliner-onnx.js).
|
|
117
|
+
|
|
118
|
+
## Credits
|
|
119
|
+
|
|
120
|
+
- [fastino-ai/GLiNER2](https://github.com/fastino-ai/GLiNER2) - Original GLiNER2 implementation
|
|
121
|
+
- [fastino/gliner2-large-v1](https://huggingface.co/fastino/gliner2-large-v1) - Pre-trained models
|
|
122
|
+
|
|
123
|
+
## License
|
|
124
|
+
|
|
125
|
+
MIT
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""GLiNER2 ONNX Runtime - NER and classification without PyTorch."""
|
|
2
|
+
|
|
3
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
4
|
+
|
|
5
|
+
from .constants import Precision
|
|
6
|
+
from .exceptions import ConfigurationError, GLiNER2Error, ModelNotFoundError
|
|
7
|
+
from .runtime import GLiNER2ONNXRuntime
|
|
8
|
+
from .types import Entity
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
__version__ = version("gliner2-onnx")
|
|
12
|
+
except PackageNotFoundError:
|
|
13
|
+
__version__ = "0.0.0.dev0"
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"ConfigurationError",
|
|
17
|
+
"Entity",
|
|
18
|
+
"GLiNER2Error",
|
|
19
|
+
"GLiNER2ONNXRuntime",
|
|
20
|
+
"ModelNotFoundError",
|
|
21
|
+
"Precision",
|
|
22
|
+
"__version__",
|
|
23
|
+
]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TYPE_CHECKING = False
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
from typing import Union
|
|
17
|
+
|
|
18
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
19
|
+
COMMIT_ID = Union[str, None]
|
|
20
|
+
else:
|
|
21
|
+
VERSION_TUPLE = object
|
|
22
|
+
COMMIT_ID = object
|
|
23
|
+
|
|
24
|
+
version: str
|
|
25
|
+
__version__: str
|
|
26
|
+
__version_tuple__: VERSION_TUPLE
|
|
27
|
+
version_tuple: VERSION_TUPLE
|
|
28
|
+
commit_id: COMMIT_ID
|
|
29
|
+
__commit_id__: COMMIT_ID
|
|
30
|
+
|
|
31
|
+
__version__ = version = '0.1.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 1, 0)
|
|
33
|
+
|
|
34
|
+
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Final, Literal
|
|
3
|
+
|
|
4
|
+
CONFIG_FILE: Final = "config.json"
|
|
5
|
+
GLINER2_CONFIG_FILE: Final = "gliner2_config.json"
|
|
6
|
+
|
|
7
|
+
Precision = Literal["fp32", "fp16"]
|
|
8
|
+
|
|
9
|
+
TOKEN_P: Final = "[P]" # noqa: S105
|
|
10
|
+
TOKEN_L: Final = "[L]" # noqa: S105
|
|
11
|
+
TOKEN_E: Final = "[E]" # noqa: S105
|
|
12
|
+
TOKEN_SEP_TEXT: Final = "[SEP_TEXT]" # noqa: S105
|
|
13
|
+
|
|
14
|
+
REQUIRED_SPECIAL_TOKENS: Final = (TOKEN_P, TOKEN_L, TOKEN_E, TOKEN_SEP_TEXT)
|
|
15
|
+
|
|
16
|
+
SCHEMA_OPEN: Final = "("
|
|
17
|
+
SCHEMA_CLOSE: Final = ")"
|
|
18
|
+
NER_TASK_NAME: Final = "entities"
|
|
19
|
+
CLASSIFICATION_TASK_NAME: Final = "category"
|
|
20
|
+
|
|
21
|
+
ONNX_INPUT_IDS: Final = "input_ids"
|
|
22
|
+
ONNX_ATTENTION_MASK: Final = "attention_mask"
|
|
23
|
+
ONNX_HIDDEN_STATE: Final = "hidden_state"
|
|
24
|
+
ONNX_HIDDEN_STATES: Final = "hidden_states"
|
|
25
|
+
ONNX_SPAN_START_IDX: Final = "span_start_idx"
|
|
26
|
+
ONNX_SPAN_END_IDX: Final = "span_end_idx"
|
|
27
|
+
ONNX_LABEL_EMBEDDINGS: Final = "label_embeddings"
|
|
28
|
+
|
|
29
|
+
# Regex pattern matching GLiNER2's WhitespaceTokenSplitter
|
|
30
|
+
# Matches: URLs, emails, @mentions, words (with hyphens/underscores), single non-whitespace chars
|
|
31
|
+
WORD_PATTERN: Final = re.compile(
|
|
32
|
+
r"""(?:https?://[^\s]+|www\.[^\s]+)
|
|
33
|
+
|[a-z0-9._%+-]+@[a-z0-9.-]+\.[a-z]{2,}
|
|
34
|
+
|@[a-z0-9_]+
|
|
35
|
+
|\w+(?:[-_]\w+)*
|
|
36
|
+
|\S""",
|
|
37
|
+
re.VERBOSE | re.IGNORECASE,
|
|
38
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Custom exceptions for GLiNER2 ONNX runtime."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class GLiNER2Error(Exception):
|
|
5
|
+
"""Base exception for GLiNER2 ONNX runtime errors."""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ModelNotFoundError(GLiNER2Error):
|
|
9
|
+
"""Raised when a required model file is not found."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConfigurationError(GLiNER2Error):
|
|
13
|
+
"""Raised when configuration is invalid or missing."""
|
|
@@ -0,0 +1,499 @@
|
|
|
1
|
+
"""GLiNER2 ONNX Runtime - NER and classification without PyTorch."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import onnxruntime as ort
|
|
8
|
+
from transformers import AutoTokenizer
|
|
9
|
+
|
|
10
|
+
from .constants import (
|
|
11
|
+
CLASSIFICATION_TASK_NAME,
|
|
12
|
+
CONFIG_FILE,
|
|
13
|
+
GLINER2_CONFIG_FILE,
|
|
14
|
+
NER_TASK_NAME,
|
|
15
|
+
ONNX_ATTENTION_MASK,
|
|
16
|
+
ONNX_HIDDEN_STATE,
|
|
17
|
+
ONNX_HIDDEN_STATES,
|
|
18
|
+
ONNX_INPUT_IDS,
|
|
19
|
+
ONNX_LABEL_EMBEDDINGS,
|
|
20
|
+
ONNX_SPAN_END_IDX,
|
|
21
|
+
ONNX_SPAN_START_IDX,
|
|
22
|
+
REQUIRED_SPECIAL_TOKENS,
|
|
23
|
+
SCHEMA_CLOSE,
|
|
24
|
+
SCHEMA_OPEN,
|
|
25
|
+
TOKEN_E,
|
|
26
|
+
TOKEN_L,
|
|
27
|
+
TOKEN_P,
|
|
28
|
+
TOKEN_SEP_TEXT,
|
|
29
|
+
WORD_PATTERN,
|
|
30
|
+
Precision,
|
|
31
|
+
)
|
|
32
|
+
from .exceptions import ConfigurationError, ModelNotFoundError
|
|
33
|
+
from .types import Entity, GLiNER2Config, OnnxModelFiles
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _validate_precision(onnx_files: dict[str, OnnxModelFiles], precision: str) -> OnnxModelFiles:
|
|
37
|
+
available = list(onnx_files.keys())
|
|
38
|
+
if not available:
|
|
39
|
+
raise ConfigurationError(f"No onnx_files found in {GLINER2_CONFIG_FILE}")
|
|
40
|
+
if precision not in available:
|
|
41
|
+
raise ConfigurationError(f"Precision '{precision}' not available. Available: {available}")
|
|
42
|
+
return onnx_files[precision]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GLiNER2ONNXRuntime:
|
|
46
|
+
"""
|
|
47
|
+
ONNX-based runtime for GLiNER2 classification and NER.
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
>>> runtime = GLiNER2ONNXRuntime.from_pretrained("lmoe/gliner2-large-v1-onnx")
|
|
51
|
+
>>> entities = runtime.extract_entities("John works at Google", ["person", "org"])
|
|
52
|
+
>>> result = runtime.classify("Buy milk", ["shopping", "work"])
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def from_pretrained(
|
|
57
|
+
cls,
|
|
58
|
+
model_id: str,
|
|
59
|
+
*,
|
|
60
|
+
precision: Precision = "fp32",
|
|
61
|
+
providers: list[str] | None = None,
|
|
62
|
+
revision: str | None = None,
|
|
63
|
+
) -> "GLiNER2ONNXRuntime":
|
|
64
|
+
"""
|
|
65
|
+
Load a GLiNER2 ONNX model from HuggingFace Hub.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
model_id: HuggingFace model ID (e.g., "lmoe/gliner2-large-v1-onnx")
|
|
69
|
+
precision: Model precision ("fp32" or "fp16").
|
|
70
|
+
Only downloads the requested precision variant.
|
|
71
|
+
Available precisions are defined in the model's config.
|
|
72
|
+
providers: ONNX execution providers (e.g., ["CUDAExecutionProvider"]).
|
|
73
|
+
Defaults to ["CPUExecutionProvider"].
|
|
74
|
+
Use onnxruntime.get_available_providers() to see available options.
|
|
75
|
+
revision: Model revision (branch, tag, or commit hash)
|
|
76
|
+
cache_dir: Directory to cache downloaded models
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
GLiNER2ONNXRuntime instance
|
|
80
|
+
"""
|
|
81
|
+
from huggingface_hub import hf_hub_download, snapshot_download
|
|
82
|
+
|
|
83
|
+
config_path = hf_hub_download(
|
|
84
|
+
repo_id=model_id,
|
|
85
|
+
filename=GLINER2_CONFIG_FILE,
|
|
86
|
+
revision=revision,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
with Path(config_path).open() as f:
|
|
90
|
+
config: GLiNER2Config = json.load(f)
|
|
91
|
+
|
|
92
|
+
onnx_files = _validate_precision(config.get("onnx_files", {}), precision)
|
|
93
|
+
onnx_patterns: list[str] = [
|
|
94
|
+
onnx_files["encoder"],
|
|
95
|
+
onnx_files["classifier"],
|
|
96
|
+
onnx_files["span_rep"],
|
|
97
|
+
onnx_files["count_embed"],
|
|
98
|
+
]
|
|
99
|
+
onnx_data_patterns: list[str] = [f"{p}.data" for p in onnx_patterns]
|
|
100
|
+
|
|
101
|
+
allow_patterns: list[str] = [
|
|
102
|
+
"*.json",
|
|
103
|
+
*onnx_patterns,
|
|
104
|
+
*onnx_data_patterns,
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
model_path = snapshot_download(
|
|
108
|
+
repo_id=model_id,
|
|
109
|
+
revision=revision,
|
|
110
|
+
allow_patterns=allow_patterns,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return cls(model_path, precision=precision, providers=providers)
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
model_path: str | Path,
|
|
118
|
+
precision: Precision = "fp32",
|
|
119
|
+
providers: list[str] | None = None,
|
|
120
|
+
):
|
|
121
|
+
"""
|
|
122
|
+
Initialize GLiNER2 ONNX runtime from a local directory.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
model_path: Directory containing ONNX models and config
|
|
126
|
+
precision: Model precision ("fp32" or "fp16")
|
|
127
|
+
providers: ONNX execution providers (e.g., ["CUDAExecutionProvider"]).
|
|
128
|
+
Defaults to ["CPUExecutionProvider"].
|
|
129
|
+
Use onnxruntime.get_available_providers() to see available options.
|
|
130
|
+
|
|
131
|
+
Raises:
|
|
132
|
+
ModelNotFoundError: If required model files are missing
|
|
133
|
+
ConfigurationError: If config files are missing or invalid
|
|
134
|
+
"""
|
|
135
|
+
self.model_path = Path(model_path)
|
|
136
|
+
self._validate_model_path()
|
|
137
|
+
|
|
138
|
+
self._validate_base_config()
|
|
139
|
+
config = self._load_gliner2_config()
|
|
140
|
+
|
|
141
|
+
self.max_width = config["max_width"]
|
|
142
|
+
self.special_tokens = config["special_tokens"]
|
|
143
|
+
self.precision = precision
|
|
144
|
+
|
|
145
|
+
onnx_files = _validate_precision(config["onnx_files"], precision)
|
|
146
|
+
|
|
147
|
+
if providers is None:
|
|
148
|
+
providers = ["CPUExecutionProvider"]
|
|
149
|
+
self._load_onnx_models(onnx_files, providers)
|
|
150
|
+
|
|
151
|
+
self.tokenizer = AutoTokenizer.from_pretrained(str(self.model_path))
|
|
152
|
+
|
|
153
|
+
def _validate_model_path(self) -> None:
|
|
154
|
+
"""Validate that model directory exists."""
|
|
155
|
+
if not self.model_path.is_dir():
|
|
156
|
+
raise ModelNotFoundError(f"Model directory not found: {self.model_path}")
|
|
157
|
+
|
|
158
|
+
def _validate_base_config(self) -> None:
|
|
159
|
+
"""Validate that base config exists and is valid JSON."""
|
|
160
|
+
config_path = self.model_path / CONFIG_FILE
|
|
161
|
+
if not config_path.exists():
|
|
162
|
+
raise ConfigurationError(f"{CONFIG_FILE} not found in {self.model_path}")
|
|
163
|
+
|
|
164
|
+
try:
|
|
165
|
+
with config_path.open() as f:
|
|
166
|
+
json.load(f)
|
|
167
|
+
except json.JSONDecodeError as e:
|
|
168
|
+
raise ConfigurationError(f"Invalid {CONFIG_FILE}: {e}") from e
|
|
169
|
+
|
|
170
|
+
def _load_gliner2_config(self) -> GLiNER2Config:
|
|
171
|
+
"""Load GLiNER2-specific config."""
|
|
172
|
+
config_path = self.model_path / GLINER2_CONFIG_FILE
|
|
173
|
+
if not config_path.exists():
|
|
174
|
+
raise ConfigurationError(f"{GLINER2_CONFIG_FILE} not found in {self.model_path}")
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
with config_path.open() as f:
|
|
178
|
+
raw_config: dict[str, object] = json.load(f)
|
|
179
|
+
except json.JSONDecodeError as e:
|
|
180
|
+
raise ConfigurationError(f"Invalid {GLINER2_CONFIG_FILE}: {e}") from e
|
|
181
|
+
|
|
182
|
+
max_width = raw_config.get("max_width")
|
|
183
|
+
if not isinstance(max_width, int):
|
|
184
|
+
raise ConfigurationError(f"{GLINER2_CONFIG_FILE} missing or invalid max_width")
|
|
185
|
+
|
|
186
|
+
special_tokens_raw = raw_config.get("special_tokens")
|
|
187
|
+
if not isinstance(special_tokens_raw, dict):
|
|
188
|
+
raise ConfigurationError(f"{GLINER2_CONFIG_FILE} missing or invalid special_tokens")
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
special_tokens = {str(k): int(v) for k, v in special_tokens_raw.items()}
|
|
192
|
+
except (TypeError, ValueError) as e:
|
|
193
|
+
raise ConfigurationError(f"{GLINER2_CONFIG_FILE} special_tokens values must be integers: {e}") from e
|
|
194
|
+
|
|
195
|
+
missing = [t for t in REQUIRED_SPECIAL_TOKENS if t not in special_tokens]
|
|
196
|
+
if missing:
|
|
197
|
+
raise ConfigurationError(f"{GLINER2_CONFIG_FILE} missing special tokens: {missing}")
|
|
198
|
+
|
|
199
|
+
onnx_files_raw = raw_config.get("onnx_files")
|
|
200
|
+
if not isinstance(onnx_files_raw, dict):
|
|
201
|
+
raise ConfigurationError(f"{GLINER2_CONFIG_FILE} missing or invalid onnx_files")
|
|
202
|
+
|
|
203
|
+
return GLiNER2Config(
|
|
204
|
+
max_width=max_width,
|
|
205
|
+
special_tokens=special_tokens,
|
|
206
|
+
onnx_files=onnx_files_raw,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def _load_onnx_models(self, onnx_files: OnnxModelFiles, providers: list[str]) -> None:
|
|
210
|
+
"""Load all ONNX model files using paths from config."""
|
|
211
|
+
self.encoder = self._load_model(self.model_path / onnx_files["encoder"], providers)
|
|
212
|
+
self.classifier = self._load_model(self.model_path / onnx_files["classifier"], providers)
|
|
213
|
+
self.span_rep = self._load_model(self.model_path / onnx_files["span_rep"], providers)
|
|
214
|
+
self.count_embed = self._load_model(self.model_path / onnx_files["count_embed"], providers)
|
|
215
|
+
|
|
216
|
+
def _load_model(self, path: Path, providers: list[str]) -> ort.InferenceSession:
|
|
217
|
+
"""Load a single ONNX model."""
|
|
218
|
+
if not path.exists():
|
|
219
|
+
raise ModelNotFoundError(f"Model not found: {path}")
|
|
220
|
+
return ort.InferenceSession(str(path), providers=providers)
|
|
221
|
+
|
|
222
|
+
def classify(
|
|
223
|
+
self,
|
|
224
|
+
text: str,
|
|
225
|
+
labels: list[str],
|
|
226
|
+
threshold: float = 0.5,
|
|
227
|
+
*,
|
|
228
|
+
multi_label: bool = False,
|
|
229
|
+
) -> dict[str, float]:
|
|
230
|
+
"""
|
|
231
|
+
Classify text into one or more categories.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
text: Text to classify
|
|
235
|
+
labels: Candidate labels
|
|
236
|
+
threshold: Minimum score threshold (for multi_label mode)
|
|
237
|
+
multi_label: If True, return all labels above threshold;
|
|
238
|
+
if False, return only the best label
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Dict mapping label(s) to score(s)
|
|
242
|
+
"""
|
|
243
|
+
if not text or not text.strip():
|
|
244
|
+
raise ValueError("Text cannot be empty")
|
|
245
|
+
if not labels:
|
|
246
|
+
raise ValueError("Labels cannot be empty")
|
|
247
|
+
|
|
248
|
+
input_ids, attention_mask, label_positions = self._build_classification_input(text, labels)
|
|
249
|
+
hidden_states = self._encode(input_ids, attention_mask)
|
|
250
|
+
label_embeddings = hidden_states[0, label_positions, :]
|
|
251
|
+
logits = self.classifier.run(None, {ONNX_HIDDEN_STATE: label_embeddings})[0].flatten()
|
|
252
|
+
|
|
253
|
+
if multi_label:
|
|
254
|
+
probs = self._sigmoid(logits)
|
|
255
|
+
results = {label: float(prob) for label, prob in zip(labels, probs, strict=True)}
|
|
256
|
+
return {k: v for k, v in results.items() if v >= threshold}
|
|
257
|
+
|
|
258
|
+
probs = self._softmax(logits)
|
|
259
|
+
results = {label: float(prob) for label, prob in zip(labels, probs, strict=True)}
|
|
260
|
+
best = max(results.keys(), key=lambda k: results[k])
|
|
261
|
+
return {best: results[best]}
|
|
262
|
+
|
|
263
|
+
def _build_schema_prefix(
|
|
264
|
+
self,
|
|
265
|
+
task_name: str,
|
|
266
|
+
labels: list[str],
|
|
267
|
+
label_token_key: str,
|
|
268
|
+
) -> tuple[list[int], list[int]]:
|
|
269
|
+
"""Build schema prefix tokens: ( [P] task ( [L/E] label1 [L/E] label2 ... ) ) [SEP_TEXT]"""
|
|
270
|
+
p_id = self.special_tokens[TOKEN_P]
|
|
271
|
+
label_token_id = self.special_tokens[label_token_key]
|
|
272
|
+
sep_text_id = self.special_tokens[TOKEN_SEP_TEXT]
|
|
273
|
+
|
|
274
|
+
tokens: list[int] = []
|
|
275
|
+
tokens.extend(self.tokenizer.encode(SCHEMA_OPEN, add_special_tokens=False))
|
|
276
|
+
tokens.append(p_id)
|
|
277
|
+
tokens.extend(self.tokenizer.encode(task_name, add_special_tokens=False))
|
|
278
|
+
tokens.extend(self.tokenizer.encode(SCHEMA_OPEN, add_special_tokens=False))
|
|
279
|
+
|
|
280
|
+
label_positions = []
|
|
281
|
+
for label in labels:
|
|
282
|
+
label_positions.append(len(tokens))
|
|
283
|
+
tokens.append(label_token_id)
|
|
284
|
+
tokens.extend(self.tokenizer.encode(label, add_special_tokens=False))
|
|
285
|
+
|
|
286
|
+
tokens.extend(self.tokenizer.encode(SCHEMA_CLOSE, add_special_tokens=False))
|
|
287
|
+
tokens.extend(self.tokenizer.encode(SCHEMA_CLOSE, add_special_tokens=False))
|
|
288
|
+
tokens.append(sep_text_id)
|
|
289
|
+
|
|
290
|
+
return tokens, label_positions
|
|
291
|
+
|
|
292
|
+
def _build_classification_input(
|
|
293
|
+
self,
|
|
294
|
+
text: str,
|
|
295
|
+
labels: list[str],
|
|
296
|
+
) -> tuple[np.ndarray, np.ndarray, list[int]]:
|
|
297
|
+
"""Build input for classification task."""
|
|
298
|
+
tokens, label_positions = self._build_schema_prefix(CLASSIFICATION_TASK_NAME, labels, TOKEN_L)
|
|
299
|
+
|
|
300
|
+
for match in WORD_PATTERN.finditer(text.lower()):
|
|
301
|
+
tokens.extend(self.tokenizer.encode(match.group(), add_special_tokens=False))
|
|
302
|
+
|
|
303
|
+
input_ids = np.array([tokens], dtype=np.int64)
|
|
304
|
+
attention_mask = np.ones_like(input_ids)
|
|
305
|
+
|
|
306
|
+
return input_ids, attention_mask, label_positions
|
|
307
|
+
|
|
308
|
+
def extract_entities(
|
|
309
|
+
self,
|
|
310
|
+
text: str,
|
|
311
|
+
labels: list[str],
|
|
312
|
+
threshold: float = 0.5,
|
|
313
|
+
) -> list[Entity]:
|
|
314
|
+
"""
|
|
315
|
+
Extract named entities from text.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
text: Text to analyze
|
|
319
|
+
labels: Entity types to extract (e.g., ["person", "organization"])
|
|
320
|
+
threshold: Minimum confidence score
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
List of Entity objects with text, label, position, and score
|
|
324
|
+
"""
|
|
325
|
+
if not text or not text.strip():
|
|
326
|
+
raise ValueError("Text cannot be empty")
|
|
327
|
+
if not labels:
|
|
328
|
+
raise ValueError("Labels cannot be empty")
|
|
329
|
+
|
|
330
|
+
input_ids, attention_mask, e_positions, word_offsets, text_start_idx, first_token_positions = self._build_ner_input(text, labels)
|
|
331
|
+
|
|
332
|
+
hidden_states = self._encode(input_ids, attention_mask)
|
|
333
|
+
label_embeddings = hidden_states[0, e_positions, :]
|
|
334
|
+
text_hidden = hidden_states[0, text_start_idx:, :]
|
|
335
|
+
|
|
336
|
+
num_words = len(word_offsets)
|
|
337
|
+
if num_words == 0:
|
|
338
|
+
return []
|
|
339
|
+
|
|
340
|
+
word_span_start, word_span_end = self._generate_spans(num_words)
|
|
341
|
+
|
|
342
|
+
token_span_start = np.array([first_token_positions[i] for i in word_span_start], dtype=np.int64)
|
|
343
|
+
token_span_end = np.array([first_token_positions[i] for i in word_span_end], dtype=np.int64)
|
|
344
|
+
|
|
345
|
+
span_rep = self._get_span_rep(
|
|
346
|
+
text_hidden[np.newaxis, :, :],
|
|
347
|
+
token_span_start[np.newaxis, :],
|
|
348
|
+
token_span_end[np.newaxis, :],
|
|
349
|
+
)[0]
|
|
350
|
+
|
|
351
|
+
scores = self._compute_span_label_scores(span_rep, label_embeddings)
|
|
352
|
+
|
|
353
|
+
entities = self._collect_entities(scores, word_span_start, word_span_end, word_offsets, labels, text, threshold)
|
|
354
|
+
|
|
355
|
+
return self._deduplicate_entities(entities)
|
|
356
|
+
|
|
357
|
+
def _build_ner_input(
|
|
358
|
+
self,
|
|
359
|
+
text: str,
|
|
360
|
+
labels: list[str],
|
|
361
|
+
) -> tuple[np.ndarray, np.ndarray, list[int], list[tuple[int, int]], int, list[int]]:
|
|
362
|
+
"""Build input for NER task with word-level span support."""
|
|
363
|
+
tokens, e_positions = self._build_schema_prefix(NER_TASK_NAME, labels, TOKEN_E)
|
|
364
|
+
text_start_idx = len(tokens)
|
|
365
|
+
|
|
366
|
+
word_offsets: list[tuple[int, int]] = []
|
|
367
|
+
first_token_positions: list[int] = []
|
|
368
|
+
token_idx = 0
|
|
369
|
+
|
|
370
|
+
for match in WORD_PATTERN.finditer(text.lower()):
|
|
371
|
+
word_offsets.append((match.start(), match.end()))
|
|
372
|
+
first_token_positions.append(token_idx)
|
|
373
|
+
|
|
374
|
+
word_tokens = self.tokenizer.encode(match.group(), add_special_tokens=False)
|
|
375
|
+
tokens.extend(word_tokens)
|
|
376
|
+
token_idx += len(word_tokens)
|
|
377
|
+
|
|
378
|
+
input_ids = np.array([tokens], dtype=np.int64)
|
|
379
|
+
attention_mask = np.ones_like(input_ids)
|
|
380
|
+
|
|
381
|
+
return input_ids, attention_mask, e_positions, word_offsets, text_start_idx, first_token_positions
|
|
382
|
+
|
|
383
|
+
def _generate_spans(self, seq_len: int) -> tuple[np.ndarray, np.ndarray]:
|
|
384
|
+
"""Generate all valid span (start, end) pairs up to max_width."""
|
|
385
|
+
start_indices = []
|
|
386
|
+
end_indices = []
|
|
387
|
+
|
|
388
|
+
for i in range(seq_len):
|
|
389
|
+
for j in range(min(self.max_width, seq_len - i)):
|
|
390
|
+
start_indices.append(i)
|
|
391
|
+
end_indices.append(i + j)
|
|
392
|
+
|
|
393
|
+
return np.array(start_indices, dtype=np.int64), np.array(end_indices, dtype=np.int64)
|
|
394
|
+
|
|
395
|
+
def _collect_entities(
|
|
396
|
+
self,
|
|
397
|
+
scores: np.ndarray,
|
|
398
|
+
word_span_start: np.ndarray,
|
|
399
|
+
word_span_end: np.ndarray,
|
|
400
|
+
word_offsets: list[tuple[int, int]],
|
|
401
|
+
labels: list[str],
|
|
402
|
+
text: str,
|
|
403
|
+
threshold: float,
|
|
404
|
+
) -> list[Entity]:
|
|
405
|
+
"""Collect entities that exceed the threshold."""
|
|
406
|
+
entities = []
|
|
407
|
+
|
|
408
|
+
for span_idx in range(word_span_start.shape[0]):
|
|
409
|
+
start_word = word_span_start[span_idx]
|
|
410
|
+
end_word = word_span_end[span_idx]
|
|
411
|
+
|
|
412
|
+
for label_idx, label in enumerate(labels):
|
|
413
|
+
score = scores[span_idx, label_idx]
|
|
414
|
+
if score >= threshold:
|
|
415
|
+
char_start = word_offsets[start_word][0]
|
|
416
|
+
char_end = word_offsets[end_word][1]
|
|
417
|
+
entities.append(
|
|
418
|
+
Entity(
|
|
419
|
+
text=text[char_start:char_end],
|
|
420
|
+
label=label,
|
|
421
|
+
start=char_start,
|
|
422
|
+
end=char_end,
|
|
423
|
+
score=float(score),
|
|
424
|
+
)
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
return entities
|
|
428
|
+
|
|
429
|
+
def _deduplicate_entities(self, entities: list[Entity]) -> list[Entity]:
|
|
430
|
+
"""Remove overlapping entities of the same label, keeping highest score."""
|
|
431
|
+
if not entities:
|
|
432
|
+
return []
|
|
433
|
+
|
|
434
|
+
sorted_entities = sorted(entities, key=lambda e: e.score, reverse=True)
|
|
435
|
+
kept: list[Entity] = []
|
|
436
|
+
|
|
437
|
+
for entity in sorted_entities:
|
|
438
|
+
overlaps = any(entity.label == kept_entity.label and entity.start < kept_entity.end and entity.end > kept_entity.start for kept_entity in kept)
|
|
439
|
+
if not overlaps:
|
|
440
|
+
kept.append(entity)
|
|
441
|
+
|
|
442
|
+
return kept
|
|
443
|
+
|
|
444
|
+
def _encode(self, input_ids: np.ndarray, attention_mask: np.ndarray) -> np.ndarray:
|
|
445
|
+
"""Run encoder on inputs."""
|
|
446
|
+
result: np.ndarray = self.encoder.run(
|
|
447
|
+
None,
|
|
448
|
+
{ONNX_INPUT_IDS: input_ids, ONNX_ATTENTION_MASK: attention_mask},
|
|
449
|
+
)[0]
|
|
450
|
+
return result
|
|
451
|
+
|
|
452
|
+
def _get_span_rep(
|
|
453
|
+
self,
|
|
454
|
+
hidden_states: np.ndarray,
|
|
455
|
+
span_start: np.ndarray,
|
|
456
|
+
span_end: np.ndarray,
|
|
457
|
+
) -> np.ndarray:
|
|
458
|
+
"""Get span representations from hidden states."""
|
|
459
|
+
result: np.ndarray = self.span_rep.run(
|
|
460
|
+
None,
|
|
461
|
+
{
|
|
462
|
+
ONNX_HIDDEN_STATES: hidden_states.astype(np.float32),
|
|
463
|
+
ONNX_SPAN_START_IDX: span_start,
|
|
464
|
+
ONNX_SPAN_END_IDX: span_end,
|
|
465
|
+
},
|
|
466
|
+
)[0]
|
|
467
|
+
return result
|
|
468
|
+
|
|
469
|
+
def _compute_span_label_scores(
|
|
470
|
+
self,
|
|
471
|
+
span_rep: np.ndarray,
|
|
472
|
+
label_embeddings: np.ndarray,
|
|
473
|
+
) -> np.ndarray:
|
|
474
|
+
"""Compute similarity scores between spans and labels."""
|
|
475
|
+
transformed_labels = self.count_embed.run(
|
|
476
|
+
None,
|
|
477
|
+
{ONNX_LABEL_EMBEDDINGS: label_embeddings.astype(np.float32)},
|
|
478
|
+
)[0]
|
|
479
|
+
|
|
480
|
+
scores = np.einsum("sh,lh->sl", span_rep, transformed_labels)
|
|
481
|
+
return self._sigmoid(scores)
|
|
482
|
+
|
|
483
|
+
@staticmethod
|
|
484
|
+
def _sigmoid(x: np.ndarray) -> np.ndarray:
|
|
485
|
+
"""Numerically stable sigmoid."""
|
|
486
|
+
result = np.empty_like(x)
|
|
487
|
+
pos_mask = x >= 0
|
|
488
|
+
neg_mask = ~pos_mask
|
|
489
|
+
result[pos_mask] = 1 / (1 + np.exp(-x[pos_mask]))
|
|
490
|
+
exp_x = np.exp(x[neg_mask])
|
|
491
|
+
result[neg_mask] = exp_x / (1 + exp_x)
|
|
492
|
+
return result
|
|
493
|
+
|
|
494
|
+
@staticmethod
|
|
495
|
+
def _softmax(x: np.ndarray) -> np.ndarray:
|
|
496
|
+
"""Numerically stable softmax."""
|
|
497
|
+
x_max = float(np.max(x))
|
|
498
|
+
exp_x = np.exp(x - x_max)
|
|
499
|
+
return np.asarray(exp_x / np.sum(exp_x))
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Type definitions for GLiNER2 ONNX runtime."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TypedDict
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class OnnxModelFiles(TypedDict):
|
|
8
|
+
"""ONNX model file paths for a single precision level."""
|
|
9
|
+
|
|
10
|
+
encoder: str
|
|
11
|
+
classifier: str
|
|
12
|
+
span_rep: str
|
|
13
|
+
count_embed: str
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GLiNER2Config(TypedDict):
|
|
17
|
+
"""GLiNER2 ONNX configuration schema."""
|
|
18
|
+
|
|
19
|
+
max_width: int
|
|
20
|
+
special_tokens: dict[str, int]
|
|
21
|
+
onnx_files: dict[str, OnnxModelFiles] # precision -> model files
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class Entity:
|
|
26
|
+
"""Extracted named entity.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
text: The entity text as it appears in the source
|
|
30
|
+
label: The entity label/type
|
|
31
|
+
start: Character offset where entity begins
|
|
32
|
+
end: Character offset where entity ends
|
|
33
|
+
score: Confidence score (0-1)
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
text: str
|
|
37
|
+
label: str
|
|
38
|
+
start: int
|
|
39
|
+
end: int
|
|
40
|
+
score: float
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling", "hatch-vcs"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "gliner2-onnx"
|
|
7
|
+
dynamic = ["version"]
|
|
8
|
+
description = "GLiNER2 ONNX runtime for NER and classification without PyTorch"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = "MIT"
|
|
11
|
+
requires-python = ">=3.10"
|
|
12
|
+
authors = [{ name = "lmoe" }]
|
|
13
|
+
keywords = ["ner", "named-entity-recognition", "gliner", "onnx", "nlp", "zero-shot"]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 4 - Beta",
|
|
16
|
+
"Intended Audience :: Developers",
|
|
17
|
+
"Intended Audience :: Science/Research",
|
|
18
|
+
"License :: OSI Approved :: MIT License",
|
|
19
|
+
"Programming Language :: Python :: 3",
|
|
20
|
+
"Programming Language :: Python :: 3.10",
|
|
21
|
+
"Programming Language :: Python :: 3.11",
|
|
22
|
+
"Programming Language :: Python :: 3.12",
|
|
23
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
24
|
+
"Topic :: Text Processing :: Linguistic",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
dependencies = [
|
|
28
|
+
"onnxruntime>=1.18.0",
|
|
29
|
+
"transformers>=4.40.0",
|
|
30
|
+
"numpy>=1.26.0",
|
|
31
|
+
"huggingface-hub>=0.23.0"
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
[project.urls]
|
|
35
|
+
Homepage = "https://github.com/lmoe/gliner2-onnx"
|
|
36
|
+
Repository = "https://github.com/lmoe/gliner2-onnx"
|
|
37
|
+
Issues = "https://github.com/lmoe/gliner2-onnx/issues"
|
|
38
|
+
|
|
39
|
+
[project.optional-dependencies]
|
|
40
|
+
# For exporting models (not needed for inference)
|
|
41
|
+
export = [
|
|
42
|
+
"gliner2>=1.2.4",
|
|
43
|
+
"torch>=2.0.0",
|
|
44
|
+
"onnx>=1.14.0,<1.18",
|
|
45
|
+
"onnxscript>=0.6.0",
|
|
46
|
+
"onnxconverter-common>=1.14.0",
|
|
47
|
+
"urllib3>=2.6.3",
|
|
48
|
+
"requests>=2.32.5"
|
|
49
|
+
]
|
|
50
|
+
test = ["pytest"]
|
|
51
|
+
|
|
52
|
+
[tool.hatch.build.targets.wheel]
|
|
53
|
+
packages = ["gliner2_onnx"]
|
|
54
|
+
|
|
55
|
+
[tool.hatch.version]
|
|
56
|
+
source = "vcs"
|
|
57
|
+
|
|
58
|
+
[tool.hatch.build.hooks.vcs]
|
|
59
|
+
version-file = "gliner2_onnx/_version.py"
|
|
60
|
+
|
|
61
|
+
[tool.hatch.build.targets.sdist]
|
|
62
|
+
include = [
|
|
63
|
+
"/gliner2_onnx/**/*.py",
|
|
64
|
+
"/README.md",
|
|
65
|
+
"/pyproject.toml",
|
|
66
|
+
]
|
|
67
|
+
exclude = [
|
|
68
|
+
"/.gitignore",
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
[dependency-groups]
|
|
72
|
+
dev = ["ruff", "mypy"]
|
|
73
|
+
|
|
74
|
+
[tool.ruff]
|
|
75
|
+
target-version = "py310"
|
|
76
|
+
line-length = 160
|
|
77
|
+
exclude = ["gliner2_onnx/_version.py"]
|
|
78
|
+
|
|
79
|
+
[tool.ruff.lint]
|
|
80
|
+
select = [
|
|
81
|
+
"E", # pycodestyle errors
|
|
82
|
+
"W", # pycodestyle warnings
|
|
83
|
+
"F", # pyflakes
|
|
84
|
+
"I", # isort
|
|
85
|
+
"B", # flake8-bugbear
|
|
86
|
+
"C4", # flake8-comprehensions
|
|
87
|
+
"UP", # pyupgrade
|
|
88
|
+
"ARG", # flake8-unused-arguments
|
|
89
|
+
"SIM", # flake8-simplify
|
|
90
|
+
"TCH", # flake8-type-checking
|
|
91
|
+
"PTH", # flake8-use-pathlib
|
|
92
|
+
"PL", # pylint
|
|
93
|
+
"RUF", # ruff-specific
|
|
94
|
+
"PERF", # perflint
|
|
95
|
+
"S", # bandit security
|
|
96
|
+
"A", # shadowing builtins
|
|
97
|
+
"T20", # print statements
|
|
98
|
+
"ERA", # commented-out code
|
|
99
|
+
"TRY", # exception handling
|
|
100
|
+
"RET", # return statements
|
|
101
|
+
"SLF", # private member access
|
|
102
|
+
"FBT", # boolean trap
|
|
103
|
+
"G", # logging format
|
|
104
|
+
"PIE", # misc checks
|
|
105
|
+
"FLY", # f-strings
|
|
106
|
+
]
|
|
107
|
+
ignore = [
|
|
108
|
+
"PLR0913", # too many arguments
|
|
109
|
+
"PLR2004", # magic value comparison
|
|
110
|
+
"PLR0911", # too many return statements
|
|
111
|
+
"PLR0915", # too many statements
|
|
112
|
+
"PLC0415", # import outside top-level
|
|
113
|
+
"S101", # assert used
|
|
114
|
+
"TRY003", # long exception messages
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
[tool.ruff.lint.isort]
|
|
118
|
+
known-first-party = ["gliner2_onnx"]
|
|
119
|
+
|
|
120
|
+
[tool.ruff.lint.per-file-ignores]
|
|
121
|
+
"tests/*" = ["T201"]
|
|
122
|
+
"tools/*" = ["T201"]
|
|
123
|
+
|
|
124
|
+
[tool.mypy]
|
|
125
|
+
python_version = "3.10"
|
|
126
|
+
strict = true
|
|
127
|
+
exclude = ["gliner2_onnx/_version\\.py$"]
|
|
128
|
+
warn_return_any = true
|
|
129
|
+
warn_unused_ignores = true
|
|
130
|
+
disallow_untyped_defs = true
|
|
131
|
+
disallow_incomplete_defs = true
|
|
132
|
+
disallow_any_explicit = true
|
|
133
|
+
disallow_any_generics = true
|
|
134
|
+
strict_equality = true
|
|
135
|
+
warn_unreachable = true
|
|
136
|
+
|
|
137
|
+
[[tool.mypy.overrides]]
|
|
138
|
+
module = [
|
|
139
|
+
"onnxruntime.*",
|
|
140
|
+
"onnx.*",
|
|
141
|
+
"onnxscript.*",
|
|
142
|
+
"onnxconverter_common.*",
|
|
143
|
+
"transformers.*",
|
|
144
|
+
"gliner2.*",
|
|
145
|
+
"torch.*",
|
|
146
|
+
"huggingface_hub.*",
|
|
147
|
+
]
|
|
148
|
+
ignore_missing_imports = true
|
|
149
|
+
|
|
150
|
+
[[tool.mypy.overrides]]
|
|
151
|
+
module = "gliner2_onnx.runtime"
|
|
152
|
+
# ONNX runtime returns untyped data, numpy arrays have complex typing
|
|
153
|
+
disallow_any_explicit = false
|
|
154
|
+
|
|
155
|
+
[[tool.mypy.overrides]]
|
|
156
|
+
module = ["tools.export_model", "tools.export_count_embed"]
|
|
157
|
+
# torch.nn.Module is typed as Any, causing cascading type errors
|
|
158
|
+
disallow_any_explicit = false
|
|
159
|
+
disable_error_code = ["misc", "no-any-return", "arg-type", "operator", "union-attr", "index"]
|