classifyre-cli 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- classifyre_cli-0.4.2.dist-info/METADATA +167 -0
- classifyre_cli-0.4.2.dist-info/RECORD +101 -0
- classifyre_cli-0.4.2.dist-info/WHEEL +4 -0
- classifyre_cli-0.4.2.dist-info/entry_points.txt +2 -0
- src/__init__.py +1 -0
- src/detectors/__init__.py +105 -0
- src/detectors/base.py +97 -0
- src/detectors/broken_links/__init__.py +3 -0
- src/detectors/broken_links/detector.py +280 -0
- src/detectors/config.py +59 -0
- src/detectors/content/__init__.py +0 -0
- src/detectors/custom/__init__.py +13 -0
- src/detectors/custom/detector.py +45 -0
- src/detectors/custom/runners/__init__.py +56 -0
- src/detectors/custom/runners/_base.py +177 -0
- src/detectors/custom/runners/_factory.py +51 -0
- src/detectors/custom/runners/_feature_extraction.py +138 -0
- src/detectors/custom/runners/_gliner2.py +324 -0
- src/detectors/custom/runners/_image_classification.py +98 -0
- src/detectors/custom/runners/_llm.py +22 -0
- src/detectors/custom/runners/_object_detection.py +107 -0
- src/detectors/custom/runners/_regex.py +147 -0
- src/detectors/custom/runners/_text_classification.py +109 -0
- src/detectors/custom/trainer.py +293 -0
- src/detectors/dependencies.py +109 -0
- src/detectors/pii/__init__.py +0 -0
- src/detectors/pii/detector.py +883 -0
- src/detectors/secrets/__init__.py +0 -0
- src/detectors/secrets/detector.py +399 -0
- src/detectors/threat/__init__.py +0 -0
- src/detectors/threat/code_security_detector.py +206 -0
- src/detectors/threat/yara_detector.py +177 -0
- src/main.py +608 -0
- src/models/generated_detectors.py +1296 -0
- src/models/generated_input.py +2732 -0
- src/models/generated_single_asset_scan_results.py +240 -0
- src/outputs/__init__.py +3 -0
- src/outputs/base.py +69 -0
- src/outputs/console.py +62 -0
- src/outputs/factory.py +156 -0
- src/outputs/file.py +83 -0
- src/outputs/rest.py +258 -0
- src/pipeline/__init__.py +7 -0
- src/pipeline/content_provider.py +26 -0
- src/pipeline/detector_pipeline.py +742 -0
- src/pipeline/parsed_content_provider.py +59 -0
- src/sandbox/__init__.py +5 -0
- src/sandbox/runner.py +145 -0
- src/sources/__init__.py +95 -0
- src/sources/atlassian_common.py +389 -0
- src/sources/azure_blob_storage/__init__.py +3 -0
- src/sources/azure_blob_storage/source.py +130 -0
- src/sources/base.py +296 -0
- src/sources/confluence/__init__.py +3 -0
- src/sources/confluence/source.py +733 -0
- src/sources/databricks/__init__.py +3 -0
- src/sources/databricks/source.py +1279 -0
- src/sources/dependencies.py +81 -0
- src/sources/google_cloud_storage/__init__.py +3 -0
- src/sources/google_cloud_storage/source.py +114 -0
- src/sources/hive/__init__.py +3 -0
- src/sources/hive/source.py +709 -0
- src/sources/jira/__init__.py +3 -0
- src/sources/jira/source.py +605 -0
- src/sources/mongodb/__init__.py +3 -0
- src/sources/mongodb/source.py +550 -0
- src/sources/mssql/__init__.py +3 -0
- src/sources/mssql/source.py +1034 -0
- src/sources/mysql/__init__.py +3 -0
- src/sources/mysql/source.py +797 -0
- src/sources/neo4j/__init__.py +0 -0
- src/sources/neo4j/source.py +523 -0
- src/sources/object_storage/base.py +679 -0
- src/sources/oracle/__init__.py +3 -0
- src/sources/oracle/source.py +982 -0
- src/sources/postgresql/__init__.py +3 -0
- src/sources/postgresql/source.py +774 -0
- src/sources/powerbi/__init__.py +3 -0
- src/sources/powerbi/source.py +774 -0
- src/sources/recipe_normalizer.py +179 -0
- src/sources/s3_compatible_storage/README.md +66 -0
- src/sources/s3_compatible_storage/__init__.py +3 -0
- src/sources/s3_compatible_storage/source.py +150 -0
- src/sources/servicedesk/__init__.py +3 -0
- src/sources/servicedesk/source.py +620 -0
- src/sources/slack/__init__.py +3 -0
- src/sources/slack/source.py +534 -0
- src/sources/snowflake/__init__.py +3 -0
- src/sources/snowflake/source.py +912 -0
- src/sources/tableau/__init__.py +3 -0
- src/sources/tableau/source.py +799 -0
- src/sources/tabular_utils.py +165 -0
- src/sources/wordpress/__init__.py +3 -0
- src/sources/wordpress/source.py +590 -0
- src/telemetry.py +96 -0
- src/utils/__init__.py +1 -0
- src/utils/content_extraction.py +108 -0
- src/utils/file_parser.py +777 -0
- src/utils/hashing.py +82 -0
- src/utils/uv_sync.py +79 -0
- src/utils/validation.py +56 -0
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""Regex pipeline runner."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import re
|
|
7
|
+
import time
|
|
8
|
+
from datetime import UTC, datetime
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from ....models.generated_detectors import (
|
|
12
|
+
PipelineResult,
|
|
13
|
+
RegexPatternDefinition,
|
|
14
|
+
RegexPipelineSchema,
|
|
15
|
+
)
|
|
16
|
+
from ...dependencies import MissingDependencyError, require_module
|
|
17
|
+
from ._base import BaseRunner
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _load_regex_engine() -> tuple[Any, bool]:
|
|
23
|
+
"""Try to load google-re2, fall back to stdlib re."""
|
|
24
|
+
try:
|
|
25
|
+
re2_module = require_module("re2", "regex", ["regex"])
|
|
26
|
+
logger.info("Using google-re2 engine for regex patterns")
|
|
27
|
+
return re2_module, True
|
|
28
|
+
except MissingDependencyError:
|
|
29
|
+
logger.info(
|
|
30
|
+
"google-re2 not available, using stdlib re (install with: uv sync --group regex)"
|
|
31
|
+
)
|
|
32
|
+
return re, False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RegexRunner(BaseRunner):
|
|
36
|
+
"""Regex pipeline — uses google-re2 when available, falls back to stdlib re."""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self, schema: RegexPipelineSchema, detector_key: str = "", detector_name: str = ""
|
|
40
|
+
) -> None:
|
|
41
|
+
self._schema = schema
|
|
42
|
+
self._detector_key = detector_key
|
|
43
|
+
self._detector_name = detector_name
|
|
44
|
+
self._engine, self._using_re2 = _load_regex_engine()
|
|
45
|
+
self._compiled: dict[str, tuple[re.Pattern[str], RegexPatternDefinition]] = {}
|
|
46
|
+
self._compile_patterns()
|
|
47
|
+
|
|
48
|
+
def _compile_patterns(self) -> None:
|
|
49
|
+
patterns = self._schema.patterns or {}
|
|
50
|
+
for name, defn in patterns.items():
|
|
51
|
+
try:
|
|
52
|
+
compiled = self._compile_one(defn)
|
|
53
|
+
self._compiled[name] = (compiled, defn)
|
|
54
|
+
except Exception as exc:
|
|
55
|
+
logger.warning(
|
|
56
|
+
"Invalid regex pattern '%s' in detector '%s': %s",
|
|
57
|
+
name,
|
|
58
|
+
self._detector_key,
|
|
59
|
+
exc,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def _compile_one(self, defn: RegexPatternDefinition) -> re.Pattern[str]:
|
|
63
|
+
case_sensitive = defn.case_sensitive if defn.case_sensitive is not None else True
|
|
64
|
+
dot_nl = defn.dot_nl or False
|
|
65
|
+
literal = defn.literal or False
|
|
66
|
+
longest_match = defn.longest_match or False
|
|
67
|
+
max_mem = defn.max_mem
|
|
68
|
+
|
|
69
|
+
legacy_flags = defn.flags or 0
|
|
70
|
+
if isinstance(legacy_flags, int) and legacy_flags & re.IGNORECASE:
|
|
71
|
+
case_sensitive = False
|
|
72
|
+
if isinstance(legacy_flags, int) and legacy_flags & re.DOTALL:
|
|
73
|
+
dot_nl = True
|
|
74
|
+
|
|
75
|
+
if self._using_re2:
|
|
76
|
+
options = self._engine.Options()
|
|
77
|
+
options.case_sensitive = case_sensitive
|
|
78
|
+
options.dot_nl = dot_nl
|
|
79
|
+
options.literal = literal
|
|
80
|
+
options.longest_match = longest_match
|
|
81
|
+
if max_mem is not None:
|
|
82
|
+
options.max_mem = max_mem
|
|
83
|
+
return self._engine.compile(defn.pattern, options=options)
|
|
84
|
+
|
|
85
|
+
flags = legacy_flags
|
|
86
|
+
if not case_sensitive:
|
|
87
|
+
flags |= re.IGNORECASE
|
|
88
|
+
if dot_nl:
|
|
89
|
+
flags |= re.DOTALL
|
|
90
|
+
if literal:
|
|
91
|
+
return re.compile(re.escape(defn.pattern), flags)
|
|
92
|
+
if longest_match:
|
|
93
|
+
logger.debug("longest_match is a RE2-only feature, ignored with stdlib re")
|
|
94
|
+
if max_mem is not None:
|
|
95
|
+
logger.debug("max_mem is a RE2-only feature, ignored with stdlib re")
|
|
96
|
+
return re.compile(defn.pattern, flags)
|
|
97
|
+
|
|
98
|
+
def run(self, text: str) -> PipelineResult:
|
|
99
|
+
start_ms = time.monotonic()
|
|
100
|
+
entities: dict[str, list[dict[str, object]]] = {}
|
|
101
|
+
|
|
102
|
+
for name, (rx, defn) in self._compiled.items():
|
|
103
|
+
group_idx = defn.group or 0
|
|
104
|
+
spans: list[dict[str, object]] = []
|
|
105
|
+
for match in rx.finditer(text):
|
|
106
|
+
actual_group = group_idx
|
|
107
|
+
try:
|
|
108
|
+
value = match.group(group_idx)
|
|
109
|
+
except IndexError:
|
|
110
|
+
value = match.group(0)
|
|
111
|
+
actual_group = 0
|
|
112
|
+
logger.warning(
|
|
113
|
+
"Capture group %d does not exist in pattern '%s', using group 0",
|
|
114
|
+
group_idx,
|
|
115
|
+
name,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
start = match.start(actual_group)
|
|
119
|
+
end = match.end(actual_group)
|
|
120
|
+
|
|
121
|
+
span: dict[str, object] = {
|
|
122
|
+
"value": value or "",
|
|
123
|
+
"confidence": 1.0,
|
|
124
|
+
"start": start,
|
|
125
|
+
"end": end,
|
|
126
|
+
}
|
|
127
|
+
if defn.severity is not None:
|
|
128
|
+
span["severity"] = str(defn.severity)
|
|
129
|
+
if match.lastindex:
|
|
130
|
+
span["groups"] = match.groups()
|
|
131
|
+
|
|
132
|
+
spans.append(span)
|
|
133
|
+
if spans:
|
|
134
|
+
entities[name] = spans
|
|
135
|
+
|
|
136
|
+
latency_ms = round((time.monotonic() - start_ms) * 1000)
|
|
137
|
+
engine_tag = "RE2" if self._using_re2 else "stdlib-re"
|
|
138
|
+
return PipelineResult(
|
|
139
|
+
entities=entities,
|
|
140
|
+
classification={},
|
|
141
|
+
metadata={
|
|
142
|
+
"runner": "REGEX",
|
|
143
|
+
"engine": engine_tag,
|
|
144
|
+
"latency_ms": latency_ms,
|
|
145
|
+
"timestamp": datetime.now(UTC).isoformat(),
|
|
146
|
+
},
|
|
147
|
+
)
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Text classification pipeline runner."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from ....models.generated_detectors import Severity, TextClassificationPipelineSchema
|
|
9
|
+
from ....models.generated_single_asset_scan_results import DetectionResult
|
|
10
|
+
from ...dependencies import ensure_torch, require_module
|
|
11
|
+
from ._base import _TEXT_CONTENT_TYPES, BaseRunner, _resolve_pipeline_severity
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _chunk_text(text: str, chunk_size: int | None, chunk_overlap: int) -> list[str]:
|
|
17
|
+
"""Split text into chunks. Returns [text] when chunk_size is not set."""
|
|
18
|
+
if not chunk_size:
|
|
19
|
+
return [text]
|
|
20
|
+
step = max(1, chunk_size - chunk_overlap)
|
|
21
|
+
return [text[i : i + chunk_size] for i in range(0, len(text), step)]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TextClassificationRunner(BaseRunner):
|
|
25
|
+
"""Text classification via a single HuggingFace text-classification pipeline."""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
schema: TextClassificationPipelineSchema,
|
|
30
|
+
detector_key: str = "",
|
|
31
|
+
detector_name: str = "",
|
|
32
|
+
) -> None:
|
|
33
|
+
self._schema = schema
|
|
34
|
+
self._detector_key = detector_key
|
|
35
|
+
self._detector_name = detector_name
|
|
36
|
+
ensure_torch("text_classification", ["custom", "detectors"])
|
|
37
|
+
transformers = require_module(
|
|
38
|
+
"transformers", "text_classification", ["custom", "detectors"]
|
|
39
|
+
)
|
|
40
|
+
pipeline_kwargs: dict[str, Any] = {
|
|
41
|
+
"model": schema.model,
|
|
42
|
+
"device": schema.device or "cpu",
|
|
43
|
+
}
|
|
44
|
+
if schema.model_revision:
|
|
45
|
+
pipeline_kwargs["revision"] = schema.model_revision
|
|
46
|
+
if schema.top_k is not None:
|
|
47
|
+
pipeline_kwargs["top_k"] = schema.top_k
|
|
48
|
+
if schema.function_to_apply is not None:
|
|
49
|
+
pipeline_kwargs["function_to_apply"] = str(schema.function_to_apply)
|
|
50
|
+
self._pipe: Any = transformers.pipeline("text-classification", **pipeline_kwargs)
|
|
51
|
+
|
|
52
|
+
def run(self, text: str) -> None: # type: ignore[override] # pragma: no cover
|
|
53
|
+
raise NotImplementedError("TextClassificationRunner uses detect() directly")
|
|
54
|
+
|
|
55
|
+
def detect(self, content: str | bytes, content_type: str) -> list[DetectionResult]:
|
|
56
|
+
if isinstance(content, bytes):
|
|
57
|
+
return []
|
|
58
|
+
if content_type not in _TEXT_CONTENT_TYPES:
|
|
59
|
+
return []
|
|
60
|
+
text = content.strip()
|
|
61
|
+
if not text:
|
|
62
|
+
return []
|
|
63
|
+
|
|
64
|
+
schema = self._schema
|
|
65
|
+
chunk_size: int | None = getattr(schema.chunk_size, "root", schema.chunk_size)
|
|
66
|
+
chunk_overlap: int = getattr(schema.chunk_overlap, "root", schema.chunk_overlap) or 0
|
|
67
|
+
max_length: int | None = getattr(schema.max_length, "root", schema.max_length)
|
|
68
|
+
threshold = schema.confidence_threshold if schema.confidence_threshold is not None else 0.7
|
|
69
|
+
default_severity = schema.severity or Severity.info
|
|
70
|
+
|
|
71
|
+
best_scores: dict[str, float] = {}
|
|
72
|
+
try:
|
|
73
|
+
for chunk in _chunk_text(text, chunk_size, chunk_overlap):
|
|
74
|
+
call_kwargs: dict[str, Any] = {"truncation": True}
|
|
75
|
+
if max_length is not None:
|
|
76
|
+
call_kwargs["max_length"] = max_length
|
|
77
|
+
raw = self._pipe(chunk, **call_kwargs) or []
|
|
78
|
+
preds: list[dict[str, Any]] = raw[0] if raw and isinstance(raw[0], list) else raw
|
|
79
|
+
for pred in preds:
|
|
80
|
+
label: str = pred.get("label", "unknown")
|
|
81
|
+
score: float = float(pred.get("score", 0.0))
|
|
82
|
+
if score > best_scores.get(label, 0.0):
|
|
83
|
+
best_scores[label] = score
|
|
84
|
+
except Exception as exc:
|
|
85
|
+
logger.error(
|
|
86
|
+
"text_classification error (model=%s): %s", schema.model, exc, exc_info=True
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
results: list[DetectionResult] = []
|
|
90
|
+
for label, score in best_scores.items():
|
|
91
|
+
if score < threshold:
|
|
92
|
+
continue
|
|
93
|
+
severity = _resolve_pipeline_severity(label, schema.severity_map, default_severity)
|
|
94
|
+
results.append(
|
|
95
|
+
self._make_result(
|
|
96
|
+
finding_type=f"classification:{label}",
|
|
97
|
+
category="CONTENT",
|
|
98
|
+
severity=severity,
|
|
99
|
+
confidence=score,
|
|
100
|
+
matched_content=text[:512],
|
|
101
|
+
location=None,
|
|
102
|
+
metadata={"model": schema.model, "predicted_label": label, "score": score},
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
results.sort(key=lambda r: r.confidence, reverse=True)
|
|
106
|
+
return results
|
|
107
|
+
|
|
108
|
+
def get_supported_content_types(self) -> list[str]:
|
|
109
|
+
return list(_TEXT_CONTENT_TYPES)
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""GLiNER2 pipeline fine-tuning trainer.
|
|
2
|
+
|
|
3
|
+
Handles two training modes in a single pass:
|
|
4
|
+
- NER fine-tuning via the base GLiNER model (entity examples with span values)
|
|
5
|
+
- Zero-shot classification fine-tuning via SetFit (text + label pairs)
|
|
6
|
+
|
|
7
|
+
Artifacts are written to an output directory structured as:
|
|
8
|
+
<output_dir>/
|
|
9
|
+
gliner2/ -- fine-tuned GLiNER2 model weights (HuggingFace format)
|
|
10
|
+
setfit/<task>/ -- one SetFit model per classification task
|
|
11
|
+
manifest.json -- training metadata for the runner
|
|
12
|
+
|
|
13
|
+
The trainer produces a JSON result dict to stdout, which the API reads to update
|
|
14
|
+
the training run record.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import json
|
|
20
|
+
import logging
|
|
21
|
+
import time
|
|
22
|
+
from dataclasses import dataclass, field
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from typing import Any
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
# Minimum examples needed before we attempt fine-tuning (not just annotation storage)
|
|
29
|
+
_MIN_NER_EXAMPLES = 5
|
|
30
|
+
_MIN_SETFIT_PER_CLASS = 2
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class TrainingExample:
|
|
35
|
+
label: str
|
|
36
|
+
text: str
|
|
37
|
+
value: str | None = None # specific entity span text (NER only)
|
|
38
|
+
accepted: bool = True
|
|
39
|
+
source: str | None = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class TrainingResult:
|
|
44
|
+
status: str
|
|
45
|
+
trained_examples: int
|
|
46
|
+
positive_examples: int
|
|
47
|
+
negative_examples: int
|
|
48
|
+
model_artifact_path: str
|
|
49
|
+
metrics: dict[str, Any] = field(default_factory=dict)
|
|
50
|
+
|
|
51
|
+
def to_dict(self) -> dict[str, Any]:
|
|
52
|
+
return {
|
|
53
|
+
"status": self.status,
|
|
54
|
+
"trained_examples": self.trained_examples,
|
|
55
|
+
"positive_examples": self.positive_examples,
|
|
56
|
+
"negative_examples": self.negative_examples,
|
|
57
|
+
"model_artifact_path": self.model_artifact_path,
|
|
58
|
+
"metrics": self.metrics,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class GLiNER2Trainer:
|
|
63
|
+
"""Orchestrates NER + classification fine-tuning for a GLiNER2 pipeline."""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
pipeline_schema: dict[str, Any],
|
|
68
|
+
examples_raw: list[dict[str, Any]],
|
|
69
|
+
output_dir: Path,
|
|
70
|
+
) -> None:
|
|
71
|
+
self._schema = pipeline_schema
|
|
72
|
+
self._output_dir = output_dir
|
|
73
|
+
self._examples: list[TrainingExample] = [
|
|
74
|
+
TrainingExample(
|
|
75
|
+
label=str(ex.get("label", "")),
|
|
76
|
+
text=str(ex.get("text", "")),
|
|
77
|
+
value=ex.get("value") or None,
|
|
78
|
+
accepted=bool(ex.get("accepted", True)),
|
|
79
|
+
source=ex.get("source") or None,
|
|
80
|
+
)
|
|
81
|
+
for ex in examples_raw
|
|
82
|
+
if ex.get("label") and ex.get("text")
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
def train(self) -> TrainingResult:
|
|
86
|
+
t0 = time.monotonic()
|
|
87
|
+
self._output_dir.mkdir(parents=True, exist_ok=True)
|
|
88
|
+
|
|
89
|
+
positive = [e for e in self._examples if e.accepted]
|
|
90
|
+
negative = [e for e in self._examples if not e.accepted]
|
|
91
|
+
metrics: dict[str, Any] = {}
|
|
92
|
+
|
|
93
|
+
entities: dict[str, Any] = self._schema.get("entities") or {}
|
|
94
|
+
classification: dict[str, Any] = self._schema.get("classification") or {}
|
|
95
|
+
base_model: str = (self._schema.get("model") or {}).get("name") or "fastino/gliner2-base-v1"
|
|
96
|
+
|
|
97
|
+
# ── NER fine-tuning ────────────────────────────────────────────────────
|
|
98
|
+
entity_labels = set(entities.keys())
|
|
99
|
+
ner_examples = [e for e in positive if e.label in entity_labels and e.value]
|
|
100
|
+
if ner_examples:
|
|
101
|
+
metrics["ner"] = self._train_ner(ner_examples, base_model)
|
|
102
|
+
else:
|
|
103
|
+
metrics["ner"] = {"skipped": True, "reason": "No span-annotated NER examples"}
|
|
104
|
+
|
|
105
|
+
# ── Classification fine-tuning (SetFit) ────────────────────────────────
|
|
106
|
+
if classification:
|
|
107
|
+
metrics["classification"] = self._train_classification(positive, classification)
|
|
108
|
+
else:
|
|
109
|
+
metrics["classification"] = {
|
|
110
|
+
"skipped": True,
|
|
111
|
+
"reason": "No classification tasks defined",
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
# Write manifest so the runner knows what's available
|
|
115
|
+
manifest: dict[str, Any] = {
|
|
116
|
+
"schema_type": self._schema.get("type", "GLINER2"),
|
|
117
|
+
"base_model": base_model,
|
|
118
|
+
"trained_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
119
|
+
"metrics": metrics,
|
|
120
|
+
}
|
|
121
|
+
(self._output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2))
|
|
122
|
+
|
|
123
|
+
metrics["duration_s"] = round(time.monotonic() - t0, 2)
|
|
124
|
+
|
|
125
|
+
return TrainingResult(
|
|
126
|
+
status="SUCCEEDED",
|
|
127
|
+
trained_examples=len(positive),
|
|
128
|
+
positive_examples=len(positive),
|
|
129
|
+
negative_examples=len(negative),
|
|
130
|
+
model_artifact_path=str(self._output_dir),
|
|
131
|
+
metrics=metrics,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# ── NER ────────────────────────────────────────────────────────────────────
|
|
135
|
+
|
|
136
|
+
def _train_ner(self, examples: list[TrainingExample], base_model: str) -> dict[str, Any]:
|
|
137
|
+
if len(examples) < _MIN_NER_EXAMPLES:
|
|
138
|
+
return {
|
|
139
|
+
"skipped": True,
|
|
140
|
+
"reason": f"Need ≥{_MIN_NER_EXAMPLES} span-annotated examples (got {len(examples)})",
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
# Build GLiNER-format span annotations
|
|
144
|
+
train_data: list[dict[str, Any]] = []
|
|
145
|
+
for ex in examples:
|
|
146
|
+
if not ex.value:
|
|
147
|
+
continue
|
|
148
|
+
start = ex.text.find(ex.value)
|
|
149
|
+
if start < 0:
|
|
150
|
+
continue
|
|
151
|
+
train_data.append(
|
|
152
|
+
{
|
|
153
|
+
"text": ex.text,
|
|
154
|
+
"ner": [{"start": start, "end": start + len(ex.value), "label": ex.label}],
|
|
155
|
+
}
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if len(train_data) < _MIN_NER_EXAMPLES:
|
|
159
|
+
return {
|
|
160
|
+
"skipped": True,
|
|
161
|
+
"reason": f"Too few locatable spans after search (got {len(train_data)})",
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
gliner_out = self._output_dir / "gliner2"
|
|
165
|
+
gliner_out.mkdir(parents=True, exist_ok=True)
|
|
166
|
+
|
|
167
|
+
try:
|
|
168
|
+
from gliner import GLiNER # type: ignore[import-untyped]
|
|
169
|
+
|
|
170
|
+
model = GLiNER.from_pretrained(base_model)
|
|
171
|
+
|
|
172
|
+
# Attempt to use the trainer API if available
|
|
173
|
+
try:
|
|
174
|
+
from gliner.training import Trainer as GLiNERTrainer # type: ignore[import-untyped]
|
|
175
|
+
from gliner.training import (
|
|
176
|
+
TrainingArguments as GLiNERArgs, # type: ignore[import-untyped]
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
args = GLiNERArgs(
|
|
180
|
+
output_dir=str(gliner_out),
|
|
181
|
+
num_train_epochs=3,
|
|
182
|
+
per_device_train_batch_size=min(4, len(train_data)),
|
|
183
|
+
warmup_ratio=0.1,
|
|
184
|
+
save_steps=0,
|
|
185
|
+
)
|
|
186
|
+
trainer = GLiNERTrainer(model=model, args=args, train_dataset=train_data)
|
|
187
|
+
trainer.train()
|
|
188
|
+
except ImportError:
|
|
189
|
+
# Fallback: manual training loop if the Trainer API isn't available
|
|
190
|
+
logger.warning("gliner.training not available — using manual fine-tuning loop")
|
|
191
|
+
self._manual_ner_train(model, train_data)
|
|
192
|
+
|
|
193
|
+
model.save_pretrained(str(gliner_out))
|
|
194
|
+
logger.info("NER model saved to %s", gliner_out)
|
|
195
|
+
return {"examples": len(train_data), "epochs": 3, "saved_to": "gliner2/"}
|
|
196
|
+
|
|
197
|
+
except ImportError as e:
|
|
198
|
+
return {"skipped": True, "reason": f"gliner not installed: {e}"}
|
|
199
|
+
except Exception as e:
|
|
200
|
+
logger.warning("NER fine-tuning failed: %s", e, exc_info=True)
|
|
201
|
+
return {"skipped": True, "reason": str(e)}
|
|
202
|
+
|
|
203
|
+
def _manual_ner_train(self, model: Any, train_data: list[dict[str, Any]]) -> None:
|
|
204
|
+
"""Simple SGD loop when Trainer API is unavailable."""
|
|
205
|
+
import torch # type: ignore[import-untyped]
|
|
206
|
+
|
|
207
|
+
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
|
208
|
+
model.train()
|
|
209
|
+
for _epoch in range(3):
|
|
210
|
+
for item in train_data:
|
|
211
|
+
optimizer.zero_grad()
|
|
212
|
+
loss = model.compute_loss(item)
|
|
213
|
+
if loss is not None:
|
|
214
|
+
loss.backward()
|
|
215
|
+
optimizer.step()
|
|
216
|
+
|
|
217
|
+
# ── Classification (SetFit) ────────────────────────────────────────────────
|
|
218
|
+
|
|
219
|
+
def _train_classification(
|
|
220
|
+
self,
|
|
221
|
+
positive: list[TrainingExample],
|
|
222
|
+
classification: dict[str, Any],
|
|
223
|
+
) -> dict[str, Any]:
|
|
224
|
+
task_results: dict[str, Any] = {}
|
|
225
|
+
for task_name, task_defn in classification.items():
|
|
226
|
+
labels: list[str] = task_defn.get("labels") or []
|
|
227
|
+
task_examples = [e for e in positive if e.label in labels]
|
|
228
|
+
if len(task_examples) < _MIN_SETFIT_PER_CLASS * max(len(labels), 1):
|
|
229
|
+
task_results[task_name] = {
|
|
230
|
+
"skipped": True,
|
|
231
|
+
"reason": (
|
|
232
|
+
f"Need ≥{_MIN_SETFIT_PER_CLASS} examples per label "
|
|
233
|
+
f"(got {len(task_examples)} across {len(labels)} labels)"
|
|
234
|
+
),
|
|
235
|
+
}
|
|
236
|
+
continue
|
|
237
|
+
try:
|
|
238
|
+
task_results[task_name] = self._train_setfit_task(task_name, task_examples, labels)
|
|
239
|
+
except Exception as e:
|
|
240
|
+
logger.warning("SetFit training failed for task '%s': %s", task_name, e)
|
|
241
|
+
task_results[task_name] = {"skipped": True, "reason": str(e)}
|
|
242
|
+
return task_results
|
|
243
|
+
|
|
244
|
+
def _train_setfit_task(
|
|
245
|
+
self,
|
|
246
|
+
task_name: str,
|
|
247
|
+
examples: list[TrainingExample],
|
|
248
|
+
labels: list[str],
|
|
249
|
+
) -> dict[str, Any]:
|
|
250
|
+
try:
|
|
251
|
+
from datasets import Dataset # type: ignore[import-untyped]
|
|
252
|
+
from setfit import ( # type: ignore[import-untyped]
|
|
253
|
+
SetFitModel,
|
|
254
|
+
Trainer,
|
|
255
|
+
TrainingArguments,
|
|
256
|
+
)
|
|
257
|
+
except ImportError as e:
|
|
258
|
+
return {"skipped": True, "reason": f"setfit/datasets not installed: {e}"}
|
|
259
|
+
|
|
260
|
+
label2id = {label: i for i, label in enumerate(labels)}
|
|
261
|
+
dataset = Dataset.from_dict(
|
|
262
|
+
{
|
|
263
|
+
"text": [ex.text for ex in examples],
|
|
264
|
+
"label": [label2id.get(ex.label, 0) for ex in examples],
|
|
265
|
+
}
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
model = SetFitModel.from_pretrained(
|
|
269
|
+
"sentence-transformers/paraphrase-MiniLM-L6-v2",
|
|
270
|
+
labels=labels,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
save_path = self._output_dir / "setfit" / task_name
|
|
274
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
275
|
+
|
|
276
|
+
args = TrainingArguments(
|
|
277
|
+
output_dir=str(save_path),
|
|
278
|
+
num_epochs=1,
|
|
279
|
+
batch_size=min(8, len(examples)),
|
|
280
|
+
)
|
|
281
|
+
trainer = Trainer(model=model, args=args, train_dataset=dataset)
|
|
282
|
+
trainer.train()
|
|
283
|
+
model.save_pretrained(str(save_path))
|
|
284
|
+
|
|
285
|
+
# Persist label ordering so the runner can decode predictions
|
|
286
|
+
(save_path / "labels.json").write_text(json.dumps(labels))
|
|
287
|
+
|
|
288
|
+
logger.info("SetFit model for task '%s' saved to %s", task_name, save_path)
|
|
289
|
+
return {
|
|
290
|
+
"examples": len(examples),
|
|
291
|
+
"labels": labels,
|
|
292
|
+
"saved_to": f"setfit/{task_name}/",
|
|
293
|
+
}
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Helpers for optional detector dependencies."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import importlib
|
|
6
|
+
import logging
|
|
7
|
+
from types import ModuleType
|
|
8
|
+
|
|
9
|
+
from src.utils.uv_sync import auto_install_enabled, sync_group
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MissingDependencyError(RuntimeError):
|
|
15
|
+
"""Raised when an optional detector dependency is unavailable or invalid."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
detector_name: str,
|
|
20
|
+
dependencies: list[str],
|
|
21
|
+
uv_groups: list[str],
|
|
22
|
+
detail: str | None = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
self.detector_name = detector_name
|
|
25
|
+
self.dependencies = dependencies
|
|
26
|
+
self.uv_groups = uv_groups
|
|
27
|
+
self.detail = detail
|
|
28
|
+
|
|
29
|
+
deps = ", ".join(dependencies)
|
|
30
|
+
group_hint = " or ".join(f"`uv sync --group {group}`" for group in uv_groups)
|
|
31
|
+
message = (
|
|
32
|
+
f"{detector_name} detector requires optional dependencies ({deps}). "
|
|
33
|
+
f"Install with {group_hint}."
|
|
34
|
+
)
|
|
35
|
+
if detail:
|
|
36
|
+
message = f"{message} {detail}"
|
|
37
|
+
|
|
38
|
+
super().__init__(message)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _ordered_groups(groups: list[str]) -> list[str]:
|
|
42
|
+
unique = list(dict.fromkeys(groups))
|
|
43
|
+
return sorted(unique, key=lambda group: (group == "detectors", group))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def require_module(
|
|
47
|
+
module_name: str,
|
|
48
|
+
detector_name: str,
|
|
49
|
+
uv_groups: list[str],
|
|
50
|
+
detail: str | None = None,
|
|
51
|
+
) -> ModuleType:
|
|
52
|
+
"""Import a module or raise a MissingDependencyError with uv guidance."""
|
|
53
|
+
try:
|
|
54
|
+
return importlib.import_module(module_name)
|
|
55
|
+
except Exception as exc: # pragma: no cover - exercised indirectly in integration setups
|
|
56
|
+
detail_messages: list[str] = [f"Original error: {exc}"]
|
|
57
|
+
|
|
58
|
+
if auto_install_enabled() and uv_groups:
|
|
59
|
+
for group in _ordered_groups(uv_groups):
|
|
60
|
+
success, install_detail = sync_group(group)
|
|
61
|
+
if install_detail:
|
|
62
|
+
detail_messages.append(install_detail)
|
|
63
|
+
if not success:
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
importlib.invalidate_caches()
|
|
68
|
+
return importlib.import_module(module_name)
|
|
69
|
+
except Exception as retry_exc: # pragma: no cover
|
|
70
|
+
detail_messages.append(
|
|
71
|
+
f"Module '{module_name}' still unavailable after installing '{group}': {retry_exc}"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
base_detail = detail or "Optional dependency import failed"
|
|
75
|
+
error_detail = (
|
|
76
|
+
f"{base_detail}. {'; '.join(detail_messages)}" if detail_messages else base_detail
|
|
77
|
+
)
|
|
78
|
+
raise MissingDependencyError(
|
|
79
|
+
detector_name=detector_name,
|
|
80
|
+
dependencies=[module_name.split(".", maxsplit=1)[0]],
|
|
81
|
+
uv_groups=uv_groups,
|
|
82
|
+
detail=error_detail,
|
|
83
|
+
) from exc
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def ensure_torch(detector_name: str, uv_groups: list[str]) -> ModuleType:
|
|
87
|
+
"""Verify PyTorch is importable and looks like a valid install."""
|
|
88
|
+
torch_module = require_module("torch", detector_name, uv_groups)
|
|
89
|
+
if not hasattr(torch_module, "no_grad"):
|
|
90
|
+
raise MissingDependencyError(
|
|
91
|
+
detector_name=detector_name,
|
|
92
|
+
dependencies=["torch"],
|
|
93
|
+
uv_groups=uv_groups,
|
|
94
|
+
detail=(
|
|
95
|
+
"Detected a module named 'torch' but it is missing `no_grad`. "
|
|
96
|
+
"Ensure PyTorch is installed via uv and no local `torch.py` shadows it."
|
|
97
|
+
),
|
|
98
|
+
)
|
|
99
|
+
if not hasattr(torch_module, "_utils"):
|
|
100
|
+
raise MissingDependencyError(
|
|
101
|
+
detector_name=detector_name,
|
|
102
|
+
dependencies=["torch"],
|
|
103
|
+
uv_groups=uv_groups,
|
|
104
|
+
detail=(
|
|
105
|
+
"Detected an incomplete/broken PyTorch install (`torch._utils` missing). "
|
|
106
|
+
"Reinstall torch via uv for this environment."
|
|
107
|
+
),
|
|
108
|
+
)
|
|
109
|
+
return torch_module
|
|
File without changes
|