weirdo 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,299 @@
1
+ """Configuration system for scorers.
2
+
3
+ Provides dataclass-based configuration with preset support.
4
+ """
5
+
6
+ import json
7
+ from dataclasses import dataclass, field, asdict
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+
11
+ @dataclass
12
+ class ScorerConfig:
13
+ """Configuration for creating a scorer instance.
14
+
15
+ Supports loading from dict, YAML, or JSON, and provides
16
+ preset configurations for common use cases.
17
+
18
+ Attributes
19
+ ----------
20
+ scorer : str
21
+ Name of the scorer to use (e.g., 'mlp').
22
+ reference : str
23
+ Name of the reference to use (e.g., 'swissprot').
24
+ k : int
25
+ K-mer size.
26
+ scorer_params : dict
27
+ Additional parameters for the scorer.
28
+ reference_params : dict
29
+ Additional parameters for the reference.
30
+ training_params : dict
31
+ Optional training parameters for trainable scorers.
32
+
33
+ Example
34
+ -------
35
+ >>> config = ScorerConfig.from_preset('default')
36
+ >>> scorer = config.build()
37
+ >>> scorer.train(peptides, labels, target_categories=['human', 'viruses'])
38
+ >>> scores = scorer.score(['MTMDKSEL'])
39
+ """
40
+
41
+ scorer: str = 'mlp'
42
+ reference: str = 'swissprot'
43
+ k: int = 8
44
+ scorer_params: Dict[str, Any] = field(default_factory=dict)
45
+ reference_params: Dict[str, Any] = field(default_factory=dict)
46
+ training_params: Dict[str, Any] = field(default_factory=dict)
47
+
48
+ def to_dict(self) -> Dict[str, Any]:
49
+ """Convert config to dictionary.
50
+
51
+ Returns
52
+ -------
53
+ config_dict : dict
54
+ Configuration as a dictionary.
55
+ """
56
+ return asdict(self)
57
+
58
+ @classmethod
59
+ def from_dict(cls, data: Dict[str, Any]) -> 'ScorerConfig':
60
+ """Create config from dictionary.
61
+
62
+ Parameters
63
+ ----------
64
+ data : dict
65
+ Configuration dictionary.
66
+
67
+ Returns
68
+ -------
69
+ config : ScorerConfig
70
+ Configuration instance.
71
+ """
72
+ return cls(
73
+ scorer=data.get('scorer', 'mlp'),
74
+ reference=data.get('reference', 'swissprot'),
75
+ k=data.get('k', 8),
76
+ scorer_params=data.get('scorer_params', {}),
77
+ reference_params=data.get('reference_params', {}),
78
+ training_params=data.get('training_params', {}),
79
+ )
80
+
81
+ @classmethod
82
+ def from_yaml(cls, path: str) -> 'ScorerConfig':
83
+ """Load config from YAML file.
84
+
85
+ Parameters
86
+ ----------
87
+ path : str
88
+ Path to YAML configuration file.
89
+
90
+ Returns
91
+ -------
92
+ config : ScorerConfig
93
+ Configuration instance.
94
+ """
95
+ try:
96
+ import yaml
97
+ except ImportError:
98
+ raise ImportError("PyYAML is required for YAML config files")
99
+
100
+ with open(path, 'r') as f:
101
+ data = yaml.safe_load(f)
102
+ return cls.from_dict(data)
103
+
104
+ @classmethod
105
+ def from_json(cls, path: str) -> 'ScorerConfig':
106
+ """Load config from JSON file.
107
+
108
+ Parameters
109
+ ----------
110
+ path : str
111
+ Path to JSON configuration file.
112
+
113
+ Returns
114
+ -------
115
+ config : ScorerConfig
116
+ Configuration instance.
117
+ """
118
+ with open(path, 'r') as f:
119
+ data = json.load(f)
120
+ return cls.from_dict(data)
121
+
122
+ def to_json(self, path: str, indent: int = 2) -> None:
123
+ """Save config to JSON file.
124
+
125
+ Parameters
126
+ ----------
127
+ path : str
128
+ Path to save JSON configuration.
129
+ indent : int, default=2
130
+ JSON indentation level.
131
+ """
132
+ with open(path, 'w') as f:
133
+ json.dump(self.to_dict(), f, indent=indent)
134
+
135
+ @classmethod
136
+ def from_preset(cls, name: str) -> 'ScorerConfig':
137
+ """Create config from a named preset.
138
+
139
+ Parameters
140
+ ----------
141
+ name : str
142
+ Preset name (e.g., 'default', 'fast').
143
+
144
+ Returns
145
+ -------
146
+ config : ScorerConfig
147
+ Configuration instance.
148
+
149
+ Raises
150
+ ------
151
+ KeyError
152
+ If preset name is not found.
153
+ """
154
+ return get_preset(name)
155
+
156
+ def build(
157
+ self,
158
+ auto_load: bool = True,
159
+ train_data: Optional[List[str]] = None,
160
+ train_labels: Optional[Any] = None,
161
+ target_categories: Optional[List[str]] = None,
162
+ **train_overrides: Any
163
+ ):
164
+ """Build scorer from this configuration.
165
+
166
+ Parameters
167
+ ----------
168
+ auto_load : bool, default=True
169
+ If True, automatically load the reference.
170
+ train_data : list of str, optional
171
+ Training peptides for trainable scorers.
172
+ train_labels : array-like, optional
173
+ Training labels for trainable scorers.
174
+ target_categories : list of str, optional
175
+ Category names for multi-label training.
176
+ **train_overrides : dict
177
+ Overrides for training parameters.
178
+
179
+ Returns
180
+ -------
181
+ scorer : BaseScorer
182
+ Configured scorer. Trainable scorers are returned untrained
183
+ unless train_data and train_labels are provided.
184
+ """
185
+ from .registry import create_scorer, create_reference
186
+ from .trainable import TrainableScorer
187
+
188
+ # Create and fit scorer
189
+ scorer_params = {'k': self.k, **self.scorer_params}
190
+ scorer = create_scorer(self.scorer, **scorer_params)
191
+ if isinstance(scorer, TrainableScorer):
192
+ if train_data is not None and train_labels is not None:
193
+ train_kwargs = {**self.training_params, **train_overrides}
194
+ scorer.train(
195
+ peptides=train_data,
196
+ labels=train_labels,
197
+ target_categories=target_categories,
198
+ **train_kwargs,
199
+ )
200
+ return scorer
201
+ # Non-trainable scorers require a reference
202
+ ref_params = {'k': self.k, **self.reference_params}
203
+ reference = create_reference(self.reference, **ref_params)
204
+ if auto_load:
205
+ reference.load()
206
+
207
+ scorer.fit(reference)
208
+ return scorer
209
+
210
+
211
+ # Preset configurations
212
+ PRESETS: Dict[str, ScorerConfig] = {
213
+ 'default': ScorerConfig(
214
+ scorer='mlp',
215
+ reference='swissprot',
216
+ k=8,
217
+ scorer_params={
218
+ 'hidden_layer_sizes': (256, 128, 64),
219
+ 'activation': 'relu',
220
+ 'alpha': 0.0001,
221
+ 'early_stopping': True,
222
+ 'use_dipeptides': True,
223
+ },
224
+ reference_params={
225
+ 'categories': None, # All categories available in SwissProt
226
+ },
227
+ training_params={
228
+ 'epochs': 200,
229
+ 'learning_rate': 0.001,
230
+ 'verbose': True,
231
+ },
232
+ ),
233
+ 'fast': ScorerConfig(
234
+ scorer='mlp',
235
+ reference='swissprot',
236
+ k=8,
237
+ scorer_params={
238
+ 'hidden_layer_sizes': (128, 64),
239
+ 'activation': 'relu',
240
+ 'alpha': 0.0001,
241
+ 'early_stopping': True,
242
+ 'use_dipeptides': False,
243
+ },
244
+ reference_params={
245
+ 'categories': None,
246
+ },
247
+ training_params={
248
+ 'epochs': 50,
249
+ 'learning_rate': 0.001,
250
+ 'verbose': True,
251
+ },
252
+ ),
253
+ }
254
+
255
+
256
+ def get_preset(name: str) -> ScorerConfig:
257
+ """Get a preset configuration by name.
258
+
259
+ Parameters
260
+ ----------
261
+ name : str
262
+ Preset name.
263
+
264
+ Returns
265
+ -------
266
+ config : ScorerConfig
267
+ Configuration instance (a copy).
268
+
269
+ Raises
270
+ ------
271
+ KeyError
272
+ If preset name is not found.
273
+ """
274
+ if name not in PRESETS:
275
+ available = list_presets()
276
+ raise KeyError(
277
+ f"Unknown preset '{name}'. Available: {available}"
278
+ )
279
+ # Return a copy to prevent modification of preset
280
+ preset = PRESETS[name]
281
+ return ScorerConfig(
282
+ scorer=preset.scorer,
283
+ reference=preset.reference,
284
+ k=preset.k,
285
+ scorer_params=preset.scorer_params.copy(),
286
+ reference_params=preset.reference_params.copy(),
287
+ training_params=preset.training_params.copy(),
288
+ )
289
+
290
+
291
+ def list_presets() -> List[str]:
292
+ """List available preset names.
293
+
294
+ Returns
295
+ -------
296
+ names : list of str
297
+ Available preset names.
298
+ """
299
+ return sorted(PRESETS.keys())