classifyre-cli 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. classifyre_cli-0.4.2.dist-info/METADATA +167 -0
  2. classifyre_cli-0.4.2.dist-info/RECORD +101 -0
  3. classifyre_cli-0.4.2.dist-info/WHEEL +4 -0
  4. classifyre_cli-0.4.2.dist-info/entry_points.txt +2 -0
  5. src/__init__.py +1 -0
  6. src/detectors/__init__.py +105 -0
  7. src/detectors/base.py +97 -0
  8. src/detectors/broken_links/__init__.py +3 -0
  9. src/detectors/broken_links/detector.py +280 -0
  10. src/detectors/config.py +59 -0
  11. src/detectors/content/__init__.py +0 -0
  12. src/detectors/custom/__init__.py +13 -0
  13. src/detectors/custom/detector.py +45 -0
  14. src/detectors/custom/runners/__init__.py +56 -0
  15. src/detectors/custom/runners/_base.py +177 -0
  16. src/detectors/custom/runners/_factory.py +51 -0
  17. src/detectors/custom/runners/_feature_extraction.py +138 -0
  18. src/detectors/custom/runners/_gliner2.py +324 -0
  19. src/detectors/custom/runners/_image_classification.py +98 -0
  20. src/detectors/custom/runners/_llm.py +22 -0
  21. src/detectors/custom/runners/_object_detection.py +107 -0
  22. src/detectors/custom/runners/_regex.py +147 -0
  23. src/detectors/custom/runners/_text_classification.py +109 -0
  24. src/detectors/custom/trainer.py +293 -0
  25. src/detectors/dependencies.py +109 -0
  26. src/detectors/pii/__init__.py +0 -0
  27. src/detectors/pii/detector.py +883 -0
  28. src/detectors/secrets/__init__.py +0 -0
  29. src/detectors/secrets/detector.py +399 -0
  30. src/detectors/threat/__init__.py +0 -0
  31. src/detectors/threat/code_security_detector.py +206 -0
  32. src/detectors/threat/yara_detector.py +177 -0
  33. src/main.py +608 -0
  34. src/models/generated_detectors.py +1296 -0
  35. src/models/generated_input.py +2732 -0
  36. src/models/generated_single_asset_scan_results.py +240 -0
  37. src/outputs/__init__.py +3 -0
  38. src/outputs/base.py +69 -0
  39. src/outputs/console.py +62 -0
  40. src/outputs/factory.py +156 -0
  41. src/outputs/file.py +83 -0
  42. src/outputs/rest.py +258 -0
  43. src/pipeline/__init__.py +7 -0
  44. src/pipeline/content_provider.py +26 -0
  45. src/pipeline/detector_pipeline.py +742 -0
  46. src/pipeline/parsed_content_provider.py +59 -0
  47. src/sandbox/__init__.py +5 -0
  48. src/sandbox/runner.py +145 -0
  49. src/sources/__init__.py +95 -0
  50. src/sources/atlassian_common.py +389 -0
  51. src/sources/azure_blob_storage/__init__.py +3 -0
  52. src/sources/azure_blob_storage/source.py +130 -0
  53. src/sources/base.py +296 -0
  54. src/sources/confluence/__init__.py +3 -0
  55. src/sources/confluence/source.py +733 -0
  56. src/sources/databricks/__init__.py +3 -0
  57. src/sources/databricks/source.py +1279 -0
  58. src/sources/dependencies.py +81 -0
  59. src/sources/google_cloud_storage/__init__.py +3 -0
  60. src/sources/google_cloud_storage/source.py +114 -0
  61. src/sources/hive/__init__.py +3 -0
  62. src/sources/hive/source.py +709 -0
  63. src/sources/jira/__init__.py +3 -0
  64. src/sources/jira/source.py +605 -0
  65. src/sources/mongodb/__init__.py +3 -0
  66. src/sources/mongodb/source.py +550 -0
  67. src/sources/mssql/__init__.py +3 -0
  68. src/sources/mssql/source.py +1034 -0
  69. src/sources/mysql/__init__.py +3 -0
  70. src/sources/mysql/source.py +797 -0
  71. src/sources/neo4j/__init__.py +0 -0
  72. src/sources/neo4j/source.py +523 -0
  73. src/sources/object_storage/base.py +679 -0
  74. src/sources/oracle/__init__.py +3 -0
  75. src/sources/oracle/source.py +982 -0
  76. src/sources/postgresql/__init__.py +3 -0
  77. src/sources/postgresql/source.py +774 -0
  78. src/sources/powerbi/__init__.py +3 -0
  79. src/sources/powerbi/source.py +774 -0
  80. src/sources/recipe_normalizer.py +179 -0
  81. src/sources/s3_compatible_storage/README.md +66 -0
  82. src/sources/s3_compatible_storage/__init__.py +3 -0
  83. src/sources/s3_compatible_storage/source.py +150 -0
  84. src/sources/servicedesk/__init__.py +3 -0
  85. src/sources/servicedesk/source.py +620 -0
  86. src/sources/slack/__init__.py +3 -0
  87. src/sources/slack/source.py +534 -0
  88. src/sources/snowflake/__init__.py +3 -0
  89. src/sources/snowflake/source.py +912 -0
  90. src/sources/tableau/__init__.py +3 -0
  91. src/sources/tableau/source.py +799 -0
  92. src/sources/tabular_utils.py +165 -0
  93. src/sources/wordpress/__init__.py +3 -0
  94. src/sources/wordpress/source.py +590 -0
  95. src/telemetry.py +96 -0
  96. src/utils/__init__.py +1 -0
  97. src/utils/content_extraction.py +108 -0
  98. src/utils/file_parser.py +777 -0
  99. src/utils/hashing.py +82 -0
  100. src/utils/uv_sync.py +79 -0
  101. src/utils/validation.py +56 -0
@@ -0,0 +1,147 @@
1
+ """Regex pipeline runner."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import re
7
+ import time
8
+ from datetime import UTC, datetime
9
+ from typing import Any
10
+
11
+ from ....models.generated_detectors import (
12
+ PipelineResult,
13
+ RegexPatternDefinition,
14
+ RegexPipelineSchema,
15
+ )
16
+ from ...dependencies import MissingDependencyError, require_module
17
+ from ._base import BaseRunner
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def _load_regex_engine() -> tuple[Any, bool]:
23
+ """Try to load google-re2, fall back to stdlib re."""
24
+ try:
25
+ re2_module = require_module("re2", "regex", ["regex"])
26
+ logger.info("Using google-re2 engine for regex patterns")
27
+ return re2_module, True
28
+ except MissingDependencyError:
29
+ logger.info(
30
+ "google-re2 not available, using stdlib re (install with: uv sync --group regex)"
31
+ )
32
+ return re, False
33
+
34
+
35
+ class RegexRunner(BaseRunner):
36
+ """Regex pipeline — uses google-re2 when available, falls back to stdlib re."""
37
+
38
+ def __init__(
39
+ self, schema: RegexPipelineSchema, detector_key: str = "", detector_name: str = ""
40
+ ) -> None:
41
+ self._schema = schema
42
+ self._detector_key = detector_key
43
+ self._detector_name = detector_name
44
+ self._engine, self._using_re2 = _load_regex_engine()
45
+ self._compiled: dict[str, tuple[re.Pattern[str], RegexPatternDefinition]] = {}
46
+ self._compile_patterns()
47
+
48
+ def _compile_patterns(self) -> None:
49
+ patterns = self._schema.patterns or {}
50
+ for name, defn in patterns.items():
51
+ try:
52
+ compiled = self._compile_one(defn)
53
+ self._compiled[name] = (compiled, defn)
54
+ except Exception as exc:
55
+ logger.warning(
56
+ "Invalid regex pattern '%s' in detector '%s': %s",
57
+ name,
58
+ self._detector_key,
59
+ exc,
60
+ )
61
+
62
+ def _compile_one(self, defn: RegexPatternDefinition) -> re.Pattern[str]:
63
+ case_sensitive = defn.case_sensitive if defn.case_sensitive is not None else True
64
+ dot_nl = defn.dot_nl or False
65
+ literal = defn.literal or False
66
+ longest_match = defn.longest_match or False
67
+ max_mem = defn.max_mem
68
+
69
+ legacy_flags = defn.flags or 0
70
+ if isinstance(legacy_flags, int) and legacy_flags & re.IGNORECASE:
71
+ case_sensitive = False
72
+ if isinstance(legacy_flags, int) and legacy_flags & re.DOTALL:
73
+ dot_nl = True
74
+
75
+ if self._using_re2:
76
+ options = self._engine.Options()
77
+ options.case_sensitive = case_sensitive
78
+ options.dot_nl = dot_nl
79
+ options.literal = literal
80
+ options.longest_match = longest_match
81
+ if max_mem is not None:
82
+ options.max_mem = max_mem
83
+ return self._engine.compile(defn.pattern, options=options)
84
+
85
+ flags = legacy_flags
86
+ if not case_sensitive:
87
+ flags |= re.IGNORECASE
88
+ if dot_nl:
89
+ flags |= re.DOTALL
90
+ if literal:
91
+ return re.compile(re.escape(defn.pattern), flags)
92
+ if longest_match:
93
+ logger.debug("longest_match is a RE2-only feature, ignored with stdlib re")
94
+ if max_mem is not None:
95
+ logger.debug("max_mem is a RE2-only feature, ignored with stdlib re")
96
+ return re.compile(defn.pattern, flags)
97
+
98
+ def run(self, text: str) -> PipelineResult:
99
+ start_ms = time.monotonic()
100
+ entities: dict[str, list[dict[str, object]]] = {}
101
+
102
+ for name, (rx, defn) in self._compiled.items():
103
+ group_idx = defn.group or 0
104
+ spans: list[dict[str, object]] = []
105
+ for match in rx.finditer(text):
106
+ actual_group = group_idx
107
+ try:
108
+ value = match.group(group_idx)
109
+ except IndexError:
110
+ value = match.group(0)
111
+ actual_group = 0
112
+ logger.warning(
113
+ "Capture group %d does not exist in pattern '%s', using group 0",
114
+ group_idx,
115
+ name,
116
+ )
117
+
118
+ start = match.start(actual_group)
119
+ end = match.end(actual_group)
120
+
121
+ span: dict[str, object] = {
122
+ "value": value or "",
123
+ "confidence": 1.0,
124
+ "start": start,
125
+ "end": end,
126
+ }
127
+ if defn.severity is not None:
128
+ span["severity"] = str(defn.severity)
129
+ if match.lastindex:
130
+ span["groups"] = match.groups()
131
+
132
+ spans.append(span)
133
+ if spans:
134
+ entities[name] = spans
135
+
136
+ latency_ms = round((time.monotonic() - start_ms) * 1000)
137
+ engine_tag = "RE2" if self._using_re2 else "stdlib-re"
138
+ return PipelineResult(
139
+ entities=entities,
140
+ classification={},
141
+ metadata={
142
+ "runner": "REGEX",
143
+ "engine": engine_tag,
144
+ "latency_ms": latency_ms,
145
+ "timestamp": datetime.now(UTC).isoformat(),
146
+ },
147
+ )
@@ -0,0 +1,109 @@
1
+ """Text classification pipeline runner."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ from ....models.generated_detectors import Severity, TextClassificationPipelineSchema
9
+ from ....models.generated_single_asset_scan_results import DetectionResult
10
+ from ...dependencies import ensure_torch, require_module
11
+ from ._base import _TEXT_CONTENT_TYPES, BaseRunner, _resolve_pipeline_severity
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def _chunk_text(text: str, chunk_size: int | None, chunk_overlap: int) -> list[str]:
17
+ """Split text into chunks. Returns [text] when chunk_size is not set."""
18
+ if not chunk_size:
19
+ return [text]
20
+ step = max(1, chunk_size - chunk_overlap)
21
+ return [text[i : i + chunk_size] for i in range(0, len(text), step)]
22
+
23
+
24
+ class TextClassificationRunner(BaseRunner):
25
+ """Text classification via a single HuggingFace text-classification pipeline."""
26
+
27
+ def __init__(
28
+ self,
29
+ schema: TextClassificationPipelineSchema,
30
+ detector_key: str = "",
31
+ detector_name: str = "",
32
+ ) -> None:
33
+ self._schema = schema
34
+ self._detector_key = detector_key
35
+ self._detector_name = detector_name
36
+ ensure_torch("text_classification", ["custom", "detectors"])
37
+ transformers = require_module(
38
+ "transformers", "text_classification", ["custom", "detectors"]
39
+ )
40
+ pipeline_kwargs: dict[str, Any] = {
41
+ "model": schema.model,
42
+ "device": schema.device or "cpu",
43
+ }
44
+ if schema.model_revision:
45
+ pipeline_kwargs["revision"] = schema.model_revision
46
+ if schema.top_k is not None:
47
+ pipeline_kwargs["top_k"] = schema.top_k
48
+ if schema.function_to_apply is not None:
49
+ pipeline_kwargs["function_to_apply"] = str(schema.function_to_apply)
50
+ self._pipe: Any = transformers.pipeline("text-classification", **pipeline_kwargs)
51
+
52
+ def run(self, text: str) -> None: # type: ignore[override] # pragma: no cover
53
+ raise NotImplementedError("TextClassificationRunner uses detect() directly")
54
+
55
+ def detect(self, content: str | bytes, content_type: str) -> list[DetectionResult]:
56
+ if isinstance(content, bytes):
57
+ return []
58
+ if content_type not in _TEXT_CONTENT_TYPES:
59
+ return []
60
+ text = content.strip()
61
+ if not text:
62
+ return []
63
+
64
+ schema = self._schema
65
+ chunk_size: int | None = getattr(schema.chunk_size, "root", schema.chunk_size)
66
+ chunk_overlap: int = getattr(schema.chunk_overlap, "root", schema.chunk_overlap) or 0
67
+ max_length: int | None = getattr(schema.max_length, "root", schema.max_length)
68
+ threshold = schema.confidence_threshold if schema.confidence_threshold is not None else 0.7
69
+ default_severity = schema.severity or Severity.info
70
+
71
+ best_scores: dict[str, float] = {}
72
+ try:
73
+ for chunk in _chunk_text(text, chunk_size, chunk_overlap):
74
+ call_kwargs: dict[str, Any] = {"truncation": True}
75
+ if max_length is not None:
76
+ call_kwargs["max_length"] = max_length
77
+ raw = self._pipe(chunk, **call_kwargs) or []
78
+ preds: list[dict[str, Any]] = raw[0] if raw and isinstance(raw[0], list) else raw
79
+ for pred in preds:
80
+ label: str = pred.get("label", "unknown")
81
+ score: float = float(pred.get("score", 0.0))
82
+ if score > best_scores.get(label, 0.0):
83
+ best_scores[label] = score
84
+ except Exception as exc:
85
+ logger.error(
86
+ "text_classification error (model=%s): %s", schema.model, exc, exc_info=True
87
+ )
88
+
89
+ results: list[DetectionResult] = []
90
+ for label, score in best_scores.items():
91
+ if score < threshold:
92
+ continue
93
+ severity = _resolve_pipeline_severity(label, schema.severity_map, default_severity)
94
+ results.append(
95
+ self._make_result(
96
+ finding_type=f"classification:{label}",
97
+ category="CONTENT",
98
+ severity=severity,
99
+ confidence=score,
100
+ matched_content=text[:512],
101
+ location=None,
102
+ metadata={"model": schema.model, "predicted_label": label, "score": score},
103
+ )
104
+ )
105
+ results.sort(key=lambda r: r.confidence, reverse=True)
106
+ return results
107
+
108
+ def get_supported_content_types(self) -> list[str]:
109
+ return list(_TEXT_CONTENT_TYPES)
@@ -0,0 +1,293 @@
1
+ """GLiNER2 pipeline fine-tuning trainer.
2
+
3
+ Handles two training modes in a single pass:
4
+ - NER fine-tuning via the base GLiNER model (entity examples with span values)
5
+ - Zero-shot classification fine-tuning via SetFit (text + label pairs)
6
+
7
+ Artifacts are written to an output directory structured as:
8
+ <output_dir>/
9
+ gliner2/ -- fine-tuned GLiNER2 model weights (HuggingFace format)
10
+ setfit/<task>/ -- one SetFit model per classification task
11
+ manifest.json -- training metadata for the runner
12
+
13
+ The trainer produces a JSON result dict to stdout, which the API reads to update
14
+ the training run record.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ import logging
21
+ import time
22
+ from dataclasses import dataclass, field
23
+ from pathlib import Path
24
+ from typing import Any
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Minimum examples needed before we attempt fine-tuning (not just annotation storage)
29
+ _MIN_NER_EXAMPLES = 5
30
+ _MIN_SETFIT_PER_CLASS = 2
31
+
32
+
33
+ @dataclass
34
+ class TrainingExample:
35
+ label: str
36
+ text: str
37
+ value: str | None = None # specific entity span text (NER only)
38
+ accepted: bool = True
39
+ source: str | None = None
40
+
41
+
42
+ @dataclass
43
+ class TrainingResult:
44
+ status: str
45
+ trained_examples: int
46
+ positive_examples: int
47
+ negative_examples: int
48
+ model_artifact_path: str
49
+ metrics: dict[str, Any] = field(default_factory=dict)
50
+
51
+ def to_dict(self) -> dict[str, Any]:
52
+ return {
53
+ "status": self.status,
54
+ "trained_examples": self.trained_examples,
55
+ "positive_examples": self.positive_examples,
56
+ "negative_examples": self.negative_examples,
57
+ "model_artifact_path": self.model_artifact_path,
58
+ "metrics": self.metrics,
59
+ }
60
+
61
+
62
+ class GLiNER2Trainer:
63
+ """Orchestrates NER + classification fine-tuning for a GLiNER2 pipeline."""
64
+
65
+ def __init__(
66
+ self,
67
+ pipeline_schema: dict[str, Any],
68
+ examples_raw: list[dict[str, Any]],
69
+ output_dir: Path,
70
+ ) -> None:
71
+ self._schema = pipeline_schema
72
+ self._output_dir = output_dir
73
+ self._examples: list[TrainingExample] = [
74
+ TrainingExample(
75
+ label=str(ex.get("label", "")),
76
+ text=str(ex.get("text", "")),
77
+ value=ex.get("value") or None,
78
+ accepted=bool(ex.get("accepted", True)),
79
+ source=ex.get("source") or None,
80
+ )
81
+ for ex in examples_raw
82
+ if ex.get("label") and ex.get("text")
83
+ ]
84
+
85
+ def train(self) -> TrainingResult:
86
+ t0 = time.monotonic()
87
+ self._output_dir.mkdir(parents=True, exist_ok=True)
88
+
89
+ positive = [e for e in self._examples if e.accepted]
90
+ negative = [e for e in self._examples if not e.accepted]
91
+ metrics: dict[str, Any] = {}
92
+
93
+ entities: dict[str, Any] = self._schema.get("entities") or {}
94
+ classification: dict[str, Any] = self._schema.get("classification") or {}
95
+ base_model: str = (self._schema.get("model") or {}).get("name") or "fastino/gliner2-base-v1"
96
+
97
+ # ── NER fine-tuning ────────────────────────────────────────────────────
98
+ entity_labels = set(entities.keys())
99
+ ner_examples = [e for e in positive if e.label in entity_labels and e.value]
100
+ if ner_examples:
101
+ metrics["ner"] = self._train_ner(ner_examples, base_model)
102
+ else:
103
+ metrics["ner"] = {"skipped": True, "reason": "No span-annotated NER examples"}
104
+
105
+ # ── Classification fine-tuning (SetFit) ────────────────────────────────
106
+ if classification:
107
+ metrics["classification"] = self._train_classification(positive, classification)
108
+ else:
109
+ metrics["classification"] = {
110
+ "skipped": True,
111
+ "reason": "No classification tasks defined",
112
+ }
113
+
114
+ # Write manifest so the runner knows what's available
115
+ manifest: dict[str, Any] = {
116
+ "schema_type": self._schema.get("type", "GLINER2"),
117
+ "base_model": base_model,
118
+ "trained_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
119
+ "metrics": metrics,
120
+ }
121
+ (self._output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2))
122
+
123
+ metrics["duration_s"] = round(time.monotonic() - t0, 2)
124
+
125
+ return TrainingResult(
126
+ status="SUCCEEDED",
127
+ trained_examples=len(positive),
128
+ positive_examples=len(positive),
129
+ negative_examples=len(negative),
130
+ model_artifact_path=str(self._output_dir),
131
+ metrics=metrics,
132
+ )
133
+
134
+ # ── NER ────────────────────────────────────────────────────────────────────
135
+
136
+ def _train_ner(self, examples: list[TrainingExample], base_model: str) -> dict[str, Any]:
137
+ if len(examples) < _MIN_NER_EXAMPLES:
138
+ return {
139
+ "skipped": True,
140
+ "reason": f"Need ≥{_MIN_NER_EXAMPLES} span-annotated examples (got {len(examples)})",
141
+ }
142
+
143
+ # Build GLiNER-format span annotations
144
+ train_data: list[dict[str, Any]] = []
145
+ for ex in examples:
146
+ if not ex.value:
147
+ continue
148
+ start = ex.text.find(ex.value)
149
+ if start < 0:
150
+ continue
151
+ train_data.append(
152
+ {
153
+ "text": ex.text,
154
+ "ner": [{"start": start, "end": start + len(ex.value), "label": ex.label}],
155
+ }
156
+ )
157
+
158
+ if len(train_data) < _MIN_NER_EXAMPLES:
159
+ return {
160
+ "skipped": True,
161
+ "reason": f"Too few locatable spans after search (got {len(train_data)})",
162
+ }
163
+
164
+ gliner_out = self._output_dir / "gliner2"
165
+ gliner_out.mkdir(parents=True, exist_ok=True)
166
+
167
+ try:
168
+ from gliner import GLiNER # type: ignore[import-untyped]
169
+
170
+ model = GLiNER.from_pretrained(base_model)
171
+
172
+ # Attempt to use the trainer API if available
173
+ try:
174
+ from gliner.training import Trainer as GLiNERTrainer # type: ignore[import-untyped]
175
+ from gliner.training import (
176
+ TrainingArguments as GLiNERArgs, # type: ignore[import-untyped]
177
+ )
178
+
179
+ args = GLiNERArgs(
180
+ output_dir=str(gliner_out),
181
+ num_train_epochs=3,
182
+ per_device_train_batch_size=min(4, len(train_data)),
183
+ warmup_ratio=0.1,
184
+ save_steps=0,
185
+ )
186
+ trainer = GLiNERTrainer(model=model, args=args, train_dataset=train_data)
187
+ trainer.train()
188
+ except ImportError:
189
+ # Fallback: manual training loop if the Trainer API isn't available
190
+ logger.warning("gliner.training not available — using manual fine-tuning loop")
191
+ self._manual_ner_train(model, train_data)
192
+
193
+ model.save_pretrained(str(gliner_out))
194
+ logger.info("NER model saved to %s", gliner_out)
195
+ return {"examples": len(train_data), "epochs": 3, "saved_to": "gliner2/"}
196
+
197
+ except ImportError as e:
198
+ return {"skipped": True, "reason": f"gliner not installed: {e}"}
199
+ except Exception as e:
200
+ logger.warning("NER fine-tuning failed: %s", e, exc_info=True)
201
+ return {"skipped": True, "reason": str(e)}
202
+
203
+ def _manual_ner_train(self, model: Any, train_data: list[dict[str, Any]]) -> None:
204
+ """Simple SGD loop when Trainer API is unavailable."""
205
+ import torch # type: ignore[import-untyped]
206
+
207
+ optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
208
+ model.train()
209
+ for _epoch in range(3):
210
+ for item in train_data:
211
+ optimizer.zero_grad()
212
+ loss = model.compute_loss(item)
213
+ if loss is not None:
214
+ loss.backward()
215
+ optimizer.step()
216
+
217
+ # ── Classification (SetFit) ────────────────────────────────────────────────
218
+
219
+ def _train_classification(
220
+ self,
221
+ positive: list[TrainingExample],
222
+ classification: dict[str, Any],
223
+ ) -> dict[str, Any]:
224
+ task_results: dict[str, Any] = {}
225
+ for task_name, task_defn in classification.items():
226
+ labels: list[str] = task_defn.get("labels") or []
227
+ task_examples = [e for e in positive if e.label in labels]
228
+ if len(task_examples) < _MIN_SETFIT_PER_CLASS * max(len(labels), 1):
229
+ task_results[task_name] = {
230
+ "skipped": True,
231
+ "reason": (
232
+ f"Need ≥{_MIN_SETFIT_PER_CLASS} examples per label "
233
+ f"(got {len(task_examples)} across {len(labels)} labels)"
234
+ ),
235
+ }
236
+ continue
237
+ try:
238
+ task_results[task_name] = self._train_setfit_task(task_name, task_examples, labels)
239
+ except Exception as e:
240
+ logger.warning("SetFit training failed for task '%s': %s", task_name, e)
241
+ task_results[task_name] = {"skipped": True, "reason": str(e)}
242
+ return task_results
243
+
244
+ def _train_setfit_task(
245
+ self,
246
+ task_name: str,
247
+ examples: list[TrainingExample],
248
+ labels: list[str],
249
+ ) -> dict[str, Any]:
250
+ try:
251
+ from datasets import Dataset # type: ignore[import-untyped]
252
+ from setfit import ( # type: ignore[import-untyped]
253
+ SetFitModel,
254
+ Trainer,
255
+ TrainingArguments,
256
+ )
257
+ except ImportError as e:
258
+ return {"skipped": True, "reason": f"setfit/datasets not installed: {e}"}
259
+
260
+ label2id = {label: i for i, label in enumerate(labels)}
261
+ dataset = Dataset.from_dict(
262
+ {
263
+ "text": [ex.text for ex in examples],
264
+ "label": [label2id.get(ex.label, 0) for ex in examples],
265
+ }
266
+ )
267
+
268
+ model = SetFitModel.from_pretrained(
269
+ "sentence-transformers/paraphrase-MiniLM-L6-v2",
270
+ labels=labels,
271
+ )
272
+
273
+ save_path = self._output_dir / "setfit" / task_name
274
+ save_path.mkdir(parents=True, exist_ok=True)
275
+
276
+ args = TrainingArguments(
277
+ output_dir=str(save_path),
278
+ num_epochs=1,
279
+ batch_size=min(8, len(examples)),
280
+ )
281
+ trainer = Trainer(model=model, args=args, train_dataset=dataset)
282
+ trainer.train()
283
+ model.save_pretrained(str(save_path))
284
+
285
+ # Persist label ordering so the runner can decode predictions
286
+ (save_path / "labels.json").write_text(json.dumps(labels))
287
+
288
+ logger.info("SetFit model for task '%s' saved to %s", task_name, save_path)
289
+ return {
290
+ "examples": len(examples),
291
+ "labels": labels,
292
+ "saved_to": f"setfit/{task_name}/",
293
+ }
@@ -0,0 +1,109 @@
1
+ """Helpers for optional detector dependencies."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import importlib
6
+ import logging
7
+ from types import ModuleType
8
+
9
+ from src.utils.uv_sync import auto_install_enabled, sync_group
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class MissingDependencyError(RuntimeError):
15
+ """Raised when an optional detector dependency is unavailable or invalid."""
16
+
17
+ def __init__(
18
+ self,
19
+ detector_name: str,
20
+ dependencies: list[str],
21
+ uv_groups: list[str],
22
+ detail: str | None = None,
23
+ ) -> None:
24
+ self.detector_name = detector_name
25
+ self.dependencies = dependencies
26
+ self.uv_groups = uv_groups
27
+ self.detail = detail
28
+
29
+ deps = ", ".join(dependencies)
30
+ group_hint = " or ".join(f"`uv sync --group {group}`" for group in uv_groups)
31
+ message = (
32
+ f"{detector_name} detector requires optional dependencies ({deps}). "
33
+ f"Install with {group_hint}."
34
+ )
35
+ if detail:
36
+ message = f"{message} {detail}"
37
+
38
+ super().__init__(message)
39
+
40
+
41
+ def _ordered_groups(groups: list[str]) -> list[str]:
42
+ unique = list(dict.fromkeys(groups))
43
+ return sorted(unique, key=lambda group: (group == "detectors", group))
44
+
45
+
46
+ def require_module(
47
+ module_name: str,
48
+ detector_name: str,
49
+ uv_groups: list[str],
50
+ detail: str | None = None,
51
+ ) -> ModuleType:
52
+ """Import a module or raise a MissingDependencyError with uv guidance."""
53
+ try:
54
+ return importlib.import_module(module_name)
55
+ except Exception as exc: # pragma: no cover - exercised indirectly in integration setups
56
+ detail_messages: list[str] = [f"Original error: {exc}"]
57
+
58
+ if auto_install_enabled() and uv_groups:
59
+ for group in _ordered_groups(uv_groups):
60
+ success, install_detail = sync_group(group)
61
+ if install_detail:
62
+ detail_messages.append(install_detail)
63
+ if not success:
64
+ continue
65
+
66
+ try:
67
+ importlib.invalidate_caches()
68
+ return importlib.import_module(module_name)
69
+ except Exception as retry_exc: # pragma: no cover
70
+ detail_messages.append(
71
+ f"Module '{module_name}' still unavailable after installing '{group}': {retry_exc}"
72
+ )
73
+
74
+ base_detail = detail or "Optional dependency import failed"
75
+ error_detail = (
76
+ f"{base_detail}. {'; '.join(detail_messages)}" if detail_messages else base_detail
77
+ )
78
+ raise MissingDependencyError(
79
+ detector_name=detector_name,
80
+ dependencies=[module_name.split(".", maxsplit=1)[0]],
81
+ uv_groups=uv_groups,
82
+ detail=error_detail,
83
+ ) from exc
84
+
85
+
86
+ def ensure_torch(detector_name: str, uv_groups: list[str]) -> ModuleType:
87
+ """Verify PyTorch is importable and looks like a valid install."""
88
+ torch_module = require_module("torch", detector_name, uv_groups)
89
+ if not hasattr(torch_module, "no_grad"):
90
+ raise MissingDependencyError(
91
+ detector_name=detector_name,
92
+ dependencies=["torch"],
93
+ uv_groups=uv_groups,
94
+ detail=(
95
+ "Detected a module named 'torch' but it is missing `no_grad`. "
96
+ "Ensure PyTorch is installed via uv and no local `torch.py` shadows it."
97
+ ),
98
+ )
99
+ if not hasattr(torch_module, "_utils"):
100
+ raise MissingDependencyError(
101
+ detector_name=detector_name,
102
+ dependencies=["torch"],
103
+ uv_groups=uv_groups,
104
+ detail=(
105
+ "Detected an incomplete/broken PyTorch install (`torch._utils` missing). "
106
+ "Reinstall torch via uv for this environment."
107
+ ),
108
+ )
109
+ return torch_module
File without changes