isage-tooluse-sias 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_tooluse_sias-0.1.0.dist-info/METADATA +186 -0
- isage_tooluse_sias-0.1.0.dist-info/RECORD +10 -0
- isage_tooluse_sias-0.1.0.dist-info/WHEEL +5 -0
- isage_tooluse_sias-0.1.0.dist-info/licenses/LICENSE +21 -0
- isage_tooluse_sias-0.1.0.dist-info/top_level.txt +1 -0
- sage_sias/__init__.py +30 -0
- sage_sias/continual_learner.py +184 -0
- sage_sias/coreset_selector.py +302 -0
- sage_sias/setup.py +6 -0
- sage_sias/types.py +94 -0
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: isage-tooluse-sias
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: SAGE Tool Use SIAS - Sample-Importance-Aware Selection for tool selection and agent training
|
|
5
|
+
Author-email: IntelliStream Team <shuhao_zhang@hust.edu.cn>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/intellistream/sage-sias
|
|
8
|
+
Project-URL: Repository, https://github.com/intellistream/sage-sias
|
|
9
|
+
Project-URL: Documentation, https://github.com/intellistream/sage-sias#readme
|
|
10
|
+
Project-URL: Bug Tracker, https://github.com/intellistream/sage-sias/issues
|
|
11
|
+
Keywords: sias,tool-selection,tool-use,sample-selection,importance-aware,continual-learning,coreset,active-learning,agent-training
|
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
22
|
+
Requires-Python: >=3.10
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
License-File: LICENSE
|
|
25
|
+
Requires-Dist: numpy>=1.20.0
|
|
26
|
+
Requires-Dist: typing-extensions>=4.0.0
|
|
27
|
+
Provides-Extra: dev
|
|
28
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
29
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
|
30
|
+
Requires-Dist: ruff>=0.8.4; extra == "dev"
|
|
31
|
+
Requires-Dist: mypy>=1.0.0; extra == "dev"
|
|
32
|
+
Provides-Extra: torch
|
|
33
|
+
Requires-Dist: torch>=2.0.0; extra == "torch"
|
|
34
|
+
Provides-Extra: all
|
|
35
|
+
Requires-Dist: isage-sias[dev,torch]; extra == "all"
|
|
36
|
+
Dynamic: license-file
|
|
37
|
+
|
|
38
|
+
# SAGE Tool Use SIAS (Sample-Importance-Aware Selection)
|
|
39
|
+
|
|
40
|
+
**Tool selection algorithm using sample-importance-aware selection for agent training and tool curation**
|
|
41
|
+
|
|
42
|
+
[](https://badge.fury.io/py/isage-tooluse-sias)
|
|
43
|
+
[](https://www.python.org/downloads/)
|
|
44
|
+
[](https://opensource.org/licenses/MIT)
|
|
45
|
+
|
|
46
|
+
## 🎯 Overview
|
|
47
|
+
|
|
48
|
+
`sage-tooluse-sias` provides Sample-Importance-Aware Selection algorithms specifically designed for:
|
|
49
|
+
|
|
50
|
+
- **Tool Selection**: Select important tools for agent use
|
|
51
|
+
- **Agent Training**: Select important trajectories for fine-tuning
|
|
52
|
+
- **Continual Learning**: Efficient sample selection for continual/lifelong learning
|
|
53
|
+
- **Tool/Trajectory Curation**: Curate representative samples for agent development
|
|
54
|
+
|
|
55
|
+
## 📦 Installation
|
|
56
|
+
|
|
57
|
+
```bash
|
|
58
|
+
# Basic installation
|
|
59
|
+
pip install isage-tooluse-sias
|
|
60
|
+
|
|
61
|
+
# With PyTorch support
|
|
62
|
+
pip install isage-tooluse-sias[torch]
|
|
63
|
+
|
|
64
|
+
# Development installation
|
|
65
|
+
pip install isage-tooluse-sias[dev]
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
## 🚀 Quick Start
|
|
69
|
+
|
|
70
|
+
### Continual Learning
|
|
71
|
+
|
|
72
|
+
```python
|
|
73
|
+
from sage_sias import ContinualLearner
|
|
74
|
+
|
|
75
|
+
# Create continual learner
|
|
76
|
+
learner = ContinualLearner(
|
|
77
|
+
buffer_size=1000,
|
|
78
|
+
selection_strategy="importance"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Add samples
|
|
82
|
+
for data, label in stream:
|
|
83
|
+
learner.add_sample(data, label)
|
|
84
|
+
|
|
85
|
+
# Get selected samples
|
|
86
|
+
important_samples = learner.get_buffer()
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
### Coreset Selection
|
|
90
|
+
|
|
91
|
+
```python
|
|
92
|
+
from sage_sias import CoresetSelector
|
|
93
|
+
|
|
94
|
+
# Create coreset selector
|
|
95
|
+
selector = CoresetSelector(
|
|
96
|
+
target_size=100,
|
|
97
|
+
method="kmeans++"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Select representative samples
|
|
101
|
+
coreset = selector.select(dataset, features)
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
## 📚 Key Components
|
|
105
|
+
|
|
106
|
+
### 1. **Continual Learner** (`continual_learner.py`)
|
|
107
|
+
|
|
108
|
+
Manages sample selection for continual learning:
|
|
109
|
+
- Buffer management with importance-based eviction
|
|
110
|
+
- Multiple selection strategies (random, importance, diversity)
|
|
111
|
+
- Support for experience replay
|
|
112
|
+
|
|
113
|
+
### 2. **Coreset Selector** (`coreset_selector.py`)
|
|
114
|
+
|
|
115
|
+
Selects representative subsets:
|
|
116
|
+
- K-means++ based selection
|
|
117
|
+
- Diversity-aware sampling
|
|
118
|
+
- Importance scoring
|
|
119
|
+
- Support for large-scale datasets
|
|
120
|
+
|
|
121
|
+
### 3. **Types** (`types.py`)
|
|
122
|
+
|
|
123
|
+
Common data types and protocols:
|
|
124
|
+
- Sample representation
|
|
125
|
+
- Importance scoring interfaces
|
|
126
|
+
- Selection strategies
|
|
127
|
+
|
|
128
|
+
## 🔧 Architecture
|
|
129
|
+
|
|
130
|
+
```
|
|
131
|
+
sage_sias/
|
|
132
|
+
├── continual_learner.py # Continual learning with buffer management
|
|
133
|
+
├── coreset_selector.py # Coreset selection algorithms
|
|
134
|
+
├── types.py # Common types and protocols
|
|
135
|
+
└── __init__.py # Public API exports
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
## 🎓 Use Cases
|
|
139
|
+
|
|
140
|
+
1. **Agent Training**: Select important trajectories for fine-tuning
|
|
141
|
+
2. **Data Pruning**: Reduce dataset size while maintaining performance
|
|
142
|
+
3. **Active Learning**: Query most informative samples
|
|
143
|
+
4. **Memory Management**: Maintain representative samples in limited buffers
|
|
144
|
+
5. **Transfer Learning**: Select relevant samples for adaptation
|
|
145
|
+
|
|
146
|
+
## 🔗 Integration with SAGE
|
|
147
|
+
|
|
148
|
+
This package is part of the SAGE ecosystem but can be used independently:
|
|
149
|
+
|
|
150
|
+
```python
|
|
151
|
+
# Standalone usage
|
|
152
|
+
from sage_sias import ContinualLearner, CoresetSelector
|
|
153
|
+
|
|
154
|
+
# With SAGE agentic (optional)
|
|
155
|
+
from sage_agentic import AgentTrainer
|
|
156
|
+
from sage_sias import CoresetSelector
|
|
157
|
+
|
|
158
|
+
trainer = AgentTrainer()
|
|
159
|
+
selector = CoresetSelector(target_size=100)
|
|
160
|
+
important_trajectories = selector.select(all_trajectories)
|
|
161
|
+
trainer.train(important_trajectories)
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
## 📖 Documentation
|
|
165
|
+
|
|
166
|
+
- **Repository**: https://github.com/intellistream/sage-tooluse-sias
|
|
167
|
+
- **SAGE Documentation**: https://intellistream.github.io/SAGE-Pub/
|
|
168
|
+
- **Issues**: https://github.com/intellistream/sage-tooluse-sias/issues
|
|
169
|
+
|
|
170
|
+
## 🤝 Contributing
|
|
171
|
+
|
|
172
|
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
|
173
|
+
|
|
174
|
+
## 📄 License
|
|
175
|
+
|
|
176
|
+
MIT License - see [LICENSE](LICENSE) file for details.
|
|
177
|
+
|
|
178
|
+
## 🙏 Acknowledgments
|
|
179
|
+
|
|
180
|
+
Originally part of the [SAGE](https://github.com/intellistream/SAGE) framework, now maintained as an independent package for broader community use.
|
|
181
|
+
|
|
182
|
+
## 📧 Contact
|
|
183
|
+
|
|
184
|
+
- **Team**: IntelliStream Team
|
|
185
|
+
- **Email**: shuhao_zhang@hust.edu.cn
|
|
186
|
+
- **GitHub**: https://github.com/intellistream
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
isage_tooluse_sias-0.1.0.dist-info/licenses/LICENSE,sha256=FviS_SBgJ3MCvV517X_20AfytowsiItMOu_l3GJWFnI,1080
|
|
2
|
+
sage_sias/__init__.py,sha256=WG9xoaAvrYdyFFvC1g34mVL_IBapR2ix0VN9aazs--o,962
|
|
3
|
+
sage_sias/continual_learner.py,sha256=LA_o2k81VO287jLvt7HmFvAhNtpJv-Zz7xjXqi0YPAo,6240
|
|
4
|
+
sage_sias/coreset_selector.py,sha256=L8JJfUvFwQl0GQOSWVZ2bl9uxSB7q-WCuoQpiawU9TU,10756
|
|
5
|
+
sage_sias/setup.py,sha256=GEYqJMCRQHES0U_92Uo9gUY9H-XnqiHX1UJ28fDVDJk,122
|
|
6
|
+
sage_sias/types.py,sha256=iAlEU2Obxp5F0iTxpwGXLoOA3LyiOKqY-Oqm133tfQw,2418
|
|
7
|
+
isage_tooluse_sias-0.1.0.dist-info/METADATA,sha256=um7pu6_rSVvQ2WTxE4x-o8DfGmO0QPrSq1fa-lgsNV0,5897
|
|
8
|
+
isage_tooluse_sias-0.1.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
9
|
+
isage_tooluse_sias-0.1.0.dist-info/top_level.txt,sha256=OCb1-BhPYZaVCErrA9_R5Eln0UgS6kq2KiEER_c87vY,10
|
|
10
|
+
isage_tooluse_sias-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025-2026 IntelliStream Team
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
sage_sias
|
sage_sias/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""SIAS - Streaming Importance-Aware Agent System.
|
|
2
|
+
|
|
3
|
+
Core components for agent tool selection and sample importance:
|
|
4
|
+
- CoresetSelector: Importance-aware sample selection
|
|
5
|
+
- OnlineContinualLearner: Experience replay with importance weighting
|
|
6
|
+
- SelectionSummary: Summary statistics for selection operations
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from sage_agentic.sias import CoresetSelector, OnlineContinualLearner
|
|
10
|
+
|
|
11
|
+
selector = CoresetSelector(strategy="hybrid")
|
|
12
|
+
selected = selector.select(samples, target_size=1000)
|
|
13
|
+
|
|
14
|
+
learner = OnlineContinualLearner(buffer_size=2048, replay_ratio=0.25)
|
|
15
|
+
batch = learner.update_buffer(new_samples)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from .coreset_selector import CoresetSelector, SelectionSummary
|
|
19
|
+
from .continual_learner import OnlineContinualLearner
|
|
20
|
+
from .types import ImportanceScore, SampleWithImportance
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"CoresetSelector",
|
|
24
|
+
"OnlineContinualLearner",
|
|
25
|
+
"SelectionSummary",
|
|
26
|
+
"ImportanceScore",
|
|
27
|
+
"SampleWithImportance",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Online Continual Learning with Experience Replay
|
|
3
|
+
|
|
4
|
+
Implements an experience replay buffer for online/incremental training that
|
|
5
|
+
prevents catastrophic forgetting. The buffer is managed using coreset selection
|
|
6
|
+
to retain the most valuable samples.
|
|
7
|
+
|
|
8
|
+
This is a core component of SIAS (Streaming Importance-Aware Agent System).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import random
|
|
14
|
+
from typing import Iterable, Optional, Sequence
|
|
15
|
+
|
|
16
|
+
from .coreset_selector import CoresetSelector, SampleT, SelectionSummary
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OnlineContinualLearner:
|
|
20
|
+
"""
|
|
21
|
+
Maintain a replay buffer for online continual learning.
|
|
22
|
+
|
|
23
|
+
Implements experience replay to prevent catastrophic forgetting during
|
|
24
|
+
incremental/online training. The buffer is managed using coreset selection
|
|
25
|
+
to keep the most valuable samples.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
buffer_size: Maximum number of samples to keep in buffer
|
|
29
|
+
replay_ratio: Ratio of replay samples to add per batch (e.g., 0.25 = 25%)
|
|
30
|
+
selector: CoresetSelector for buffer management
|
|
31
|
+
|
|
32
|
+
Example:
|
|
33
|
+
>>> learner = OnlineContinualLearner(buffer_size=2048, replay_ratio=0.25)
|
|
34
|
+
>>> for new_batch in data_stream:
|
|
35
|
+
... training_batch = learner.update_buffer(new_batch)
|
|
36
|
+
... train_step(training_batch)
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
buffer_size: int = 2048,
|
|
42
|
+
replay_ratio: float = 0.3,
|
|
43
|
+
selector: Optional[CoresetSelector] = None,
|
|
44
|
+
random_seed: int = 17,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""
|
|
47
|
+
Initialize OnlineContinualLearner.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
buffer_size: Maximum samples to keep in replay buffer
|
|
51
|
+
replay_ratio: Fraction of batch size to sample from buffer
|
|
52
|
+
selector: CoresetSelector for buffer management (default: hybrid)
|
|
53
|
+
random_seed: Random seed for reproducibility
|
|
54
|
+
"""
|
|
55
|
+
self.buffer_size = buffer_size
|
|
56
|
+
self.replay_ratio = replay_ratio
|
|
57
|
+
self.selector = selector or CoresetSelector(strategy="hybrid")
|
|
58
|
+
self._buffer: list[SampleT] = []
|
|
59
|
+
self._metrics: dict[str, float] = {}
|
|
60
|
+
self._rng = random.Random(random_seed)
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def buffer(self) -> list[SampleT]:
|
|
64
|
+
"""Access the current buffer (read-only view)."""
|
|
65
|
+
return list(self._buffer)
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def buffer_len(self) -> int:
|
|
69
|
+
"""Current number of samples in buffer."""
|
|
70
|
+
return len(self._buffer)
|
|
71
|
+
|
|
72
|
+
def update_buffer(
|
|
73
|
+
self,
|
|
74
|
+
new_samples: Sequence[SampleT],
|
|
75
|
+
metrics: Optional[dict[str, float]] = None,
|
|
76
|
+
) -> list[SampleT]:
|
|
77
|
+
"""
|
|
78
|
+
Update buffer with new samples and return training batch.
|
|
79
|
+
|
|
80
|
+
This method:
|
|
81
|
+
1. Adds new samples to the buffer
|
|
82
|
+
2. If buffer exceeds size limit, uses coreset selection to prune
|
|
83
|
+
3. Returns new samples + replay samples for training
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
new_samples: New samples to add to buffer
|
|
87
|
+
metrics: Optional metrics dict mapping sample_id to importance score
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Training batch combining new samples with replay samples
|
|
91
|
+
"""
|
|
92
|
+
if not new_samples:
|
|
93
|
+
return list(self._buffer)
|
|
94
|
+
|
|
95
|
+
if metrics:
|
|
96
|
+
self._metrics.update(metrics)
|
|
97
|
+
|
|
98
|
+
# Combine buffer with new samples
|
|
99
|
+
combined = list(self._buffer) + list(new_samples)
|
|
100
|
+
|
|
101
|
+
# Prune if over capacity
|
|
102
|
+
if len(combined) > self.buffer_size:
|
|
103
|
+
combined = self.selector.select(
|
|
104
|
+
combined,
|
|
105
|
+
target_size=self.buffer_size,
|
|
106
|
+
metrics=self._metrics,
|
|
107
|
+
)
|
|
108
|
+
# Clean up metrics for removed samples
|
|
109
|
+
combined_ids = {self._get_sample_id(sample) for sample in combined}
|
|
110
|
+
self._metrics = {k: v for k, v in self._metrics.items() if k in combined_ids}
|
|
111
|
+
|
|
112
|
+
self._buffer = combined
|
|
113
|
+
return self._assemble_training_batch(new_samples)
|
|
114
|
+
|
|
115
|
+
def _assemble_training_batch(
|
|
116
|
+
self,
|
|
117
|
+
new_samples: Sequence[SampleT],
|
|
118
|
+
) -> list[SampleT]:
|
|
119
|
+
"""Combine new samples with replay samples."""
|
|
120
|
+
new_ids = {self._get_sample_id(s) for s in new_samples}
|
|
121
|
+
replay = self.sample_replay(len(new_samples), exclude=new_ids)
|
|
122
|
+
return list(new_samples) + replay
|
|
123
|
+
|
|
124
|
+
def sample_replay(
|
|
125
|
+
self,
|
|
126
|
+
new_batch_size: int,
|
|
127
|
+
*,
|
|
128
|
+
exclude: Optional[Iterable[str]] = None,
|
|
129
|
+
) -> list[SampleT]:
|
|
130
|
+
"""
|
|
131
|
+
Sample from replay buffer.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
new_batch_size: Size of new batch (replay size = batch_size * ratio)
|
|
135
|
+
exclude: Sample IDs to exclude from replay
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
List of replay samples
|
|
139
|
+
"""
|
|
140
|
+
if not self._buffer or self.replay_ratio <= 0:
|
|
141
|
+
return []
|
|
142
|
+
|
|
143
|
+
exclude = set(exclude or [])
|
|
144
|
+
available = [
|
|
145
|
+
sample for sample in self._buffer if self._get_sample_id(sample) not in exclude
|
|
146
|
+
]
|
|
147
|
+
if not available:
|
|
148
|
+
return []
|
|
149
|
+
|
|
150
|
+
replay_size = max(1, int(new_batch_size * self.replay_ratio))
|
|
151
|
+
replay_size = min(replay_size, len(available))
|
|
152
|
+
return self._rng.sample(available, replay_size)
|
|
153
|
+
|
|
154
|
+
def buffer_snapshot(self) -> list[SampleT]:
|
|
155
|
+
"""Return a copy of the current buffer."""
|
|
156
|
+
return list(self._buffer)
|
|
157
|
+
|
|
158
|
+
def buffer_summary(self) -> SelectionSummary:
|
|
159
|
+
"""Get summary statistics for the buffer."""
|
|
160
|
+
return SelectionSummary(
|
|
161
|
+
total_samples=len(self._buffer),
|
|
162
|
+
selected_samples=len(self._buffer),
|
|
163
|
+
strategy=f"buffer:{self.selector.strategy}",
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def clear(self) -> None:
|
|
167
|
+
"""Clear the buffer and metrics."""
|
|
168
|
+
self._buffer = []
|
|
169
|
+
self._metrics = {}
|
|
170
|
+
|
|
171
|
+
def update_metrics(self, metrics: dict[str, float]) -> None:
|
|
172
|
+
"""
|
|
173
|
+
Update importance metrics for samples in buffer.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
metrics: Dict mapping sample_id to importance score
|
|
177
|
+
"""
|
|
178
|
+
self._metrics.update(metrics)
|
|
179
|
+
|
|
180
|
+
def _get_sample_id(self, sample: SampleT) -> str:
|
|
181
|
+
"""Get sample_id from sample (supports dict or object)."""
|
|
182
|
+
if isinstance(sample, dict):
|
|
183
|
+
return sample.get("sample_id", sample.get("dialog_id", str(id(sample))))
|
|
184
|
+
return getattr(sample, "sample_id", getattr(sample, "dialog_id", str(id(sample))))
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Coreset Selection for Efficient Training
|
|
3
|
+
|
|
4
|
+
Implements lightweight coreset selection strategies that identify the most
|
|
5
|
+
valuable samples for training, reducing computational cost while maintaining
|
|
6
|
+
model quality.
|
|
7
|
+
|
|
8
|
+
Strategies:
|
|
9
|
+
- loss_topk: Select samples with highest loss (most informative)
|
|
10
|
+
- diversity: Select samples maximizing coverage of feature space
|
|
11
|
+
- hybrid: Combination of loss-based and diversity-based selection
|
|
12
|
+
- random: Uniform random sampling (baseline)
|
|
13
|
+
|
|
14
|
+
This is a core component of SIAS (Streaming Importance-Aware Agent System).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import math
|
|
20
|
+
import random
|
|
21
|
+
import re
|
|
22
|
+
from collections import Counter
|
|
23
|
+
from dataclasses import dataclass
|
|
24
|
+
from typing import Any, Optional, Protocol, Sequence, runtime_checkable
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(slots=True)
|
|
28
|
+
class SelectionSummary:
|
|
29
|
+
"""Summary statistics for a selection operation."""
|
|
30
|
+
|
|
31
|
+
total_samples: int
|
|
32
|
+
selected_samples: int
|
|
33
|
+
strategy: str
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@runtime_checkable
|
|
37
|
+
class SampleProtocol(Protocol):
|
|
38
|
+
"""Protocol for samples that can be used with CoresetSelector."""
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def sample_id(self) -> str:
|
|
42
|
+
"""Unique identifier for the sample."""
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def text(self) -> str:
|
|
47
|
+
"""Text content of the sample."""
|
|
48
|
+
...
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def metadata(self) -> dict[str, Any]:
|
|
52
|
+
"""Metadata dictionary."""
|
|
53
|
+
...
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# Type alias for any sample that implements the protocol
|
|
57
|
+
SampleT = Any # Should implement SampleProtocol
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class CoresetSelector:
|
|
61
|
+
"""
|
|
62
|
+
Implements lightweight coreset selection strategies.
|
|
63
|
+
|
|
64
|
+
This class provides several strategies for selecting a representative
|
|
65
|
+
subset of samples from a larger dataset, optimizing for training efficiency.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
strategy: Selection strategy ("loss_topk", "diversity", "hybrid", "random")
|
|
69
|
+
metric_key: Key in metadata to use for loss-based selection
|
|
70
|
+
diversity_temperature: Temperature for diversity scoring
|
|
71
|
+
random_seed: Seed for reproducibility
|
|
72
|
+
|
|
73
|
+
Example:
|
|
74
|
+
>>> selector = CoresetSelector(strategy="hybrid")
|
|
75
|
+
>>> selected = selector.select(samples, target_size=1000)
|
|
76
|
+
>>> print(f"Selected {len(selected)} from {len(samples)} samples")
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
STRATEGIES = ("loss_topk", "diversity", "hybrid", "random")
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
strategy: str = "loss_topk",
|
|
84
|
+
metric_key: str = "loss",
|
|
85
|
+
diversity_temperature: float = 0.7,
|
|
86
|
+
random_seed: int = 13,
|
|
87
|
+
) -> None:
|
|
88
|
+
"""
|
|
89
|
+
Initialize CoresetSelector.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
strategy: Selection strategy to use
|
|
93
|
+
metric_key: Metadata key for loss-based selection
|
|
94
|
+
diversity_temperature: Temperature for diversity scoring
|
|
95
|
+
random_seed: Random seed for reproducibility
|
|
96
|
+
"""
|
|
97
|
+
if strategy not in self.STRATEGIES:
|
|
98
|
+
raise ValueError(f"Unknown strategy: {strategy}. Choose from {self.STRATEGIES}")
|
|
99
|
+
|
|
100
|
+
self.strategy = strategy
|
|
101
|
+
self.metric_key = metric_key
|
|
102
|
+
self.diversity_temperature = diversity_temperature
|
|
103
|
+
self._rng = random.Random(random_seed)
|
|
104
|
+
|
|
105
|
+
# ------------------------------------------------------------------
|
|
106
|
+
# Public API
|
|
107
|
+
# ------------------------------------------------------------------
|
|
108
|
+
def select(
|
|
109
|
+
self,
|
|
110
|
+
samples: Sequence[SampleT],
|
|
111
|
+
*,
|
|
112
|
+
target_size: Optional[int],
|
|
113
|
+
metrics: Optional[dict[str, float]] = None,
|
|
114
|
+
) -> list[SampleT]:
|
|
115
|
+
"""
|
|
116
|
+
Select a subset of samples using the configured strategy.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
samples: Input samples to select from
|
|
120
|
+
target_size: Number of samples to select (None = keep all)
|
|
121
|
+
metrics: Optional external metrics dict mapping sample_id to score
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
List of selected samples
|
|
125
|
+
"""
|
|
126
|
+
if target_size is None or target_size <= 0 or target_size >= len(samples):
|
|
127
|
+
return list(samples)
|
|
128
|
+
|
|
129
|
+
if self.strategy == "loss_topk":
|
|
130
|
+
return self._select_loss(samples, target_size, metrics)
|
|
131
|
+
if self.strategy == "diversity":
|
|
132
|
+
return self._select_diversity(samples, target_size)
|
|
133
|
+
if self.strategy == "hybrid":
|
|
134
|
+
return self._select_hybrid(samples, target_size, metrics)
|
|
135
|
+
return self._select_random(samples, target_size)
|
|
136
|
+
|
|
137
|
+
def summary(self, original_size: int, selected_size: int) -> SelectionSummary:
|
|
138
|
+
"""Create a summary of the selection operation."""
|
|
139
|
+
return SelectionSummary(
|
|
140
|
+
total_samples=original_size,
|
|
141
|
+
selected_samples=selected_size,
|
|
142
|
+
strategy=self.strategy,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# ------------------------------------------------------------------
|
|
146
|
+
# Selection Strategies
|
|
147
|
+
# ------------------------------------------------------------------
|
|
148
|
+
def _select_loss(
|
|
149
|
+
self,
|
|
150
|
+
samples: Sequence[SampleT],
|
|
151
|
+
target_size: int,
|
|
152
|
+
metrics: Optional[dict[str, float]],
|
|
153
|
+
) -> list[SampleT]:
|
|
154
|
+
"""Select samples with highest loss/importance scores."""
|
|
155
|
+
|
|
156
|
+
def score(sample: SampleT) -> float:
|
|
157
|
+
sample_id = self._get_sample_id(sample)
|
|
158
|
+
if metrics and sample_id in metrics:
|
|
159
|
+
return metrics[sample_id]
|
|
160
|
+
meta = self._get_metadata(sample)
|
|
161
|
+
meta_val = meta.get(self.metric_key)
|
|
162
|
+
if isinstance(meta_val, (int, float)):
|
|
163
|
+
return float(meta_val)
|
|
164
|
+
return 0.0
|
|
165
|
+
|
|
166
|
+
ranked = sorted(samples, key=score, reverse=True)
|
|
167
|
+
return list(ranked[:target_size])
|
|
168
|
+
|
|
169
|
+
def _select_random(
|
|
170
|
+
self,
|
|
171
|
+
samples: Sequence[SampleT],
|
|
172
|
+
target_size: int,
|
|
173
|
+
) -> list[SampleT]:
|
|
174
|
+
"""Uniform random sampling."""
|
|
175
|
+
return self._rng.sample(list(samples), target_size)
|
|
176
|
+
|
|
177
|
+
def _select_hybrid(
|
|
178
|
+
self,
|
|
179
|
+
samples: Sequence[SampleT],
|
|
180
|
+
target_size: int,
|
|
181
|
+
metrics: Optional[dict[str, float]],
|
|
182
|
+
) -> list[SampleT]:
|
|
183
|
+
"""Hybrid selection: 60% loss-based + 40% diversity-based."""
|
|
184
|
+
loss_portion = int(target_size * 0.6)
|
|
185
|
+
div_portion = target_size - loss_portion
|
|
186
|
+
|
|
187
|
+
# First select high-loss samples
|
|
188
|
+
top_loss = self._select_loss(samples, loss_portion or 1, metrics)
|
|
189
|
+
top_loss_ids = {self._get_sample_id(s) for s in top_loss}
|
|
190
|
+
|
|
191
|
+
# Then select diverse samples from remaining
|
|
192
|
+
remaining = [s for s in samples if self._get_sample_id(s) not in top_loss_ids]
|
|
193
|
+
if not remaining:
|
|
194
|
+
return top_loss
|
|
195
|
+
|
|
196
|
+
diversity = self._select_diversity(remaining, max(div_portion, 1))
|
|
197
|
+
merged = (top_loss + diversity)[:target_size]
|
|
198
|
+
return merged
|
|
199
|
+
|
|
200
|
+
def _select_diversity(
|
|
201
|
+
self,
|
|
202
|
+
samples: Sequence[SampleT],
|
|
203
|
+
target_size: int,
|
|
204
|
+
) -> list[SampleT]:
|
|
205
|
+
"""Select samples maximizing feature space coverage."""
|
|
206
|
+
if not samples:
|
|
207
|
+
return []
|
|
208
|
+
|
|
209
|
+
# Extract features for all samples
|
|
210
|
+
features = {
|
|
211
|
+
self._get_sample_id(sample): self._text_features(self._get_text(sample))
|
|
212
|
+
for sample in samples
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
selected: list[SampleT] = []
|
|
216
|
+
candidates = list(samples)
|
|
217
|
+
|
|
218
|
+
# Start with the sample that has the highest token variance
|
|
219
|
+
scores = {
|
|
220
|
+
self._get_sample_id(sample): self._feature_norm(features[self._get_sample_id(sample)])
|
|
221
|
+
for sample in samples
|
|
222
|
+
}
|
|
223
|
+
first = max(candidates, key=lambda s: scores.get(self._get_sample_id(s), 0.0))
|
|
224
|
+
selected.append(first)
|
|
225
|
+
candidates = [s for s in candidates if self._get_sample_id(s) != self._get_sample_id(first)]
|
|
226
|
+
|
|
227
|
+
# Iteratively select most diverse samples
|
|
228
|
+
while candidates and len(selected) < target_size:
|
|
229
|
+
best_candidate = max(
|
|
230
|
+
candidates,
|
|
231
|
+
key=lambda sample: self._min_distance(sample, selected, features),
|
|
232
|
+
)
|
|
233
|
+
selected.append(best_candidate)
|
|
234
|
+
candidates = [
|
|
235
|
+
s
|
|
236
|
+
for s in candidates
|
|
237
|
+
if self._get_sample_id(s) != self._get_sample_id(best_candidate)
|
|
238
|
+
]
|
|
239
|
+
|
|
240
|
+
return selected
|
|
241
|
+
|
|
242
|
+
# ------------------------------------------------------------------
|
|
243
|
+
# Feature Extraction Helpers
|
|
244
|
+
# ------------------------------------------------------------------
|
|
245
|
+
def _text_features(self, text: str) -> Counter:
|
|
246
|
+
"""Extract normalized token frequency features from text."""
|
|
247
|
+
tokens = re.findall(r"[a-zA-Z0-9_]+", text.lower())
|
|
248
|
+
filtered = [token for token in tokens if len(token) > 2]
|
|
249
|
+
counts = Counter(filtered)
|
|
250
|
+
total = sum(counts.values()) or 1.0
|
|
251
|
+
for key in counts:
|
|
252
|
+
counts[key] /= total
|
|
253
|
+
return counts
|
|
254
|
+
|
|
255
|
+
def _feature_norm(self, features: Counter) -> float:
|
|
256
|
+
"""Compute L2 norm of feature vector."""
|
|
257
|
+
return math.sqrt(sum(value * value for value in features.values()))
|
|
258
|
+
|
|
259
|
+
def _cosine_similarity(self, left: Counter, right: Counter) -> float:
|
|
260
|
+
"""Compute cosine similarity between two feature vectors."""
|
|
261
|
+
keys = left.keys() & right.keys()
|
|
262
|
+
if not keys:
|
|
263
|
+
return 0.0
|
|
264
|
+
return sum(left[key] * right[key] for key in keys)
|
|
265
|
+
|
|
266
|
+
def _min_distance(
|
|
267
|
+
self,
|
|
268
|
+
candidate: SampleT,
|
|
269
|
+
selected: Sequence[SampleT],
|
|
270
|
+
features: dict[str, Counter],
|
|
271
|
+
) -> float:
|
|
272
|
+
"""Compute minimum distance from candidate to selected set."""
|
|
273
|
+
cand_feat = features[self._get_sample_id(candidate)]
|
|
274
|
+
if not selected:
|
|
275
|
+
return 1.0
|
|
276
|
+
sims = [
|
|
277
|
+
self._cosine_similarity(cand_feat, features[self._get_sample_id(item)])
|
|
278
|
+
for item in selected
|
|
279
|
+
]
|
|
280
|
+
similarity = max(sims) if sims else 0.0
|
|
281
|
+
return 1.0 - similarity
|
|
282
|
+
|
|
283
|
+
# ------------------------------------------------------------------
|
|
284
|
+
# Sample Access Helpers (support both dict and object access)
|
|
285
|
+
# ------------------------------------------------------------------
|
|
286
|
+
def _get_sample_id(self, sample: SampleT) -> str:
|
|
287
|
+
"""Get sample_id from sample (supports dict or object)."""
|
|
288
|
+
if isinstance(sample, dict):
|
|
289
|
+
return sample.get("sample_id", sample.get("dialog_id", str(id(sample))))
|
|
290
|
+
return getattr(sample, "sample_id", getattr(sample, "dialog_id", str(id(sample))))
|
|
291
|
+
|
|
292
|
+
def _get_text(self, sample: SampleT) -> str:
|
|
293
|
+
"""Get text from sample (supports dict or object)."""
|
|
294
|
+
if isinstance(sample, dict):
|
|
295
|
+
return sample.get("text", "")
|
|
296
|
+
return getattr(sample, "text", "")
|
|
297
|
+
|
|
298
|
+
def _get_metadata(self, sample: SampleT) -> dict[str, Any]:
|
|
299
|
+
"""Get metadata from sample (supports dict or object)."""
|
|
300
|
+
if isinstance(sample, dict):
|
|
301
|
+
return sample.get("metadata", {})
|
|
302
|
+
return getattr(sample, "metadata", {})
|
sage_sias/setup.py
ADDED
sage_sias/types.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SIAS Core Data Types
|
|
3
|
+
|
|
4
|
+
Defines the core data structures used across SIAS components.
|
|
5
|
+
These are designed to be independent of specific data sources.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any, Protocol, runtime_checkable
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(slots=True)
|
|
15
|
+
class SIASSample:
|
|
16
|
+
"""
|
|
17
|
+
Generic sample container for SIAS algorithms.
|
|
18
|
+
|
|
19
|
+
This is a lightweight data class that can wrap samples from various sources.
|
|
20
|
+
The only required fields are sample_id and text; everything else is optional.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
sample_id: Unique identifier for this sample
|
|
24
|
+
text: The text content (or serialized representation)
|
|
25
|
+
metadata: Arbitrary metadata dictionary
|
|
26
|
+
importance_score: SSIS-computed importance score (set during training)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
sample_id: str
|
|
30
|
+
text: str
|
|
31
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
32
|
+
importance_score: float = 0.0
|
|
33
|
+
|
|
34
|
+
def __hash__(self) -> int:
|
|
35
|
+
return hash(self.sample_id)
|
|
36
|
+
|
|
37
|
+
def __eq__(self, other: object) -> bool:
|
|
38
|
+
if isinstance(other, SIASSample):
|
|
39
|
+
return self.sample_id == other.sample_id
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@runtime_checkable
|
|
44
|
+
class SampleProtocol(Protocol):
|
|
45
|
+
"""
|
|
46
|
+
Protocol for samples that can be used with SIAS algorithms.
|
|
47
|
+
|
|
48
|
+
Any class with these attributes can be used with CoresetSelector
|
|
49
|
+
and OnlineContinualLearner without modification.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def sample_id(self) -> str:
|
|
54
|
+
"""Unique identifier for the sample."""
|
|
55
|
+
...
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def text(self) -> str:
|
|
59
|
+
"""Text content of the sample."""
|
|
60
|
+
...
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def metadata(self) -> dict[str, Any]:
|
|
64
|
+
"""Metadata dictionary."""
|
|
65
|
+
...
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# Backward compatibility alias
|
|
69
|
+
# This allows existing code using ProcessedDialog to work with SIAS
|
|
70
|
+
# by implementing the SampleProtocol
|
|
71
|
+
Sample = SIASSample
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def wrap_sample(
|
|
75
|
+
sample_id: str,
|
|
76
|
+
text: str,
|
|
77
|
+
metadata: dict[str, Any] | None = None,
|
|
78
|
+
**kwargs: Any,
|
|
79
|
+
) -> SIASSample:
|
|
80
|
+
"""
|
|
81
|
+
Factory function to create a SIASSample.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
sample_id: Unique identifier
|
|
85
|
+
text: Text content
|
|
86
|
+
metadata: Optional metadata dict
|
|
87
|
+
**kwargs: Additional metadata fields
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A new SIASSample instance
|
|
91
|
+
"""
|
|
92
|
+
meta = metadata or {}
|
|
93
|
+
meta.update(kwargs)
|
|
94
|
+
return SIASSample(sample_id=sample_id, text=text, metadata=meta)
|