dreadnode 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dreadnode/__init__.py ADDED
@@ -0,0 +1,51 @@
1
+ from dreadnode.main import DEFAULT_INSTANCE, Dreadnode
2
+ from dreadnode.metric import Metric, MetricDict, Scorer
3
+ from dreadnode.object import Object
4
+ from dreadnode.task import Task
5
+ from dreadnode.tracing.span import RunSpan, Span, TaskSpan
6
+ from dreadnode.version import VERSION
7
+
8
+ configure = DEFAULT_INSTANCE.configure
9
+ shutdown = DEFAULT_INSTANCE.shutdown
10
+
11
+ api = DEFAULT_INSTANCE.api
12
+ span = DEFAULT_INSTANCE.span
13
+ task = DEFAULT_INSTANCE.task
14
+ task_span = DEFAULT_INSTANCE.task_span
15
+ run = DEFAULT_INSTANCE.run
16
+ scorer = DEFAULT_INSTANCE.scorer
17
+ task_span = DEFAULT_INSTANCE.task_span
18
+ push_update = DEFAULT_INSTANCE.push_update
19
+
20
+ log_metric = DEFAULT_INSTANCE.log_metric
21
+ log_param = DEFAULT_INSTANCE.log_param
22
+ log_params = DEFAULT_INSTANCE.log_params
23
+ log_input = DEFAULT_INSTANCE.log_input
24
+ log_inputs = DEFAULT_INSTANCE.log_inputs
25
+ log_output = DEFAULT_INSTANCE.log_output
26
+ link_objects = DEFAULT_INSTANCE.link_objects
27
+ log_artifact = DEFAULT_INSTANCE.log_artifact
28
+
29
+ __version__ = VERSION
30
+
31
+ __all__ = [
32
+ "Dreadnode",
33
+ "Metric",
34
+ "MetricDict",
35
+ "Object",
36
+ "Run",
37
+ "RunSpan",
38
+ "Score",
39
+ "Scorer",
40
+ "Span",
41
+ "Task",
42
+ "TaskSpan",
43
+ "__version__",
44
+ "configure",
45
+ "log_metric",
46
+ "log_param",
47
+ "run",
48
+ "shutdown",
49
+ "span",
50
+ "task",
51
+ ]
File without changes
@@ -0,0 +1,249 @@
1
+ import io
2
+ import json
3
+ import typing as t
4
+
5
+ import httpx
6
+ import pandas as pd
7
+ from pydantic import BaseModel
8
+ from ulid import ULID
9
+
10
+ from dreadnode.util import logger
11
+ from dreadnode.version import VERSION
12
+
13
+ from .models import (
14
+ MetricAggregationType,
15
+ Project,
16
+ Run,
17
+ StatusFilter,
18
+ Task,
19
+ TimeAggregationType,
20
+ TimeAxisType,
21
+ TraceSpan,
22
+ UserDataCredentials,
23
+ )
24
+
25
+ ModelT = t.TypeVar("ModelT", bound=BaseModel)
26
+
27
+
28
+ class ApiClient:
29
+ """Client for the Dreadnode API."""
30
+
31
+ def __init__(
32
+ self,
33
+ base_url: str,
34
+ api_key: str,
35
+ *,
36
+ debug: bool = False,
37
+ ):
38
+ self._base_url = base_url.rstrip("/")
39
+ if not self._base_url.endswith("/api"):
40
+ self._base_url += "/api"
41
+
42
+ self._client = httpx.Client(
43
+ headers={
44
+ "User-Agent": f"dreadnode-sdk/{VERSION}",
45
+ "Accept": "application/json",
46
+ "X-API-Key": api_key,
47
+ },
48
+ base_url=self._base_url,
49
+ timeout=30,
50
+ )
51
+
52
+ if debug:
53
+ self._client.event_hooks["request"].append(self._log_request)
54
+ self._client.event_hooks["response"].append(self._log_response)
55
+
56
+ def _log_request(self, request: httpx.Request) -> None:
57
+ """Log every request to the console if debug is enabled."""
58
+
59
+ logger.debug("-------------------------------------------")
60
+ logger.debug("%s %s", request.method, request.url)
61
+ logger.debug("Headers: %s", request.headers)
62
+ logger.debug("Content: %s", request.content)
63
+ logger.debug("-------------------------------------------")
64
+
65
+ def _log_response(self, response: httpx.Response) -> None:
66
+ """Log every response to the console if debug is enabled."""
67
+
68
+ logger.debug("-------------------------------------------")
69
+ logger.debug("Response: %s", response.status_code)
70
+ logger.debug("Headers: %s", response.headers)
71
+ logger.debug("Content: %s", response.read())
72
+ logger.debug("--------------------------------------------")
73
+
74
+ def _get_error_message(self, response: httpx.Response) -> str:
75
+ """Get the error message from the response."""
76
+
77
+ try:
78
+ obj = response.json()
79
+ return f"{response.status_code}: {obj.get('detail', json.dumps(obj))}"
80
+ except Exception: # noqa: BLE001
81
+ return str(response.content)
82
+
83
+ def _request(
84
+ self,
85
+ method: str,
86
+ path: str,
87
+ params: dict[str, t.Any] | None = None,
88
+ json_data: dict[str, t.Any] | None = None,
89
+ ) -> httpx.Response:
90
+ """Make a raw request to the API."""
91
+
92
+ return self._client.request(method, path, json=json_data, params=params)
93
+
94
+ def request(
95
+ self,
96
+ method: str,
97
+ path: str,
98
+ params: dict[str, t.Any] | None = None,
99
+ json_data: dict[str, t.Any] | None = None,
100
+ ) -> httpx.Response:
101
+ """Make a request to the API. Raise an exception for non-200 status codes."""
102
+
103
+ response = self._request(method, path, params, json_data)
104
+ if response.status_code == 401: # noqa: PLR2004
105
+ raise RuntimeError("Authentication failed, please check your API token.")
106
+
107
+ try:
108
+ response.raise_for_status()
109
+ except httpx.HTTPStatusError as e:
110
+ raise RuntimeError(self._get_error_message(response)) from e
111
+
112
+ return response
113
+
114
+ # This currently won't work with API keys
115
+ # def get_user(self) -> UserResponse:
116
+ # response = self.request("GET", "/user")
117
+ # return UserResponse(**response.json())
118
+
119
+ def list_projects(self) -> list[Project]:
120
+ response = self.request("GET", "/strikes/projects")
121
+ return [Project(**project) for project in response.json()]
122
+
123
+ def get_project(self, project: str) -> Project:
124
+ response = self.request("GET", f"/strikes/projects/{project!s}")
125
+ return Project(**response.json())
126
+
127
+ def list_runs(self, project: str) -> list[Run]:
128
+ response = self.request("GET", f"/strikes/projects/{project!s}/runs")
129
+ return [Run(**run) for run in response.json()]
130
+
131
+ def get_run(self, run: str | ULID) -> Run:
132
+ response = self.request("GET", f"/strikes/projects/runs/{run!s}")
133
+ return Run(**response.json())
134
+
135
+ def get_run_tasks(self, run: str | ULID) -> list[Task]:
136
+ response = self.request("GET", f"/strikes/projects/runs/{run!s}/tasks")
137
+ return [Task(**task) for task in response.json()]
138
+
139
+ def get_run_trace(self, run: str | ULID) -> list[Task | TraceSpan]:
140
+ response = self.request("GET", f"/strikes/projects/runs/{run!s}/spans")
141
+ spans: list[Task | TraceSpan] = []
142
+ for item in response.json():
143
+ if "parent_task_span_id" in item:
144
+ spans.append(Task(**item))
145
+ else:
146
+ spans.append(TraceSpan(**item))
147
+ return spans
148
+
149
+ # Data exports
150
+
151
+ def export_runs(
152
+ self,
153
+ project: str,
154
+ *,
155
+ filter: str | None = None,
156
+ # format: ExportFormat = "parquet",
157
+ status: StatusFilter = "completed",
158
+ aggregations: list[MetricAggregationType] | None = None,
159
+ ) -> pd.DataFrame:
160
+ response = self.request(
161
+ "GET",
162
+ f"/strikes/projects/{project!s}/export",
163
+ params={
164
+ "format": "parquet",
165
+ "status": status,
166
+ **({"filter": filter} if filter else {}),
167
+ **({"aggregations": aggregations} if aggregations else {}),
168
+ },
169
+ )
170
+ return pd.read_parquet(io.BytesIO(response.content))
171
+
172
+ def export_metrics(
173
+ self,
174
+ project: str,
175
+ *,
176
+ filter: str | None = None,
177
+ # format: ExportFormat = "parquet",
178
+ status: StatusFilter = "completed",
179
+ metrics: list[str] | None = None,
180
+ aggregations: list[MetricAggregationType] | None = None,
181
+ ) -> pd.DataFrame:
182
+ response = self.request(
183
+ "GET",
184
+ f"/strikes/projects/{project!s}/export/metrics",
185
+ params={
186
+ "format": "parquet",
187
+ "status": status,
188
+ "filter": filter,
189
+ **({"metrics": metrics} if metrics else {}),
190
+ **({"aggregations": aggregations} if aggregations else {}),
191
+ },
192
+ )
193
+ return pd.read_parquet(io.BytesIO(response.content))
194
+
195
+ def export_parameters(
196
+ self,
197
+ project: str,
198
+ *,
199
+ filter: str | None = None,
200
+ # format: ExportFormat = "parquet",
201
+ status: StatusFilter = "completed",
202
+ parameters: list[str] | None = None,
203
+ metrics: list[str] | None = None,
204
+ aggregations: list[MetricAggregationType] | None = None,
205
+ ) -> pd.DataFrame:
206
+ response = self.request(
207
+ "GET",
208
+ f"/strikes/projects/{project!s}/export/parameters",
209
+ params={
210
+ "format": "parquet",
211
+ "status": status,
212
+ "filter": filter,
213
+ **({"parameters": parameters} if parameters else {}),
214
+ **({"metrics": metrics} if metrics else {}),
215
+ **({"aggregations": aggregations} if aggregations else {}),
216
+ },
217
+ )
218
+ return pd.read_parquet(io.BytesIO(response.content))
219
+
220
+ def export_timeseries(
221
+ self,
222
+ project: str,
223
+ *,
224
+ filter: str | None = None,
225
+ # format: ExportFormat = "parquet",
226
+ status: StatusFilter = "completed",
227
+ metrics: list[str] | None = None,
228
+ time_axis: TimeAxisType = "relative",
229
+ aggregations: list[TimeAggregationType] | None = None,
230
+ ) -> pd.DataFrame:
231
+ response = self.request(
232
+ "GET",
233
+ f"/strikes/projects/{project!s}/export/timeseries",
234
+ params={
235
+ "format": "parquet",
236
+ "status": status,
237
+ "filter": filter,
238
+ "time_axis": time_axis,
239
+ **({"metrics": metrics} if metrics else {}),
240
+ **({"aggregation": aggregations} if aggregations else {}),
241
+ },
242
+ )
243
+ return pd.read_parquet(io.BytesIO(response.content))
244
+
245
+ # User data access
246
+
247
+ def get_user_data_credentials(self) -> UserDataCredentials:
248
+ response = self.request("GET", "/user-data/credentials")
249
+ return UserDataCredentials(**response.json())
@@ -0,0 +1,210 @@
1
+ import typing as t
2
+ from datetime import datetime
3
+ from uuid import UUID
4
+
5
+ from pydantic import BaseModel, Field
6
+ from ulid import ULID
7
+
8
+ AnyDict = dict[str, t.Any]
9
+
10
+ # User
11
+
12
+
13
+ class UserAPIKey(BaseModel):
14
+ key: str
15
+
16
+
17
+ class UserResponse(BaseModel):
18
+ id: UUID
19
+ email_address: str
20
+ username: str
21
+ api_key: UserAPIKey
22
+
23
+
24
+ # Strikes
25
+
26
+ SpanStatus = t.Literal[
27
+ "pending", # A pending span has been created
28
+ "completed", # The span has been finished
29
+ "failed", # The raised an exception
30
+ ]
31
+
32
+ ExportFormat = t.Literal["csv", "json", "jsonl", "parquet"]
33
+ StatusFilter = t.Literal["all", "completed", "failed"]
34
+ TimeAxisType = t.Literal["wall", "relative", "step"]
35
+ TimeAggregationType = t.Literal["max", "min", "sum", "count"]
36
+ MetricAggregationType = t.Literal[
37
+ "avg",
38
+ "median",
39
+ "min",
40
+ "max",
41
+ "sum",
42
+ "first",
43
+ "last",
44
+ "count",
45
+ "std",
46
+ "var",
47
+ ]
48
+
49
+
50
+ class SpanException(BaseModel):
51
+ type: str
52
+ message: str
53
+ stacktrace: str
54
+
55
+
56
+ class SpanEvent(BaseModel):
57
+ timestamp: datetime
58
+ name: str
59
+ attributes: AnyDict
60
+
61
+
62
+ class SpanLink(BaseModel):
63
+ trace_id: str
64
+ span_id: str
65
+ attributes: AnyDict
66
+
67
+
68
+ class TraceLog(BaseModel):
69
+ timestamp: datetime
70
+ body: str
71
+ severity: str
72
+ service: str | None
73
+ trace_id: str | None
74
+ span_id: str | None
75
+ attributes: AnyDict
76
+ container: str | None
77
+
78
+
79
+ class TraceSpan(BaseModel):
80
+ timestamp: datetime
81
+ duration: int
82
+ trace_id: str
83
+ span_id: str
84
+ parent_span_id: str | None
85
+ service_name: str | None
86
+ status: SpanStatus
87
+ exception: SpanException | None
88
+ name: str
89
+ attributes: AnyDict
90
+ resource_attributes: AnyDict
91
+ events: list[SpanEvent]
92
+ links: list[SpanLink]
93
+
94
+
95
+ class Metric(BaseModel):
96
+ value: float
97
+ step: int
98
+ timestamp: datetime
99
+ attributes: AnyDict
100
+
101
+
102
+ class ObjectRef(BaseModel):
103
+ name: str
104
+ label: str
105
+ hash: str
106
+
107
+
108
+ class ObjectUri(BaseModel):
109
+ hash: str
110
+ schema_hash: str
111
+ uri: str
112
+ size: int
113
+ type: t.Literal["uri"]
114
+
115
+
116
+ class ObjectVal(BaseModel):
117
+ hash: str
118
+ schema_hash: str
119
+ value: t.Any
120
+ type: t.Literal["val"]
121
+
122
+
123
+ Object = ObjectUri | ObjectVal
124
+
125
+
126
+ class V0Object(BaseModel):
127
+ name: str
128
+ label: str
129
+ value: t.Any
130
+
131
+
132
+ class Run(BaseModel):
133
+ id: ULID
134
+ name: str
135
+ span_id: str
136
+ trace_id: str
137
+ timestamp: datetime
138
+ duration: int
139
+ status: SpanStatus
140
+ exception: SpanException | None
141
+ tags: set[str]
142
+ params: AnyDict
143
+ metrics: dict[str, list[Metric]]
144
+ inputs: list[ObjectRef]
145
+ outputs: list[ObjectRef]
146
+ objects: dict[str, Object]
147
+ object_schemas: AnyDict
148
+ schema_: AnyDict = Field(alias="schema")
149
+
150
+
151
+ class Task(BaseModel):
152
+ name: str
153
+ span_id: str
154
+ trace_id: str
155
+ parent_span_id: str | None
156
+ parent_task_span_id: str | None
157
+ timestamp: datetime
158
+ duration: int
159
+ status: SpanStatus
160
+ exception: SpanException | None
161
+ tags: set[str]
162
+ params: AnyDict
163
+ metrics: dict[str, list[Metric]]
164
+ inputs: list[ObjectRef] | list[V0Object] # v0 compat
165
+ outputs: list[ObjectRef] | list[V0Object] # v0 compat
166
+ schema_: AnyDict = Field(alias="schema")
167
+ attributes: AnyDict
168
+ resource_attributes: AnyDict
169
+ events: list[SpanEvent]
170
+ links: list[SpanLink]
171
+
172
+
173
+ class Project(BaseModel):
174
+ id: UUID
175
+ key: str
176
+ name: str
177
+ description: str | None
178
+ created_at: datetime
179
+ updated_at: datetime
180
+ run_count: int
181
+ last_run: Run | None
182
+
183
+
184
+ # Derived types
185
+
186
+
187
+ class TaskTree(BaseModel):
188
+ task: Task
189
+ children: list["TaskTree"] = []
190
+
191
+
192
+ class SpanTree(BaseModel):
193
+ """Tree representation of a trace span with its children"""
194
+
195
+ span: Task | TraceSpan
196
+ children: list["SpanTree"] = []
197
+
198
+
199
+ # User data credentials
200
+
201
+
202
+ class UserDataCredentials(BaseModel):
203
+ access_key_id: str
204
+ secret_access_key: str
205
+ session_token: str
206
+ expiration: datetime
207
+ region: str
208
+ bucket: str
209
+ prefix: str
210
+ endpoint: str | None
File without changes