netgreener 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- netgreener/__init__.py +51 -0
- netgreener/_metrics_server.py +59 -0
- netgreener/analyzer.py +295 -0
- netgreener/api_client.py +123 -0
- netgreener/cli.py +433 -0
- netgreener/config.py +66 -0
- netgreener/constants.py +3 -0
- netgreener/executor.py +132 -0
- netgreener/nlp_analyzer.py +123 -0
- netgreener/reporter.py +79 -0
- netgreener-0.1.0.dist-info/METADATA +15 -0
- netgreener-0.1.0.dist-info/RECORD +15 -0
- netgreener-0.1.0.dist-info/WHEEL +5 -0
- netgreener-0.1.0.dist-info/entry_points.txt +2 -0
- netgreener-0.1.0.dist-info/top_level.txt +1 -0
netgreener/__init__.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""
|
|
2
|
+
netgreener — public API for use inside ML training scripts.
|
|
3
|
+
|
|
4
|
+
Usage inside a user's train.py:
|
|
5
|
+
import netgreener as ng
|
|
6
|
+
ng.log_metrics(accuracy=0.95, f1=0.88, precision=0.90, recall=0.87, auc=0.96)
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import os
|
|
11
|
+
import socket
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def log_metrics(
|
|
15
|
+
accuracy: float | None = None,
|
|
16
|
+
f1: float | None = None,
|
|
17
|
+
precision: float | None = None,
|
|
18
|
+
recall: float | None = None,
|
|
19
|
+
auc: float | None = None,
|
|
20
|
+
**extra: float,
|
|
21
|
+
) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Report accuracy metrics back to the NetGreener CLI runner.
|
|
24
|
+
Call this anywhere in your script after evaluation is complete.
|
|
25
|
+
Safe to call even when not running under netgreener (silently no-ops).
|
|
26
|
+
"""
|
|
27
|
+
port_str = os.environ.get("NETGREENER_METRICS_PORT")
|
|
28
|
+
if not port_str:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
payload: dict = {}
|
|
32
|
+
if accuracy is not None:
|
|
33
|
+
payload["model_accuracy"] = accuracy
|
|
34
|
+
if f1 is not None:
|
|
35
|
+
payload["f1_score"] = f1
|
|
36
|
+
if precision is not None:
|
|
37
|
+
payload["model_precision"] = precision
|
|
38
|
+
if recall is not None:
|
|
39
|
+
payload["recall"] = recall
|
|
40
|
+
if auc is not None:
|
|
41
|
+
payload["auc_score"] = auc
|
|
42
|
+
payload.update(extra)
|
|
43
|
+
|
|
44
|
+
if not payload:
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
with socket.create_connection(("127.0.0.1", int(port_str)), timeout=2) as s:
|
|
49
|
+
s.sendall(json.dumps(payload).encode())
|
|
50
|
+
except OSError:
|
|
51
|
+
pass
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local TCP server that receives ng.log_metrics() calls from the user's running script.
|
|
3
|
+
|
|
4
|
+
The CLI starts this server before launching the script and injects the port via
|
|
5
|
+
NETGREENER_METRICS_PORT. The user-facing netgreener.log_metrics() connects to it.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import socket
|
|
10
|
+
import threading
|
|
11
|
+
from typing import Callable
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MetricsServer:
|
|
15
|
+
"""Listens on a loopback TCP port for metric payloads from the child process."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, on_metrics: Callable[[dict], None]):
|
|
18
|
+
self._on_metrics = on_metrics
|
|
19
|
+
self._server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
20
|
+
self._server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
21
|
+
self._server.bind(("127.0.0.1", 0))
|
|
22
|
+
self._server.listen(5)
|
|
23
|
+
self._server.settimeout(0.5)
|
|
24
|
+
self.port: int = self._server.getsockname()[1]
|
|
25
|
+
self._running = False
|
|
26
|
+
self._thread: threading.Thread | None = None
|
|
27
|
+
|
|
28
|
+
def start(self) -> None:
|
|
29
|
+
self._running = True
|
|
30
|
+
self._thread = threading.Thread(target=self._serve, daemon=True)
|
|
31
|
+
self._thread.start()
|
|
32
|
+
|
|
33
|
+
def stop(self) -> None:
|
|
34
|
+
self._running = False
|
|
35
|
+
try:
|
|
36
|
+
self._server.close()
|
|
37
|
+
except OSError:
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
def _serve(self) -> None:
|
|
41
|
+
while self._running:
|
|
42
|
+
try:
|
|
43
|
+
conn, _ = self._server.accept()
|
|
44
|
+
except (socket.timeout, OSError):
|
|
45
|
+
continue
|
|
46
|
+
try:
|
|
47
|
+
conn.settimeout(5.0)
|
|
48
|
+
data = bytearray()
|
|
49
|
+
try:
|
|
50
|
+
while chunk := conn.recv(4096):
|
|
51
|
+
data += chunk
|
|
52
|
+
except socket.timeout:
|
|
53
|
+
pass
|
|
54
|
+
payload = json.loads(bytes(data).decode())
|
|
55
|
+
self._on_metrics(payload)
|
|
56
|
+
except Exception:
|
|
57
|
+
pass
|
|
58
|
+
finally:
|
|
59
|
+
conn.close()
|
netgreener/analyzer.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Headless code analysis — wraps CodeAnalyzer without any PyQt5 dependency.
|
|
3
|
+
Extracts AST features, Halstead metrics, raw Radon metrics, library domains,
|
|
4
|
+
and derives CodePattern records for the supervisor dashboard.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import sys
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
_CODEANALYZER_SRC = Path(__file__).resolve().parents[3] / "NetGreenerCodeAnalyzer" / "Src"
|
|
12
|
+
if _CODEANALYZER_SRC.exists() and str(_CODEANALYZER_SRC) not in sys.path:
|
|
13
|
+
sys.path.insert(0, str(_CODEANALYZER_SRC))
|
|
14
|
+
|
|
15
|
+
_LIBRARY_DOMAINS: dict[str, str] = {
|
|
16
|
+
# Computer vision
|
|
17
|
+
"cv2": "computer_vision",
|
|
18
|
+
"PIL": "computer_vision",
|
|
19
|
+
"Pillow": "computer_vision",
|
|
20
|
+
"skimage": "computer_vision",
|
|
21
|
+
"torchvision": "computer_vision",
|
|
22
|
+
"imageio": "computer_vision",
|
|
23
|
+
"albumentations": "computer_vision",
|
|
24
|
+
# Deep learning
|
|
25
|
+
"tensorflow": "deep_learning",
|
|
26
|
+
"tf": "deep_learning",
|
|
27
|
+
"keras": "deep_learning",
|
|
28
|
+
"torch": "deep_learning",
|
|
29
|
+
"jax": "deep_learning",
|
|
30
|
+
"flax": "deep_learning",
|
|
31
|
+
"paddle": "deep_learning",
|
|
32
|
+
"mxnet": "deep_learning",
|
|
33
|
+
"fastai": "deep_learning",
|
|
34
|
+
"lightning": "deep_learning",
|
|
35
|
+
"pytorch_lightning": "deep_learning",
|
|
36
|
+
# Classical ML
|
|
37
|
+
"sklearn": "classical_ml",
|
|
38
|
+
"xgboost": "classical_ml",
|
|
39
|
+
"lightgbm": "classical_ml",
|
|
40
|
+
"catboost": "classical_ml",
|
|
41
|
+
"statsmodels": "classical_ml",
|
|
42
|
+
# NLP
|
|
43
|
+
"transformers": "nlp",
|
|
44
|
+
"spacy": "nlp",
|
|
45
|
+
"nltk": "nlp",
|
|
46
|
+
"gensim": "nlp",
|
|
47
|
+
"sentence_transformers": "nlp",
|
|
48
|
+
# Data
|
|
49
|
+
"pandas": "data",
|
|
50
|
+
"numpy": "data",
|
|
51
|
+
"scipy": "data",
|
|
52
|
+
"polars": "data",
|
|
53
|
+
"pyarrow": "data",
|
|
54
|
+
# Visualization
|
|
55
|
+
"matplotlib": "visualization",
|
|
56
|
+
"seaborn": "visualization",
|
|
57
|
+
"plotly": "visualization",
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _detect_domains(imports_detail: list[tuple]) -> list[str]:
|
|
62
|
+
domains: set[str] = set()
|
|
63
|
+
for module, _ in imports_detail:
|
|
64
|
+
if module:
|
|
65
|
+
root = module.split(".")[0]
|
|
66
|
+
if domain := _LIBRARY_DOMAINS.get(root):
|
|
67
|
+
domains.add(domain)
|
|
68
|
+
return sorted(domains)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _safe_halstead(h) -> dict[str, float]:
|
|
72
|
+
"""Extract scalar Halstead fields defensively — radon's return type varies by version."""
|
|
73
|
+
try:
|
|
74
|
+
# h_visit may return a namedtuple or an object with a .total attribute
|
|
75
|
+
src = h.total if hasattr(h, "total") else h
|
|
76
|
+
return {
|
|
77
|
+
"halstead_vocabulary": int(getattr(src, "vocabulary", 0)),
|
|
78
|
+
"halstead_length": int(getattr(src, "length", 0)),
|
|
79
|
+
"halstead_volume": round(float(getattr(src, "volume", 0.0)), 2),
|
|
80
|
+
"halstead_difficulty": round(float(getattr(src, "difficulty", 0.0)), 2),
|
|
81
|
+
"halstead_effort": round(float(getattr(src, "effort", 0.0)), 2),
|
|
82
|
+
"halstead_bugs": round(float(getattr(src, "bugs", 0.0)), 4),
|
|
83
|
+
}
|
|
84
|
+
except Exception:
|
|
85
|
+
return {k: 0 for k in ("halstead_vocabulary", "halstead_length",
|
|
86
|
+
"halstead_volume", "halstead_difficulty",
|
|
87
|
+
"halstead_effort", "halstead_bugs")}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def detect_code_patterns(features: dict[str, Any], file_path: str, project_id: int) -> list[dict]:
|
|
91
|
+
"""Map per-file AST metrics to CodePattern records for the supervisor dashboard."""
|
|
92
|
+
patterns: list[dict] = []
|
|
93
|
+
|
|
94
|
+
loc = features.get("lines_of_code", 0)
|
|
95
|
+
mi = features.get("maintainability_index", 50.0)
|
|
96
|
+
max_cc = features.get("max_cyclomatic_complexity", 0)
|
|
97
|
+
comments = features.get("comments_count", 0)
|
|
98
|
+
exception_nodes = features.get("count_exception_handling_nodes", 0)
|
|
99
|
+
depth = features.get("depth_of_ast", 0)
|
|
100
|
+
|
|
101
|
+
def _pat(pattern_type, name, severity, description, confidence=0.9):
|
|
102
|
+
return {
|
|
103
|
+
"project_id": project_id,
|
|
104
|
+
"file_path": file_path,
|
|
105
|
+
"pattern_type": pattern_type,
|
|
106
|
+
"pattern_name": name,
|
|
107
|
+
"severity": severity,
|
|
108
|
+
"description": description,
|
|
109
|
+
"confidence": confidence,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
# Cyclomatic complexity
|
|
113
|
+
if max_cc > 20:
|
|
114
|
+
patterns.append(_pat("code_smell", "High Cyclomatic Complexity", "critical",
|
|
115
|
+
f"Max function complexity {max_cc} (>20). Very hard to test and maintain."))
|
|
116
|
+
elif max_cc > 10:
|
|
117
|
+
patterns.append(_pat("code_smell", "High Cyclomatic Complexity", "high",
|
|
118
|
+
f"Max function complexity {max_cc} (>10). Consider splitting functions.", 0.85))
|
|
119
|
+
|
|
120
|
+
# Maintainability index
|
|
121
|
+
if mi < 20:
|
|
122
|
+
patterns.append(_pat("anti_pattern", "Low Maintainability", "high",
|
|
123
|
+
f"Maintainability index {mi:.1f} (<20). Code is difficult to maintain."))
|
|
124
|
+
elif mi > 65:
|
|
125
|
+
patterns.append(_pat("design_pattern", "High Maintainability", "low",
|
|
126
|
+
f"Maintainability index {mi:.1f} (>65). Code is clean and well structured.", 0.85))
|
|
127
|
+
|
|
128
|
+
# Documentation
|
|
129
|
+
if loc > 20:
|
|
130
|
+
comment_ratio = comments / loc
|
|
131
|
+
if comment_ratio > 0.2:
|
|
132
|
+
patterns.append(_pat("design_pattern", "Well Documented Code", "low",
|
|
133
|
+
f"Comment ratio {comment_ratio:.0%}. Good inline documentation.", 0.8))
|
|
134
|
+
elif comment_ratio < 0.02:
|
|
135
|
+
patterns.append(_pat("code_smell", "Insufficient Documentation", "low",
|
|
136
|
+
f"Comment ratio {comment_ratio:.0%}. Consider adding docstrings.", 0.75))
|
|
137
|
+
|
|
138
|
+
# Error handling
|
|
139
|
+
if exception_nodes > 0:
|
|
140
|
+
patterns.append(_pat("design_pattern", "Error Handling Present", "low",
|
|
141
|
+
f"{exception_nodes} exception handling block(s) found.", 0.8))
|
|
142
|
+
|
|
143
|
+
# Deep nesting
|
|
144
|
+
if depth > 25:
|
|
145
|
+
patterns.append(_pat("code_smell", "Deep Nesting", "medium",
|
|
146
|
+
f"AST depth {depth} (>25). Deeply nested code is hard to follow.", 0.8))
|
|
147
|
+
|
|
148
|
+
return patterns
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def analyze_file(source_code: str) -> dict[str, Any]:
|
|
152
|
+
"""
|
|
153
|
+
Run full CodeAnalyzer pipeline on a source string.
|
|
154
|
+
Returns a flat dict of 30+ features including Halstead, Radon raw, and library domains.
|
|
155
|
+
Returns a dict with an 'error' key on failure.
|
|
156
|
+
"""
|
|
157
|
+
try:
|
|
158
|
+
from Libs.CodeAnalyzer.FeatureExtraction.ASTFeatures.ast_base_analysis import AstBaseAnalyzer
|
|
159
|
+
from Libs.CodeAnalyzer.FeatureExtraction.ASTFeatures.node_counter import NodeCounter
|
|
160
|
+
from Libs.CodeAnalyzer.FeatureExtraction.ASTFeatures.node_detail_extractor import NodeDetailExtractor
|
|
161
|
+
from Libs.CodeAnalyzer.FeatureExtraction.ASTFeatures.metrics_calculator import MetricsCalculator
|
|
162
|
+
except ImportError as e:
|
|
163
|
+
return {"error": f"CodeAnalyzer not available: {e}"}
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
base = AstBaseAnalyzer(source_code)
|
|
167
|
+
counter = NodeCounter(base)
|
|
168
|
+
detail = NodeDetailExtractor(base)
|
|
169
|
+
metrics = MetricsCalculator(base)
|
|
170
|
+
|
|
171
|
+
imports_detail = detail.import_nodes_detail()
|
|
172
|
+
_, comment_count = base.extract_comments()
|
|
173
|
+
node_count = counter.count_node()
|
|
174
|
+
mi_result = metrics.calculate_maintainability_index_rank()
|
|
175
|
+
|
|
176
|
+
# Cyclomatic complexity — returns list of (name, cc, rank) per function
|
|
177
|
+
cc_list = metrics.calculate_cyclomatic_complexity()
|
|
178
|
+
cc_values = [c for _, c, _ in cc_list] if cc_list else []
|
|
179
|
+
avg_cc = sum(cc_values) / len(cc_values) if cc_values else 0.0
|
|
180
|
+
max_cc = max(cc_values, default=0)
|
|
181
|
+
|
|
182
|
+
# Halstead metrics
|
|
183
|
+
halstead = _safe_halstead(metrics.calculate_halstead_metrics())
|
|
184
|
+
|
|
185
|
+
# Raw Radon metrics (LLOC, SLOC, blank, multi-line strings)
|
|
186
|
+
raw = metrics.calculate_raw_metrics()
|
|
187
|
+
|
|
188
|
+
features: dict[str, Any] = {
|
|
189
|
+
# AST node counts
|
|
190
|
+
"lines_of_code": counter.count_lines_of_code(),
|
|
191
|
+
"number_of_tokens": counter.count_tokens(),
|
|
192
|
+
"depth_of_ast": counter.depth_of_ast(),
|
|
193
|
+
"number_of_nodes": node_count,
|
|
194
|
+
"number_of_edges": node_count - 1,
|
|
195
|
+
"number_of_imports": counter.count_imports(),
|
|
196
|
+
"imports_detail": imports_detail,
|
|
197
|
+
"count_class_nodes": counter.count_class_nodes(),
|
|
198
|
+
"count_function_nodes": counter.count_function_nodes(),
|
|
199
|
+
"count_call_function_nodes": counter.count_call_function_nodes(),
|
|
200
|
+
"count_control_flow_nodes": counter.count_control_flow_nodes(),
|
|
201
|
+
"count_exception_handling_nodes": counter.count_exception_handling_nodes(),
|
|
202
|
+
"count_expression_nodes": counter.count_expression_nodes(),
|
|
203
|
+
"count_assignment_nodes": counter.count_assignment_nodes(),
|
|
204
|
+
"count_binary_operations": counter.count_binary_operation_nodes(),
|
|
205
|
+
"count_unary_operations": counter.count_unary_operation_nodes(),
|
|
206
|
+
"count_lambda_function_nodes": counter.count_lambda_function_nodes(),
|
|
207
|
+
"comments_count": comment_count,
|
|
208
|
+
# Complexity / quality
|
|
209
|
+
"avg_cyclomatic_complexity": round(avg_cc, 2),
|
|
210
|
+
"max_cyclomatic_complexity": max_cc,
|
|
211
|
+
"maintainability_index": mi_result[0],
|
|
212
|
+
"maintainability_index_rank": mi_result[1],
|
|
213
|
+
# Radon raw
|
|
214
|
+
"lloc": raw.lloc,
|
|
215
|
+
"sloc": raw.sloc,
|
|
216
|
+
"blank_lines": raw.blank,
|
|
217
|
+
"multi_line_strings": raw.multi,
|
|
218
|
+
# Halstead
|
|
219
|
+
**halstead,
|
|
220
|
+
# Library domains
|
|
221
|
+
"library_domains": _detect_domains(imports_detail),
|
|
222
|
+
}
|
|
223
|
+
return features
|
|
224
|
+
|
|
225
|
+
except SyntaxError as e:
|
|
226
|
+
return {"error": f"SyntaxError: {e}"}
|
|
227
|
+
except Exception as e:
|
|
228
|
+
return {"error": str(e)}
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def analyze_project(project_dir: str) -> dict[str, Any]:
|
|
232
|
+
"""
|
|
233
|
+
Analyze all .py files in a project directory.
|
|
234
|
+
Returns aggregated features, per-file feature cache, and errors.
|
|
235
|
+
The caller should call detect_code_patterns() per file once project_id is known.
|
|
236
|
+
"""
|
|
237
|
+
root = Path(project_dir)
|
|
238
|
+
py_files = list(root.rglob("*.py"))
|
|
239
|
+
|
|
240
|
+
all_domains: set[str] = set()
|
|
241
|
+
total_loc = 0
|
|
242
|
+
total_imports = 0
|
|
243
|
+
total_avg_cc = 0.0
|
|
244
|
+
total_mi = 0.0
|
|
245
|
+
total_halstead_volume = 0.0
|
|
246
|
+
total_halstead_difficulty = 0.0
|
|
247
|
+
total_halstead_bugs = 0.0
|
|
248
|
+
file_count = 0
|
|
249
|
+
errors: list[str] = []
|
|
250
|
+
per_file: list[dict] = [] # {"rel_path": str, "features": dict}
|
|
251
|
+
|
|
252
|
+
for path in py_files:
|
|
253
|
+
try:
|
|
254
|
+
source = path.read_text(encoding="utf-8", errors="replace")
|
|
255
|
+
except OSError:
|
|
256
|
+
continue
|
|
257
|
+
result = analyze_file(source)
|
|
258
|
+
if "error" in result:
|
|
259
|
+
errors.append(f"{path.name}: {result['error']}")
|
|
260
|
+
continue
|
|
261
|
+
file_count += 1
|
|
262
|
+
total_loc += result.get("lines_of_code", 0)
|
|
263
|
+
total_imports += result.get("number_of_imports", 0)
|
|
264
|
+
total_avg_cc += result.get("avg_cyclomatic_complexity", 0.0)
|
|
265
|
+
total_mi += result.get("maintainability_index", 0.0)
|
|
266
|
+
total_halstead_volume += result.get("halstead_volume", 0.0)
|
|
267
|
+
total_halstead_difficulty += result.get("halstead_difficulty", 0.0)
|
|
268
|
+
total_halstead_bugs += result.get("halstead_bugs", 0.0)
|
|
269
|
+
all_domains.update(result.get("library_domains", []))
|
|
270
|
+
per_file.append({"rel_path": str(path.relative_to(root)), "features": result})
|
|
271
|
+
|
|
272
|
+
def _avg(total: float) -> float:
|
|
273
|
+
return round(total / file_count, 2) if file_count else 0.0
|
|
274
|
+
|
|
275
|
+
return {
|
|
276
|
+
"file_count": file_count,
|
|
277
|
+
"total_lines_of_code": total_loc,
|
|
278
|
+
"total_imports": total_imports,
|
|
279
|
+
"avg_cyclomatic_complexity": _avg(total_avg_cc),
|
|
280
|
+
"avg_maintainability_index": _avg(total_mi),
|
|
281
|
+
"avg_halstead_volume": _avg(total_halstead_volume),
|
|
282
|
+
"avg_halstead_difficulty": _avg(total_halstead_difficulty),
|
|
283
|
+
"total_halstead_bugs": round(total_halstead_bugs, 3),
|
|
284
|
+
"library_domains": sorted(all_domains),
|
|
285
|
+
"per_file": per_file,
|
|
286
|
+
"errors": errors,
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def build_patterns(per_file: list[dict], project_id: int) -> list[dict]:
|
|
291
|
+
"""Convert cached per-file features to CodePattern records for a known project_id."""
|
|
292
|
+
patterns: list[dict] = []
|
|
293
|
+
for entry in per_file:
|
|
294
|
+
patterns.extend(detect_code_patterns(entry["features"], entry["rel_path"], project_id))
|
|
295
|
+
return patterns
|
netgreener/api_client.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Lightweight API client for NetGreener CLI — wraps the REST endpoints needed by the CLI.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from .config import get_api_url, load_credentials
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class APIError(Exception):
|
|
12
|
+
def __init__(self, status_code: int, detail: str):
|
|
13
|
+
self.status_code = status_code
|
|
14
|
+
self.detail = detail
|
|
15
|
+
super().__init__(f"HTTP {status_code}: {detail}")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class NetGreenerClient:
|
|
19
|
+
def __init__(self, token: str | None = None, base_url: str | None = None):
|
|
20
|
+
self.base_url = (base_url or get_api_url()) + "/api/v1"
|
|
21
|
+
self._token = token
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def from_credentials(cls) -> "NetGreenerClient":
|
|
25
|
+
creds = load_credentials()
|
|
26
|
+
if not creds:
|
|
27
|
+
raise APIError(401, "Not logged in. Run: netgreener login")
|
|
28
|
+
return cls(token=creds["access_token"])
|
|
29
|
+
|
|
30
|
+
def _headers(self) -> dict:
|
|
31
|
+
h = {"Content-Type": "application/json"}
|
|
32
|
+
if self._token:
|
|
33
|
+
h["Authorization"] = f"Bearer {self._token}"
|
|
34
|
+
return h
|
|
35
|
+
|
|
36
|
+
def _raise(self, resp: requests.Response) -> None:
|
|
37
|
+
if not resp.ok:
|
|
38
|
+
try:
|
|
39
|
+
detail = resp.json().get("detail", resp.text)
|
|
40
|
+
except Exception:
|
|
41
|
+
detail = resp.text
|
|
42
|
+
raise APIError(resp.status_code, detail)
|
|
43
|
+
|
|
44
|
+
def login(self, email: str, password: str) -> dict:
|
|
45
|
+
resp = requests.post(
|
|
46
|
+
f"{self.base_url}/auth/login",
|
|
47
|
+
json={"email": email, "password": password},
|
|
48
|
+
timeout=15,
|
|
49
|
+
)
|
|
50
|
+
self._raise(resp)
|
|
51
|
+
data = resp.json()
|
|
52
|
+
self._token = data["access_token"]
|
|
53
|
+
return data
|
|
54
|
+
|
|
55
|
+
def get_or_create_project(self, name: str, directory: str) -> dict:
|
|
56
|
+
resp = requests.get(
|
|
57
|
+
f"{self.base_url}/projects",
|
|
58
|
+
headers=self._headers(),
|
|
59
|
+
timeout=15,
|
|
60
|
+
)
|
|
61
|
+
self._raise(resp)
|
|
62
|
+
projects = resp.json()
|
|
63
|
+
for p in projects:
|
|
64
|
+
if p.get("project_name") == name:
|
|
65
|
+
return p
|
|
66
|
+
resp = requests.post(
|
|
67
|
+
f"{self.base_url}/projects",
|
|
68
|
+
headers=self._headers(),
|
|
69
|
+
json={"project_name": name, "project_directory": directory, "language": "Python"},
|
|
70
|
+
timeout=15,
|
|
71
|
+
)
|
|
72
|
+
self._raise(resp)
|
|
73
|
+
return resp.json()
|
|
74
|
+
|
|
75
|
+
def create_run_session(self, payload: dict) -> dict:
|
|
76
|
+
resp = requests.post(
|
|
77
|
+
f"{self.base_url}/runsessions",
|
|
78
|
+
headers=self._headers(),
|
|
79
|
+
json=payload,
|
|
80
|
+
timeout=15,
|
|
81
|
+
)
|
|
82
|
+
self._raise(resp)
|
|
83
|
+
return resp.json()
|
|
84
|
+
|
|
85
|
+
def update_run_accuracy(self, run_id: int, metrics: dict) -> dict:
|
|
86
|
+
resp = requests.put(
|
|
87
|
+
f"{self.base_url}/runsessions/{run_id}/accuracy",
|
|
88
|
+
headers=self._headers(),
|
|
89
|
+
json=metrics,
|
|
90
|
+
timeout=15,
|
|
91
|
+
)
|
|
92
|
+
self._raise(resp)
|
|
93
|
+
return resp.json()
|
|
94
|
+
|
|
95
|
+
def post_code_patterns_batch(self, patterns: list) -> dict:
|
|
96
|
+
resp = requests.post(
|
|
97
|
+
f"{self.base_url}/code-patterns/batch",
|
|
98
|
+
headers=self._headers(),
|
|
99
|
+
json={"patterns": patterns},
|
|
100
|
+
timeout=30,
|
|
101
|
+
)
|
|
102
|
+
self._raise(resp)
|
|
103
|
+
return resp.json()
|
|
104
|
+
|
|
105
|
+
def post_surrogate_training(self, payload: dict) -> dict:
|
|
106
|
+
resp = requests.post(
|
|
107
|
+
f"{self.base_url}/surrogate-training",
|
|
108
|
+
headers=self._headers(),
|
|
109
|
+
json=payload,
|
|
110
|
+
timeout=15,
|
|
111
|
+
)
|
|
112
|
+
self._raise(resp)
|
|
113
|
+
return resp.json()
|
|
114
|
+
|
|
115
|
+
def post_environmental_impact(self, payload: dict) -> dict:
|
|
116
|
+
resp = requests.post(
|
|
117
|
+
f"{self.base_url}/environmental-impact",
|
|
118
|
+
headers=self._headers(),
|
|
119
|
+
json=payload,
|
|
120
|
+
timeout=15,
|
|
121
|
+
)
|
|
122
|
+
self._raise(resp)
|
|
123
|
+
return resp.json()
|