dreadnode 1.12.1__tar.gz → 1.13.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. {dreadnode-1.12.1 → dreadnode-1.13.0}/PKG-INFO +2 -1
  2. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/__init__.py +7 -2
  3. dreadnode-1.13.0/dreadnode/__main__.py +10 -0
  4. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/api/client.py +113 -26
  5. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/api/models.py +34 -10
  6. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/api/util.py +1 -1
  7. dreadnode-1.13.0/dreadnode/cli/__init__.py +3 -0
  8. dreadnode-1.13.0/dreadnode/cli/api.py +79 -0
  9. dreadnode-1.13.0/dreadnode/cli/github.py +273 -0
  10. dreadnode-1.13.0/dreadnode/cli/main.py +200 -0
  11. dreadnode-1.13.0/dreadnode/cli/profile/__init__.py +3 -0
  12. dreadnode-1.13.0/dreadnode/cli/profile/cli.py +101 -0
  13. dreadnode-1.13.0/dreadnode/config.py +108 -0
  14. dreadnode-1.13.0/dreadnode/constants.py +57 -0
  15. dreadnode-1.13.0/dreadnode/data_types/__init__.py +9 -0
  16. dreadnode-1.13.0/dreadnode/lookup.py +146 -0
  17. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/main.py +87 -102
  18. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/object.py +3 -1
  19. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/__init__.py +27 -3
  20. dreadnode-1.13.0/dreadnode/scorers/classification.py +102 -0
  21. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/consistency.py +16 -18
  22. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/contains.py +50 -31
  23. dreadnode-1.13.0/dreadnode/scorers/format.py +61 -0
  24. dreadnode-1.13.0/dreadnode/scorers/harm.py +55 -0
  25. dreadnode-1.12.1/dreadnode/scorers/llm_judge.py → dreadnode-1.13.0/dreadnode/scorers/judge.py +18 -20
  26. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/length.py +28 -23
  27. dreadnode-1.13.0/dreadnode/scorers/lexical.py +67 -0
  28. dreadnode-1.13.0/dreadnode/scorers/operators.py +122 -0
  29. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/pii.py +1 -2
  30. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/readability.py +6 -2
  31. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/sentiment.py +14 -2
  32. dreadnode-1.13.0/dreadnode/scorers/similarity.py +293 -0
  33. dreadnode-1.13.0/dreadnode/scorers/util.py +33 -0
  34. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/task.py +0 -84
  35. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/tracing/span.py +6 -7
  36. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/util.py +119 -0
  37. {dreadnode-1.12.1 → dreadnode-1.13.0}/pyproject.toml +13 -2
  38. dreadnode-1.12.1/dreadnode/constants.py +0 -16
  39. dreadnode-1.12.1/dreadnode/data_types/__init__.py +0 -9
  40. dreadnode-1.12.1/dreadnode/scorers/similarity.py +0 -180
  41. {dreadnode-1.12.1 → dreadnode-1.13.0}/README.md +0 -0
  42. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/api/__init__.py +0 -0
  43. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/artifact/__init__.py +0 -0
  44. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/artifact/merger.py +0 -0
  45. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/artifact/storage.py +0 -0
  46. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/artifact/tree_builder.py +0 -0
  47. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/convert.py +0 -0
  48. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/audio.py +0 -0
  49. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/base.py +0 -0
  50. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/image.py +0 -0
  51. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/object_3d.py +0 -0
  52. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/table.py +0 -0
  53. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/text.py +0 -0
  54. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/video.py +0 -0
  55. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/integrations/__init__.py +0 -0
  56. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/integrations/transformers.py +0 -0
  57. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/metric.py +0 -0
  58. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/py.typed +0 -0
  59. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/rigging.py +0 -0
  60. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/serialization.py +0 -0
  61. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/tracing/__init__.py +0 -0
  62. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/tracing/constants.py +0 -0
  63. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/tracing/exporters.py +0 -0
  64. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/types.py +0 -0
  65. {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/version.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: dreadnode
3
- Version: 1.12.1
3
+ Version: 1.13.0
4
4
  Summary: Dreadnode SDK
5
5
  Author: Nick Landers
6
6
  Author-email: monoxgas@gmail.com
@@ -14,6 +14,7 @@ Provides-Extra: all
14
14
  Provides-Extra: multimodal
15
15
  Provides-Extra: training
16
16
  Requires-Dist: coolname (>=2.2.0,<3.0.0)
17
+ Requires-Dist: cyclopts (>=3.22.2,<4.0.0)
17
18
  Requires-Dist: fsspec[s3] (>=2023.1.0,<=2025.3.0)
18
19
  Requires-Dist: httpx (>=0.28.0,<0.29.0)
19
20
  Requires-Dist: logfire (>=3.5.3,<=3.20.0)
@@ -1,9 +1,10 @@
1
1
  from dreadnode import convert, data_types, scorers
2
2
  from dreadnode.data_types import Audio, Code, Image, Markdown, Object3D, Table, Text, Video
3
+ from dreadnode.lookup import Lookup, lookup_input, lookup_output, lookup_param, resolve_lookup
3
4
  from dreadnode.main import DEFAULT_INSTANCE, Dreadnode
4
5
  from dreadnode.metric import Metric, MetricDict, Scorer
5
6
  from dreadnode.object import Object
6
- from dreadnode.task import Task, TaskInput
7
+ from dreadnode.task import Task
7
8
  from dreadnode.tracing.span import RunSpan, Span, TaskSpan
8
9
  from dreadnode.version import VERSION
9
10
 
@@ -39,6 +40,7 @@ __all__ = [
39
40
  "Code",
40
41
  "Dreadnode",
41
42
  "Image",
43
+ "Lookup",
42
44
  "Markdown",
43
45
  "Metric",
44
46
  "MetricDict",
@@ -50,7 +52,6 @@ __all__ = [
50
52
  "Span",
51
53
  "Table",
52
54
  "Task",
53
- "TaskInput",
54
55
  "TaskSpan",
55
56
  "Text",
56
57
  "Video",
@@ -69,7 +70,11 @@ __all__ = [
69
70
  "log_output",
70
71
  "log_param",
71
72
  "log_params",
73
+ "lookup_input",
74
+ "lookup_output",
75
+ "lookup_param",
72
76
  "push_update",
77
+ "resolve_lookup",
73
78
  "run",
74
79
  "scorer",
75
80
  "scorers",
@@ -0,0 +1,10 @@
1
+ from dreadnode.cli import cli
2
+
3
+
4
+ def run() -> None:
5
+ """Run the Dreadnode CLI."""
6
+ cli.meta()
7
+
8
+
9
+ if __name__ == "__main__":
10
+ run()
@@ -1,22 +1,19 @@
1
1
  import io
2
2
  import json
3
+ import time
3
4
  import typing as t
5
+ from datetime import datetime, timezone
6
+ from urllib.parse import urlparse
4
7
 
5
8
  import httpx
6
9
  import pandas as pd
7
10
  from pydantic import BaseModel
8
11
  from ulid import ULID
9
12
 
10
- from dreadnode.api.util import (
11
- convert_flat_tasks_to_tree,
12
- convert_flat_trace_to_tree,
13
- process_run,
14
- process_task,
15
- )
16
- from dreadnode.util import logger
17
- from dreadnode.version import VERSION
18
-
19
- from .models import (
13
+ from dreadnode.api.models import (
14
+ AccessRefreshTokenResponse,
15
+ DeviceCodeResponse,
16
+ GithubTokenResponse,
20
17
  MetricAggregationType,
21
18
  Project,
22
19
  RawRun,
@@ -31,7 +28,17 @@ from .models import (
31
28
  TraceSpan,
32
29
  TraceTree,
33
30
  UserDataCredentials,
31
+ UserResponse,
34
32
  )
33
+ from dreadnode.api.util import (
34
+ convert_flat_tasks_to_tree,
35
+ convert_flat_trace_to_tree,
36
+ process_run,
37
+ process_task,
38
+ )
39
+ from dreadnode.constants import DEFAULT_MAX_POLL_TIME, DEFAULT_POLL_INTERVAL
40
+ from dreadnode.util import logger
41
+ from dreadnode.version import VERSION
35
42
 
36
43
  ModelT = t.TypeVar("ModelT", bound=BaseModel)
37
44
 
@@ -47,11 +54,13 @@ class ApiClient:
47
54
  def __init__(
48
55
  self,
49
56
  base_url: str,
50
- api_key: str,
51
57
  *,
58
+ api_key: str | None = None,
59
+ cookies: dict[str, str] | None = None,
52
60
  debug: bool = False,
53
61
  ):
54
- """Initializes the API client.
62
+ """
63
+ Initializes the API client.
55
64
 
56
65
  Args:
57
66
  base_url (str): The base URL of the Dreadnode API.
@@ -62,12 +71,28 @@ class ApiClient:
62
71
  if not self._base_url.endswith("/api"):
63
72
  self._base_url += "/api"
64
73
 
74
+ _cookies = httpx.Cookies()
75
+ cookie_domain = urlparse(base_url).hostname
76
+ if cookie_domain is None:
77
+ raise ValueError(f"Invalid URL: {base_url}")
78
+
79
+ if cookie_domain == "localhost":
80
+ cookie_domain = "localhost.local"
81
+
82
+ for key, value in (cookies or {}).items():
83
+ _cookies.set(key, value, domain=cookie_domain)
84
+
85
+ headers = {
86
+ "User-Agent": f"dreadnode-sdk/{VERSION}",
87
+ "Accept": "application/json",
88
+ }
89
+
90
+ if api_key:
91
+ headers["X-Api-Key"] = api_key
92
+
65
93
  self._client = httpx.Client(
66
- headers={
67
- "User-Agent": f"dreadnode-sdk/{VERSION}",
68
- "Accept": "application/json",
69
- "X-API-Key": api_key,
70
- },
94
+ headers=headers,
95
+ cookies=_cookies,
71
96
  base_url=self._base_url,
72
97
  timeout=30,
73
98
  )
@@ -77,7 +102,8 @@ class ApiClient:
77
102
  self._client.event_hooks["response"].append(self._log_response)
78
103
 
79
104
  def _log_request(self, request: httpx.Request) -> None:
80
- """Logs HTTP requests if debug mode is enabled.
105
+ """
106
+ Logs HTTP requests if debug mode is enabled.
81
107
 
82
108
  Args:
83
109
  request (httpx.Request): The HTTP request object.
@@ -90,7 +116,8 @@ class ApiClient:
90
116
  logger.debug("-------------------------------------------")
91
117
 
92
118
  def _log_response(self, response: httpx.Response) -> None:
93
- """Logs HTTP responses if debug mode is enabled.
119
+ """
120
+ Logs HTTP responses if debug mode is enabled.
94
121
 
95
122
  Args:
96
123
  response (httpx.Response): The HTTP response object.
@@ -103,7 +130,8 @@ class ApiClient:
103
130
  logger.debug("--------------------------------------------")
104
131
 
105
132
  def _get_error_message(self, response: httpx.Response) -> str:
106
- """Extracts the error message from an HTTP response.
133
+ """
134
+ Extracts the error message from an HTTP response.
107
135
 
108
136
  Args:
109
137
  response (httpx.Response): The HTTP response object.
@@ -125,7 +153,8 @@ class ApiClient:
125
153
  params: dict[str, t.Any] | None = None,
126
154
  json_data: dict[str, t.Any] | None = None,
127
155
  ) -> httpx.Response:
128
- """Makes a raw HTTP request to the API.
156
+ """
157
+ Makes a raw HTTP request to the API.
129
158
 
130
159
  Args:
131
160
  method (str): The HTTP method (e.g., "GET", "POST").
@@ -146,7 +175,8 @@ class ApiClient:
146
175
  params: dict[str, t.Any] | None = None,
147
176
  json_data: dict[str, t.Any] | None = None,
148
177
  ) -> httpx.Response:
149
- """Makes an HTTP request to the API and raises exceptions for errors.
178
+ """
179
+ Makes an HTTP request to the API and raises exceptions for errors.
150
180
 
151
181
  Args:
152
182
  method (str): The HTTP method (e.g., "GET", "POST").
@@ -170,6 +200,59 @@ class ApiClient:
170
200
 
171
201
  return response
172
202
 
203
+ # Auth
204
+
205
+ def url_for_user_code(self, user_code: str) -> str:
206
+ """Get the URL to verify the user code."""
207
+
208
+ return f"{self._base_url.removesuffix('/api')}/account/device?code={user_code}"
209
+
210
+ def get_device_codes(self) -> DeviceCodeResponse:
211
+ """Start the authentication flow by requesting user and device codes."""
212
+
213
+ response = self.request("POST", "/auth/device/code")
214
+ return DeviceCodeResponse(**response.json())
215
+
216
+ def poll_for_token(
217
+ self,
218
+ device_code: str,
219
+ interval: int = DEFAULT_POLL_INTERVAL,
220
+ max_poll_time: int = DEFAULT_MAX_POLL_TIME,
221
+ ) -> AccessRefreshTokenResponse:
222
+ """Poll for the access token with the given device code."""
223
+
224
+ start_time = datetime.now(timezone.utc)
225
+ while (datetime.now(timezone.utc) - start_time).total_seconds() < max_poll_time:
226
+ response = self._request(
227
+ "POST", "/auth/device/token", json_data={"device_code": device_code}
228
+ )
229
+
230
+ if response.status_code == 200: # noqa: PLR2004
231
+ return AccessRefreshTokenResponse(**response.json())
232
+ if response.status_code != 401: # noqa: PLR2004
233
+ raise RuntimeError(self._get_error_message(response))
234
+
235
+ time.sleep(interval)
236
+
237
+ raise RuntimeError("Polling for token timed out")
238
+
239
+ # User
240
+
241
+ def get_user(self) -> UserResponse:
242
+ """Get the user email and username."""
243
+
244
+ response = self.request("GET", "/user")
245
+ return UserResponse(**response.json())
246
+
247
+ # Github
248
+
249
+ def get_github_access_token(self, repos: list[str]) -> GithubTokenResponse:
250
+ """Try to get a GitHub access token for the given repositories."""
251
+ response = self.request("POST", "/github/token", json_data={"repos": repos})
252
+ return GithubTokenResponse(**response.json())
253
+
254
+ # Strikes
255
+
173
256
  def list_projects(self) -> list[Project]:
174
257
  """Retrieves a list of projects.
175
258
 
@@ -294,7 +377,8 @@ class ApiClient:
294
377
  status: StatusFilter = "completed",
295
378
  aggregations: list[MetricAggregationType] | None = None,
296
379
  ) -> pd.DataFrame:
297
- """Exports run data for a specific project.
380
+ """
381
+ Exports run data for a specific project.
298
382
 
299
383
  Args:
300
384
  project: The project identifier.
@@ -327,7 +411,8 @@ class ApiClient:
327
411
  metrics: list[str] | None = None,
328
412
  aggregations: list[MetricAggregationType] | None = None,
329
413
  ) -> pd.DataFrame:
330
- """Exports metric data for a specific project.
414
+ """
415
+ Exports metric data for a specific project.
331
416
 
332
417
  Args:
333
418
  project: The project identifier.
@@ -363,7 +448,8 @@ class ApiClient:
363
448
  metrics: list[str] | None = None,
364
449
  aggregations: list[MetricAggregationType] | None = None,
365
450
  ) -> pd.DataFrame:
366
- """Exports parameter data for a specific project.
451
+ """
452
+ Exports parameter data for a specific project.
367
453
 
368
454
  Args:
369
455
  project: The project identifier.
@@ -401,7 +487,8 @@ class ApiClient:
401
487
  time_axis: TimeAxisType = "relative",
402
488
  aggregations: list[TimeAggregationType] | None = None,
403
489
  ) -> pd.DataFrame:
404
- """Exports timeseries data for a specific project.
490
+ """
491
+ Exports timeseries data for a specific project.
405
492
 
406
493
  Args:
407
494
  project: The project identifier.
@@ -32,6 +32,35 @@ class UserResponse(BaseModel):
32
32
  api_key: UserAPIKey
33
33
 
34
34
 
35
+ class UserDataCredentials(BaseModel):
36
+ access_key_id: str
37
+ secret_access_key: str
38
+ session_token: str
39
+ expiration: datetime
40
+ region: str
41
+ bucket: str
42
+ prefix: str
43
+ endpoint: str | None
44
+
45
+
46
+ # Auth
47
+
48
+
49
+ class DeviceCodeResponse(BaseModel):
50
+ id: UUID
51
+ completed: bool
52
+ device_code: str
53
+ expires_at: datetime
54
+ expires_in: int
55
+ user_code: str
56
+ verification_url: str
57
+
58
+
59
+ class AccessRefreshTokenResponse(BaseModel):
60
+ access_token: str
61
+ refresh_token: str
62
+
63
+
35
64
  # Strikes
36
65
 
37
66
  SpanStatus = t.Literal[
@@ -406,15 +435,10 @@ class TraceTree(BaseModel):
406
435
  """Children of this span, representing nested spans or tasks."""
407
436
 
408
437
 
409
- # User data credentials
438
+ # Github
410
439
 
411
440
 
412
- class UserDataCredentials(BaseModel):
413
- access_key_id: str
414
- secret_access_key: str
415
- session_token: str
416
- expiration: datetime
417
- region: str
418
- bucket: str
419
- prefix: str
420
- endpoint: str | None
441
+ class GithubTokenResponse(BaseModel):
442
+ token: str
443
+ expires_at: datetime
444
+ repos: list[str]
@@ -1,6 +1,6 @@
1
1
  from logging import getLogger
2
2
 
3
- from .models import (
3
+ from dreadnode.api.models import (
4
4
  Object,
5
5
  ObjectUri,
6
6
  ObjectVal,
@@ -0,0 +1,3 @@
1
+ from dreadnode.cli.main import cli
2
+
3
+ __all__ = ["cli"]
@@ -0,0 +1,79 @@
1
+ import atexit
2
+ import base64
3
+ import json
4
+ from datetime import datetime, timezone
5
+
6
+ from dreadnode.api.client import ApiClient
7
+ from dreadnode.config import UserConfig
8
+ from dreadnode.constants import (
9
+ DEFAULT_TOKEN_MAX_TTL,
10
+ )
11
+
12
+
13
+ class Token:
14
+ """A JWT token with an expiration time."""
15
+
16
+ data: str
17
+ expires_at: datetime
18
+
19
+ @staticmethod
20
+ def parse_jwt_token_expiration(token: str) -> datetime:
21
+ """Return the expiration date from a JWT token."""
22
+
23
+ _, b64payload, _ = token.split(".")
24
+ payload = base64.urlsafe_b64decode(b64payload + "==").decode("utf-8")
25
+ return datetime.fromtimestamp(json.loads(payload).get("exp"), tz=timezone.utc)
26
+
27
+ def __init__(self, token: str):
28
+ self.data = token
29
+ self.expires_at = Token.parse_jwt_token_expiration(token)
30
+
31
+ def ttl(self) -> int:
32
+ """Get number of seconds left until the token expires."""
33
+ return int((self.expires_at - datetime.now(tz=timezone.utc)).total_seconds())
34
+
35
+ def is_expired(self) -> bool:
36
+ """Return True if the token is expired."""
37
+ return self.ttl() <= 0
38
+
39
+ def is_close_to_expiry(self) -> bool:
40
+ """Return True if the token is close to expiry."""
41
+ return self.ttl() <= DEFAULT_TOKEN_MAX_TTL
42
+
43
+
44
+ def create_api_client(*, profile: str | None = None) -> ApiClient:
45
+ """Create an authenticated API client using stored configuration data."""
46
+
47
+ user_config = UserConfig.read()
48
+ config = user_config.get_server_config(profile)
49
+
50
+ client = ApiClient(
51
+ config.url,
52
+ cookies={"access_token": config.access_token, "refresh_token": config.refresh_token},
53
+ )
54
+
55
+ # Preemptively check if the token is expired
56
+ if Token(config.refresh_token).is_expired():
57
+ raise RuntimeError("Authentication expired, use [bold]dreadnode login[/]")
58
+
59
+ def _flush_auth_changes() -> None:
60
+ """Flush the authentication data to disk if it has been updated."""
61
+
62
+ access_token = client._client.cookies.get("access_token") # noqa: SLF001
63
+ refresh_token = client._client.cookies.get("refresh_token") # noqa: SLF001
64
+
65
+ changed: bool = False
66
+ if access_token and access_token != config.access_token:
67
+ changed = True
68
+ config.access_token = access_token
69
+
70
+ if refresh_token and refresh_token != config.refresh_token:
71
+ changed = True
72
+ config.refresh_token = refresh_token
73
+
74
+ if changed:
75
+ user_config.set_server_config(config, profile).write()
76
+
77
+ atexit.register(_flush_auth_changes)
78
+
79
+ return client