dreadnode 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dreadnode/main.py ADDED
@@ -0,0 +1,1042 @@
1
+ import contextlib
2
+ import inspect
3
+ import os
4
+ import random
5
+ import re
6
+ import typing as t
7
+ from dataclasses import dataclass
8
+ from datetime import datetime, timezone
9
+ from pathlib import Path
10
+ from urllib.parse import urljoin
11
+
12
+ import coolname # type: ignore [import-untyped]
13
+ import logfire
14
+ from fsspec.implementations.local import ( # type: ignore [import-untyped]
15
+ LocalFileSystem,
16
+ )
17
+ from logfire._internal.exporters.remove_pending import RemovePendingSpansExporter
18
+ from logfire._internal.stack_info import get_filepath_attribute, warn_at_user_stacklevel
19
+ from logfire._internal.utils import safe_repr
20
+ from opentelemetry.exporter.otlp.proto.http import Compression
21
+ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
22
+ from opentelemetry.sdk.trace.export import BatchSpanProcessor
23
+ from s3fs import S3FileSystem # type: ignore [import-untyped]
24
+
25
+ from dreadnode.api.client import ApiClient
26
+ from dreadnode.constants import (
27
+ DEFAULT_SERVER_URL,
28
+ ENV_API_KEY,
29
+ ENV_API_TOKEN,
30
+ ENV_LOCAL_DIR,
31
+ ENV_PROJECT,
32
+ ENV_SERVER,
33
+ ENV_SERVER_URL,
34
+ )
35
+ from dreadnode.metric import Metric, MetricMode, Scorer, ScorerCallable, T
36
+ from dreadnode.task import P, R, Task
37
+ from dreadnode.tracing.exporters import (
38
+ FileExportConfig,
39
+ FileMetricReader,
40
+ FileSpanExporter,
41
+ )
42
+ from dreadnode.tracing.span import (
43
+ RunSpan,
44
+ Span,
45
+ TaskSpan,
46
+ current_run_span,
47
+ current_task_span,
48
+ )
49
+ from dreadnode.types import (
50
+ AnyDict,
51
+ JsonValue,
52
+ )
53
+ from dreadnode.util import handle_internal_errors
54
+ from dreadnode.version import VERSION
55
+
56
+ if t.TYPE_CHECKING:
57
+ from fsspec import AbstractFileSystem # type: ignore [import-untyped]
58
+ from opentelemetry.sdk.metrics.export import MetricReader
59
+ from opentelemetry.sdk.trace import SpanProcessor
60
+ from opentelemetry.trace import Tracer
61
+
62
+
63
+ ToObject = t.Literal["task-or-run", "run"]
64
+
65
+
66
+ class DreadnodeConfigWarning(UserWarning):
67
+ pass
68
+
69
+
70
+ class DreadnodeUsageWarning(UserWarning):
71
+ pass
72
+
73
+
74
+ @dataclass
75
+ class Dreadnode:
76
+ """
77
+ The core Dreadnode SDK class.
78
+
79
+ A default instance of this class is created and can be used directly with `dreadnode.*`.
80
+
81
+ Otherwise, you can create your own instance and configure it with `configure()`.
82
+ """
83
+
84
+ server: str | None
85
+ token: str | None
86
+ local_dir: str | Path | t.Literal[False]
87
+ project: str | None
88
+ service_name: str | None
89
+ service_version: str | None
90
+ console: logfire.ConsoleOptions | t.Literal[False, True]
91
+ send_to_logfire: bool | t.Literal["if-token-present"]
92
+ otel_scope: str
93
+
94
+ def __init__(
95
+ self,
96
+ *,
97
+ server: str | None = None,
98
+ token: str | None = None,
99
+ local_dir: str | Path | t.Literal[False] = False,
100
+ project: str | None = None,
101
+ service_name: str | None = None,
102
+ service_version: str | None = None,
103
+ console: logfire.ConsoleOptions | t.Literal[False, True] = True,
104
+ send_to_logfire: bool | t.Literal["if-token-present"] = "if-token-present",
105
+ otel_scope: str = "dreadnode",
106
+ ) -> None:
107
+ self.server = server
108
+ self.token = token
109
+ self.local_dir = local_dir
110
+ self.project = project
111
+ self.service_name = service_name
112
+ self.service_version = service_version
113
+ self.console = console
114
+ self.send_to_logfire = send_to_logfire
115
+ self.otel_scope = otel_scope
116
+
117
+ self._api: ApiClient | None = None
118
+
119
+ self._logfire = logfire.DEFAULT_LOGFIRE_INSTANCE
120
+ self._logfire.config.ignore_no_config = True
121
+
122
+ self._fs: AbstractFileSystem = LocalFileSystem(auto_mkdir=True)
123
+ self._fs_prefix: str = ".dreadnode/storage/"
124
+
125
+ self._initialized = False
126
+
127
+ def configure(
128
+ self,
129
+ *,
130
+ server: str | None = None,
131
+ token: str | None = None,
132
+ local_dir: str | Path | t.Literal[False] = False,
133
+ project: str | None = None,
134
+ service_name: str | None = None,
135
+ service_version: str | None = None,
136
+ console: logfire.ConsoleOptions | t.Literal[False, True] = True,
137
+ send_to_logfire: bool | t.Literal["if-token-present"] = "if-token-present",
138
+ otel_scope: str = "dreadnode",
139
+ ) -> None:
140
+ """
141
+ Configure the Dreadnode SDK and call `initialize()`.
142
+
143
+ This method should always be called before using the SDK.
144
+
145
+ If `server` and `token` are not provided, the SDK will look in
146
+ the associated environment variables:
147
+
148
+ - `DREADNODE_SERVER_URL` or `DREADNODE_SERVER`
149
+ - `DREADNODE_API_TOKEN` or `DREADNODE_API_KEY`
150
+
151
+ Args:
152
+ server: The Dreadnode server URL.
153
+ token: The Dreadnode API token.
154
+ local_dir: The local directory to store data in.
155
+ project: The defautlt project name to associate all runs with.
156
+ service_name: The service name to use for OpenTelemetry.
157
+ service_version: The service version to use for OpenTelemetry.
158
+ console: Whether to log span information to the console.
159
+ send_to_logfire: Whether to send data to Logfire.
160
+ otel_scope: The OpenTelemetry scope name.
161
+ """
162
+
163
+ self._initialized = False
164
+
165
+ self.server = server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER)
166
+ self.token = token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY)
167
+
168
+ if local_dir is False and ENV_LOCAL_DIR in os.environ:
169
+ env_local_dir = os.environ.get(ENV_LOCAL_DIR)
170
+ if env_local_dir:
171
+ self.local_dir = Path(env_local_dir)
172
+ else:
173
+ self.local_dir = False
174
+ else:
175
+ self.local_dir = local_dir
176
+
177
+ self.project = project or os.environ.get(ENV_PROJECT)
178
+ self.service_name = service_name
179
+ self.service_version = service_version
180
+ self.console = console
181
+ self.send_to_logfire = send_to_logfire
182
+ self.otel_scope = otel_scope
183
+
184
+ self.initialize()
185
+
186
+ def initialize(self) -> None:
187
+ """
188
+ Initialize the Dreadnode SDK.
189
+
190
+ This method is called automatically when you call `configure()`.
191
+ """
192
+ if self._initialized:
193
+ return
194
+
195
+ span_processors: list[SpanProcessor] = []
196
+ metric_readers: list[MetricReader] = []
197
+
198
+ self.server = self.server or DEFAULT_SERVER_URL
199
+ if self.server is None and self.local_dir is False:
200
+ warn_at_user_stacklevel(
201
+ "Your current configuration won't persist run data anywhere. "
202
+ "Use `dreadnode.init(server=..., token=...)`, `dreadnode.init(local_dir=...)`, "
203
+ f"or use environment variables ({ENV_SERVER_URL}, {ENV_API_TOKEN}, {ENV_LOCAL_DIR}).",
204
+ category=DreadnodeConfigWarning,
205
+ )
206
+
207
+ if self.local_dir is not False:
208
+ config = FileExportConfig(
209
+ base_path=self.local_dir,
210
+ prefix=self.project + "-" if self.project else "",
211
+ )
212
+ span_processors.append(BatchSpanProcessor(FileSpanExporter(config)))
213
+ metric_readers.append(FileMetricReader(config))
214
+
215
+ if self.token is not None:
216
+ self._api = ApiClient(self.server, self.token)
217
+
218
+ try:
219
+ self._api.list_projects()
220
+ except Exception as e:
221
+ raise RuntimeError(
222
+ "Failed to authenticate with the provided server and token",
223
+ ) from e
224
+
225
+ headers = {"User-Agent": f"dreadnode/{VERSION}", "X-Api-Key": self.token}
226
+ span_processors.append(
227
+ BatchSpanProcessor(
228
+ RemovePendingSpansExporter( # This will tell Logfire to emit pending spans to us as well
229
+ OTLPSpanExporter(
230
+ endpoint=urljoin(self.server, "/api/otel/traces"),
231
+ headers=headers,
232
+ compression=Compression.Gzip,
233
+ ),
234
+ ),
235
+ ),
236
+ )
237
+ # TODO(nick): Metrics
238
+ # https://linear.app/dreadnode/issue/ENG-1310/sdk-add-metrics-exports
239
+ # metric_readers.append(
240
+ # PeriodicExportingMetricReader(
241
+ # OTLPMetricExporter(
242
+ # endpoint=urljoin(self.server, "/v1/metrics"),
243
+ # headers=headers,
244
+ # compression=Compression.Gzip,
245
+ # # preferred_temporality
246
+ # )
247
+ # )
248
+ # )
249
+
250
+ credentials = self._api.get_user_data_credentials()
251
+ self._fs = S3FileSystem(
252
+ key=credentials.access_key_id,
253
+ secret=credentials.secret_access_key,
254
+ token=credentials.session_token,
255
+ client_kwargs={
256
+ "endpoint_url": credentials.endpoint,
257
+ "region_name": credentials.region,
258
+ },
259
+ )
260
+ self._fs_prefix = f"{credentials.bucket}/{credentials.prefix}/"
261
+
262
+ self._logfire = logfire.configure(
263
+ local=not self.is_default,
264
+ send_to_logfire=self.send_to_logfire,
265
+ additional_span_processors=span_processors,
266
+ metrics=logfire.MetricsOptions(additional_readers=metric_readers),
267
+ service_name=self.service_name,
268
+ service_version=self.service_version,
269
+ console=logfire.ConsoleOptions() if self.console is True else self.console,
270
+ scrubbing=False,
271
+ )
272
+ self._logfire.config.ignore_no_config = True
273
+
274
+ self._initialized = True
275
+
276
+ @property
277
+ def is_default(self) -> bool:
278
+ return self is DEFAULT_INSTANCE
279
+
280
+ def api(self, *, server: str | None = None, token: str | None = None) -> ApiClient:
281
+ """
282
+ Get an API client based on the current configuration or the provided server and token.
283
+
284
+ If the server and token are not provided, the method will use the current configuration
285
+ and `configure()` needs to be called first.
286
+
287
+ Args:
288
+ server: The server URL to use for the API client.
289
+ token: The API token to use for authentication.
290
+
291
+ Returns:
292
+ An ApiClient instance.
293
+ """
294
+ if server is not None and token is not None:
295
+ return ApiClient(server, token)
296
+
297
+ if not self._initialized:
298
+ raise RuntimeError("Call .configure() before accessing the API")
299
+
300
+ if self._api is None:
301
+ raise RuntimeError("API is not available without a server configuration")
302
+
303
+ return self._api
304
+
305
+ def _get_tracer(self, *, is_span_tracer: bool = True) -> "Tracer":
306
+ return self._logfire._tracer_provider.get_tracer( # noqa: SLF001
307
+ self.otel_scope,
308
+ VERSION,
309
+ is_span_tracer=is_span_tracer,
310
+ )
311
+
312
+ @handle_internal_errors()
313
+ def shutdown(self) -> None:
314
+ """
315
+ Shutdown any associate OpenTelemetry components and flush any pending spans.
316
+
317
+ It is not required to call this method, as the SDK will automatically
318
+ flush and shutdown when the process exits.
319
+
320
+ However, if you want to ensure that all spans are flushed before
321
+ exiting, you can call this method manually.
322
+ """
323
+ if not self._initialized:
324
+ return
325
+
326
+ self._logfire.shutdown()
327
+
328
+ def span(
329
+ self,
330
+ name: str,
331
+ *,
332
+ tags: t.Sequence[str] | None = None,
333
+ **attributes: t.Any,
334
+ ) -> Span:
335
+ """
336
+ Create a new OpenTelemety span.
337
+
338
+ Spans are more lightweight than tasks, but still let you track
339
+ work being performed and view it in the UI. You cannot
340
+ log parameters, inputs, or outputs to spans.
341
+
342
+ Example:
343
+ ```
344
+ with dreadnode.span("my_span") as span:
345
+ # do some work here
346
+ pass
347
+ ```
348
+
349
+ Args:
350
+ name: The name of the span.
351
+ tags: A list of tags to attach to the span.
352
+ **attributes: A dictionary of attributes to attach to the span.
353
+
354
+ Returns:
355
+ A Span object.
356
+ """
357
+ return Span(
358
+ name=name,
359
+ attributes=attributes,
360
+ tracer=self._get_tracer(),
361
+ tags=tags,
362
+ )
363
+
364
+ # Some excessive typing here to ensure we can properly
365
+ # overload our decorator for sync/async and cases
366
+ # where we need the return type of the task to align
367
+ # with the scorer inputs
368
+
369
+ class TaskDecorator(t.Protocol):
370
+ @t.overload
371
+ def __call__(
372
+ self,
373
+ func: t.Callable[P, t.Awaitable[R]],
374
+ ) -> Task[P, R]: ...
375
+
376
+ @t.overload
377
+ def __call__(
378
+ self,
379
+ func: t.Callable[P, R],
380
+ ) -> Task[P, R]: ...
381
+
382
+ def __call__(
383
+ self,
384
+ func: t.Callable[P, t.Awaitable[R]] | t.Callable[P, R],
385
+ ) -> Task[P, R]: ...
386
+
387
+ class ScoredTaskDecorator(t.Protocol, t.Generic[R]):
388
+ @t.overload
389
+ def __call__(
390
+ self,
391
+ func: t.Callable[P, t.Awaitable[R]],
392
+ ) -> Task[P, R]: ...
393
+
394
+ @t.overload
395
+ def __call__(
396
+ self,
397
+ func: t.Callable[P, R],
398
+ ) -> Task[P, R]: ...
399
+
400
+ def __call__(
401
+ self,
402
+ func: t.Callable[P, t.Awaitable[R]] | t.Callable[P, R],
403
+ ) -> Task[P, R]: ...
404
+
405
+ @t.overload
406
+ def task(
407
+ self,
408
+ *,
409
+ scorers: None = None,
410
+ name: str | None = None,
411
+ label: str | None = None,
412
+ log_params: t.Sequence[str] | bool = False,
413
+ log_inputs: t.Sequence[str] | bool = True,
414
+ log_output: bool = True,
415
+ tags: t.Sequence[str] | None = None,
416
+ **attributes: t.Any,
417
+ ) -> TaskDecorator: ...
418
+
419
+ @t.overload
420
+ def task(
421
+ self,
422
+ *,
423
+ scorers: t.Sequence[Scorer[R] | ScorerCallable[R]],
424
+ name: str | None = None,
425
+ label: str | None = None,
426
+ log_params: t.Sequence[str] | bool = False,
427
+ log_inputs: t.Sequence[str] | bool = True,
428
+ log_output: bool = True,
429
+ tags: t.Sequence[str] | None = None,
430
+ **attributes: t.Any,
431
+ ) -> ScoredTaskDecorator[R]: ...
432
+
433
+ def task(
434
+ self,
435
+ *,
436
+ scorers: t.Sequence[Scorer[t.Any] | ScorerCallable[t.Any]] | None = None,
437
+ name: str | None = None,
438
+ label: str | None = None,
439
+ log_params: t.Sequence[str] | bool = False,
440
+ log_inputs: t.Sequence[str] | bool = True,
441
+ log_output: bool = True,
442
+ tags: t.Sequence[str] | None = None,
443
+ **attributes: t.Any,
444
+ ) -> TaskDecorator:
445
+ """
446
+ Create a new task from a function.
447
+
448
+ Example:
449
+ ```
450
+ @dreadnode.task(name="my_task")
451
+ async def my_task(x: int) -> int:
452
+ return x * 2
453
+
454
+ await my_task(2)
455
+ ```
456
+
457
+ Args:
458
+ scorers: A list of scorers to attach to the task. These will be called after every execution
459
+ of the task and will be passed the task's output.
460
+ name: The name of the task.
461
+ label: The label of the task - useful for filtering in the UI.
462
+ log_params: Whether to log all, or specific, incoming arguments to the function as parameters.
463
+ log_inputs: Whether to log all, or specific, incoming arguments to the function as inputs.
464
+ log_output: Whether to log the result of the function as an output.
465
+ tags: A list of tags to attach to the task span.
466
+ **attributes: A dictionary of attributes to attach to the task span.
467
+
468
+ Returns:
469
+ A new Task object.
470
+ """
471
+
472
+ def make_task(
473
+ func: t.Callable[P, t.Awaitable[R]] | t.Callable[P, R],
474
+ ) -> Task[P, R]:
475
+ unwrapped = inspect.unwrap(func)
476
+
477
+ if inspect.isgeneratorfunction(unwrapped) or inspect.isasyncgenfunction(
478
+ unwrapped,
479
+ ):
480
+ raise TypeError("@task cannot be applied to generators")
481
+
482
+ func_name = getattr(
483
+ unwrapped,
484
+ "__qualname__",
485
+ getattr(func, "__name__", safe_repr(func)),
486
+ )
487
+
488
+ _name = name or func_name
489
+ _label = label or func_name
490
+
491
+ # conform our label for sanity
492
+ _label = re.sub(r"[\W_]+", "_", _label.lower())
493
+
494
+ _attributes = attributes or {}
495
+ _attributes["code.function"] = func_name
496
+ with contextlib.suppress(Exception):
497
+ _attributes["code.lineno"] = unwrapped.__code__.co_firstlineno
498
+ with contextlib.suppress(Exception):
499
+ _attributes.update(
500
+ get_filepath_attribute(
501
+ inspect.getsourcefile(unwrapped), # type: ignore [arg-type]
502
+ ),
503
+ )
504
+
505
+ return Task(
506
+ tracer=self._get_tracer(),
507
+ name=_name,
508
+ attributes=_attributes,
509
+ func=t.cast("t.Callable[P, R]", func),
510
+ scorers=[
511
+ scorer
512
+ if isinstance(scorer, Scorer)
513
+ else Scorer.from_callable(self._get_tracer(), scorer)
514
+ for scorer in scorers or []
515
+ ],
516
+ tags=list(tags or []),
517
+ log_params=log_params,
518
+ log_inputs=log_inputs,
519
+ log_output=log_output,
520
+ label=_label,
521
+ )
522
+
523
+ return make_task
524
+
525
+ def task_span(
526
+ self,
527
+ name: str,
528
+ *,
529
+ label: str | None = None,
530
+ params: AnyDict | None = None,
531
+ tags: t.Sequence[str] | None = None,
532
+ **attributes: t.Any,
533
+ ) -> TaskSpan[t.Any]:
534
+ """
535
+ Create a task span without an explicit associated function.
536
+
537
+ This is useful for creating tasks on the fly without having to
538
+ define a function.
539
+
540
+ Example:
541
+ ```
542
+ async with dreadnode.task_span("my_task") as task:
543
+ # do some work here
544
+ pass
545
+ ```
546
+ Args:
547
+ name: The name of the task.
548
+ label: The label of the task - useful for filtering in the UI.
549
+ params: A dictionary of parameters to attach to the task span.
550
+ tags: A list of tags to attach to the task span.
551
+ **attributes: A dictionary of attributes to attach to the task span.
552
+
553
+ Returns:
554
+ A TaskSpan object.
555
+ """
556
+ if (run := current_run_span.get()) is None:
557
+ raise RuntimeError("Task spans must be created within a run")
558
+
559
+ label = label or re.sub(r"[\W_]+", "_", name.lower())
560
+ return TaskSpan(
561
+ name=name,
562
+ label=label,
563
+ attributes=attributes,
564
+ params=params,
565
+ tags=tags,
566
+ run_id=run.run_id,
567
+ tracer=self._get_tracer(),
568
+ )
569
+
570
+ def scorer(
571
+ self,
572
+ *,
573
+ name: str | None = None,
574
+ tags: t.Sequence[str] | None = None,
575
+ **attributes: t.Any,
576
+ ) -> t.Callable[[ScorerCallable[T]], Scorer[T]]:
577
+ """
578
+ Make a scorer from a callable function.
579
+
580
+ This is useful when you want to change the name of the scorer
581
+ or add additional attributes to it.
582
+
583
+ Example:
584
+ ```
585
+ @dreadnode.scorer(name="my_scorer")
586
+ async def my_scorer(x: int) -> float:
587
+ return x * 2
588
+
589
+ @dreadnode.task(scorers=[my_scorer])
590
+ async def my_task(x: int) -> int:
591
+ return x * 2
592
+
593
+ await my_task(2)
594
+ ```
595
+
596
+ Args:
597
+ name: The name of the scorer.
598
+ tags: A list of tags to attach to the scorer.
599
+ **attributes: A dictionary of attributes to attach to the scorer.
600
+
601
+ Returns:
602
+ A new Scorer object.
603
+ """
604
+
605
+ def make_scorer(func: ScorerCallable[T]) -> Scorer[T]:
606
+ return Scorer.from_callable(
607
+ self._get_tracer(),
608
+ func,
609
+ name=name,
610
+ tags=tags,
611
+ attributes=attributes,
612
+ )
613
+
614
+ return make_scorer
615
+
616
+ def run(
617
+ self,
618
+ name: str | None = None,
619
+ *,
620
+ tags: t.Sequence[str] | None = None,
621
+ params: AnyDict | None = None,
622
+ project: str | None = None,
623
+ **attributes: t.Any,
624
+ ) -> RunSpan:
625
+ """
626
+ Create a new run.
627
+
628
+ Runs are the main way to track work in Dreadnode. They are
629
+ associated with a specific project and can have parameters,
630
+ inputs, and outputs logged to them.
631
+
632
+ You cannot create runs inside other runs.
633
+
634
+ Example:
635
+ ```
636
+ with dreadnode.run("my_run"):
637
+ # do some work here
638
+ pass
639
+ ```
640
+
641
+ Args:
642
+ name: The name of the run. If not provided, a random name will be generated.
643
+ tags: A list of tags to attach to the run.
644
+ params: A dictionary of parameters to attach to the run.
645
+ project: The project name to associate the run with. If not provided,
646
+ the project passed to `configure()` will be used, or the
647
+ run will be associated with a default project.
648
+ **attributes: Additional attributes to attach to the run span.
649
+ """
650
+ if not self._initialized:
651
+ self.initialize()
652
+
653
+ if name is None:
654
+ name = f"{coolname.generate_slug(2)}-{random.randint(100, 999)}" # noqa: S311
655
+
656
+ return RunSpan(
657
+ name=name,
658
+ project=project or self.project or "default",
659
+ attributes=attributes,
660
+ tracer=self._get_tracer(),
661
+ params=params,
662
+ tags=tags,
663
+ file_system=self._fs,
664
+ prefix_path=self._fs_prefix,
665
+ )
666
+
667
+ @handle_internal_errors()
668
+ def push_update(self) -> None:
669
+ """
670
+ Push any pending metric or parameter data to the server.
671
+
672
+ This is useful for ensuring that the UI is up to date with the
673
+ latest data. Otherwise, all data for the run will be pushed
674
+ automatically when the run is closed.
675
+
676
+ Example:
677
+ ```
678
+ with dreadnode.run("my_run"):
679
+ dreadnode.log_params(...)
680
+ dreadnode.log_metric(...)
681
+ dreadnode.push_update()
682
+ """
683
+ if (run := current_run_span.get()) is None:
684
+ raise RuntimeError("Run updates must be pushed within a run")
685
+
686
+ run.push_update()
687
+
688
+ @handle_internal_errors()
689
+ def log_param(
690
+ self,
691
+ key: str,
692
+ value: JsonValue,
693
+ *,
694
+ to: ToObject = "task-or-run",
695
+ ) -> None:
696
+ """
697
+ Log a single parameter to the current task or run.
698
+
699
+ Parameters are key-value pairs that are associated with the task or run
700
+ and can be used to track configuration values, hyperparameters, or other
701
+ metadata.
702
+
703
+ Example:
704
+ ```
705
+ with dreadnode.run("my_run") as run:
706
+ run.log_param("param_name", "param_value")
707
+ ```
708
+
709
+ Args:
710
+ key: The name of the parameter.
711
+ value: The value of the parameter.
712
+ to: The target object to log the parameter to. Can be "task-or-run" or "run".
713
+ Defaults to "task-or-run". If "task-or-run", the parameter will be logged
714
+ to the current task or run, whichever is the nearest ancestor.
715
+ """
716
+ self.log_params(to=to, **{key: value})
717
+
718
+ @handle_internal_errors()
719
+ def log_params(self, to: ToObject = "run", **params: JsonValue) -> None:
720
+ """
721
+ Log multiple parameters to the current task or run.
722
+
723
+ Parameters are key-value pairs that are associated with the task or run
724
+ and can be used to track configuration values, hyperparameters, or other
725
+ metadata.
726
+
727
+ Example:
728
+ ```
729
+ with dreadnode.run("my_run") as run:
730
+ run.log_params(
731
+ param1="value1",
732
+ param2="value2"
733
+ )
734
+ ```
735
+
736
+ Args:
737
+ to: The target object to log the parameters to. Can be "task-or-run" or "run".
738
+ Defaults to "task-or-run". If "task-or-run", the parameters will be logged
739
+ to the current task or run, whichever is the nearest ancestor.
740
+ **params: The parameters to log. Each parameter is a key-value pair.
741
+ """
742
+ task = current_task_span.get()
743
+ run = current_run_span.get()
744
+
745
+ target = (task or run) if to == "task-or-run" else run
746
+ if target is None:
747
+ raise RuntimeError("log_params() must be called within a run")
748
+
749
+ target.log_params(**params)
750
+
751
+ @t.overload
752
+ def log_metric(
753
+ self,
754
+ key: str,
755
+ value: float | bool,
756
+ *,
757
+ step: int = 0,
758
+ origin: t.Any | None = None,
759
+ timestamp: datetime | None = None,
760
+ mode: MetricMode = "direct",
761
+ to: ToObject = "task-or-run",
762
+ ) -> None:
763
+ """
764
+ Log a single metric to the current task or run.
765
+
766
+ Metrics are some measurement or recorded value related to the task or run.
767
+ They can be used to track performance, resource usage, or other quantitative data.
768
+
769
+ Example:
770
+ ```
771
+ with dreadnode.run("my_run") as run:
772
+ run.log_metric("metric_name", 42.0)
773
+ ```
774
+
775
+ Args:
776
+ key: The name of the metric.
777
+ value: The value of the metric.
778
+ step: The step of the metric.
779
+ origin: The origin of the metric - can be provided any object which was logged
780
+ as an input or output anywhere in the run.
781
+ timestamp: The timestamp of the metric - defaults to the current time.
782
+ mode: The aggregation mode to use for the metric. Helpful when you want to let
783
+ the library take care of translating your raw values into better representations.
784
+ - direct: do not modify the value at all (default)
785
+ - min: the lowest observed value reported for this metric
786
+ - max: the highest observed value reported for this metric
787
+ - avg: the average of all reported values for this metric
788
+ - sum: the cumulative sum of all reported values for this metric
789
+ - count: increment every time this metric is logged - disregard value
790
+ to: The target object to log the metric to. Can be "task-or-run" or "run".
791
+ Defaults to "task-or-run". If "task-or-run", the metric will be logged
792
+ to the current task or run, whichever is the nearest ancestor.
793
+ """
794
+
795
+ @t.overload
796
+ def log_metric(
797
+ self,
798
+ key: str,
799
+ value: Metric,
800
+ *,
801
+ origin: t.Any | None = None,
802
+ mode: MetricMode = "direct",
803
+ to: ToObject = "task-or-run",
804
+ ) -> None:
805
+ """
806
+ Log a single metric to the current task or run.
807
+
808
+ Metrics are some measurement or recorded value related to the task or run.
809
+ They can be used to track performance, resource usage, or other quantitative data.
810
+
811
+ Example:
812
+ ```
813
+ with dreadnode.run("my_run") as run:
814
+ run.log_metric("metric_name", 42.0)
815
+ ```
816
+
817
+ Args:
818
+ key: The name of the metric.
819
+ value: The metric object.
820
+ origin: The origin of the metric - can be provided any object which was logged
821
+ as an input or output anywhere in the run.
822
+ mode: The aggregation mode to use for the metric. Helpful when you want to let
823
+ the library take care of translating your raw values into better representations.
824
+ - direct: do not modify the value at all (default)
825
+ - min: always report the lowest ovbserved value for this metric
826
+ - max: always report the highest observed value for this metric
827
+ - avg: report the average of all values for this metric
828
+ - sum: report a rolling sum of all values for this metric
829
+ - count: report the number of times this metric has been logged
830
+ to: The target object to log the metric to. Can be "task-or-run" or "run".
831
+ Defaults to "task-or-run". If "task-or-run", the metric will be logged
832
+ to the current task or run, whichever is the nearest ancestor.
833
+ """
834
+
835
+ @handle_internal_errors()
836
+ def log_metric(
837
+ self,
838
+ key: str,
839
+ value: float | bool | Metric,
840
+ *,
841
+ step: int = 0,
842
+ origin: t.Any | None = None,
843
+ timestamp: datetime | None = None,
844
+ mode: MetricMode = "direct",
845
+ to: ToObject = "task-or-run",
846
+ ) -> None:
847
+ task = current_task_span.get()
848
+ run = current_run_span.get()
849
+
850
+ target = (task or run) if to == "task-or-run" else run
851
+ if target is None:
852
+ raise RuntimeError("log_metric() must be called within a run")
853
+
854
+ metric = (
855
+ value
856
+ if isinstance(value, Metric)
857
+ else Metric(float(value), step, timestamp or datetime.now(timezone.utc))
858
+ )
859
+ target.log_metric(key, metric, origin=origin, mode=mode)
860
+
861
+ @handle_internal_errors()
862
+ def log_artifact(
863
+ self,
864
+ local_uri: str | Path,
865
+ ) -> None:
866
+ """
867
+ Log a file or directory artifact to the current run.
868
+
869
+ This method uploads a local file or directory to the artifact storage associated with the run.
870
+
871
+ Examples:
872
+ Log a single file:
873
+ ```
874
+ with dreadnode.run("my_run") as run:
875
+ # Save a file
876
+ with open("results.json", "w") as f:
877
+ json.dump(results, f)
878
+
879
+ # Log it as an artifact
880
+ run.log_artifact("results.json")
881
+ ```
882
+
883
+ Log a directory:
884
+ ```
885
+ with dreadnode.run("my_run") as run:
886
+ # Create a directory with model files
887
+ os.makedirs("model_output", exist_ok=True)
888
+ save_model("model_output/model.pkl")
889
+ save_config("model_output/config.yaml")
890
+
891
+ # Log the entire directory as an artifact
892
+ run.log_artifact("model_output")
893
+ ```
894
+
895
+ Args:
896
+ local_uri: The local path to the file to upload.
897
+ to: The target object to log the artifact to. Only "run" is supported.
898
+ """
899
+ if (run := current_run_span.get()) is None:
900
+ raise RuntimeError("log_artifact() must be called within a run")
901
+
902
+ run.log_artifact(local_uri=local_uri)
903
+
904
+ @handle_internal_errors()
905
+ def log_input(
906
+ self,
907
+ name: str,
908
+ value: JsonValue,
909
+ *,
910
+ label: str | None = None,
911
+ to: ToObject = "task-or-run",
912
+ **attributes: t.Any,
913
+ ) -> None:
914
+ """
915
+ Log a single input to the current task or run.
916
+
917
+ Inputs can be any runtime object, which are serialized, stored, and tracked
918
+ in the Dreadnode UI.
919
+
920
+ Example:
921
+ ```
922
+ @dreadnode.task
923
+ async def my_task(x: int) -> int:
924
+ dreadnode.log_input("input_name", x)
925
+ return x * 2
926
+
927
+ with dreadnode.run("my_run"):
928
+ dreadnode.log_input("input_name", some_dataframe)
929
+
930
+ await my_task(2)
931
+ ```
932
+ """
933
+ task = current_task_span.get()
934
+ run = current_run_span.get()
935
+
936
+ target = (task or run) if to == "task-or-run" else run
937
+ if target is None:
938
+ raise RuntimeError("log_inputs() must be called within a run")
939
+
940
+ target.log_input(name, value, label=label, **attributes)
941
+
942
+ @handle_internal_errors()
943
+ def log_inputs(
944
+ self,
945
+ to: ToObject = "task-or-run",
946
+ **inputs: JsonValue,
947
+ ) -> None:
948
+ """
949
+ Log multiple inputs to the current task or run.
950
+
951
+ See `log_input()` for more details.
952
+ """
953
+ for name, value in inputs.items():
954
+ self.log_input(name, value, to=to)
955
+
956
+ @handle_internal_errors()
957
+ def log_output(
958
+ self,
959
+ name: str,
960
+ value: t.Any,
961
+ *,
962
+ label: str | None = None,
963
+ to: ToObject = "task-or-run",
964
+ **attributes: JsonValue,
965
+ ) -> None:
966
+ """
967
+ Log a single output to the current task or run.
968
+
969
+ Outputs can be any runtime object, which are serialized, stored, and tracked
970
+ in the Dreadnode UI.
971
+
972
+ Example:
973
+ ```
974
+ @dreadnode.task
975
+ async def my_task(x: int) -> int:
976
+ result = x * 2
977
+ dreadnode.log_output("result", x * 2)
978
+ return result
979
+
980
+ with dreadnode.run("my_run"):
981
+ await my_task(2)
982
+
983
+ dreadnode.log_output("other", 123)
984
+ ```
985
+ """
986
+ task = current_task_span.get()
987
+ run = current_run_span.get()
988
+
989
+ target = (task or run) if to == "task-or-run" else run
990
+ if target is None:
991
+ raise RuntimeError(
992
+ "log_output() must be called within a run or a task",
993
+ )
994
+
995
+ target.log_output(name, value, label=label, **attributes)
996
+
997
+ @handle_internal_errors()
998
+ def log_outputs(
999
+ self,
1000
+ to: ToObject = "task-or-run",
1001
+ **outputs: JsonValue,
1002
+ ) -> None:
1003
+ """
1004
+ Log multiple outputs to the current task or run.
1005
+
1006
+ See `log_output()` for more details.
1007
+ """
1008
+ for name, value in outputs.items():
1009
+ self.log_output(name, value, to=to)
1010
+
1011
+ @handle_internal_errors()
1012
+ def link_objects(self, origin: t.Any, link: t.Any, **attributes: JsonValue) -> None:
1013
+ """
1014
+ Associate two runtime objects with each other.
1015
+
1016
+ This is useful for linking any two objects which are related to
1017
+ each other, such as a model and its training data, or an input
1018
+ prompt and the resulting output.
1019
+
1020
+ Example:
1021
+ ```
1022
+ with dreadnode.run("my_run") as run:
1023
+ model = SomeModel()
1024
+ data = SomeData()
1025
+
1026
+ run.link_objects(model, data)
1027
+ ```
1028
+
1029
+ Args:
1030
+ origin: The origin object to link from.
1031
+ link: The linked object to link to.
1032
+ **attributes: Additional attributes to attach to the link.
1033
+ """
1034
+ if (run := current_run_span.get()) is None:
1035
+ raise RuntimeError("link() must be called within a run")
1036
+
1037
+ origin_hash = run.log_object(origin)
1038
+ link_hash = run.log_object(link)
1039
+ run.link_objects(origin_hash, link_hash, **attributes)
1040
+
1041
+
1042
+ DEFAULT_INSTANCE = Dreadnode()