dreadnode 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,819 @@
1
+ import logging
2
+ import re
3
+ import types
4
+ import typing as t
5
+ from contextvars import ContextVar, Token
6
+ from copy import deepcopy
7
+ from datetime import datetime, timezone
8
+ from pathlib import Path
9
+
10
+ import typing_extensions as te
11
+ from fsspec import AbstractFileSystem # type: ignore [import-untyped]
12
+ from logfire._internal.json_encoder import logfire_json_dumps as json_dumps
13
+ from logfire._internal.json_schema import (
14
+ JsonSchemaProperties,
15
+ attributes_json_schema,
16
+ create_json_schema,
17
+ )
18
+ from logfire._internal.tracer import OPEN_SPANS
19
+ from logfire._internal.utils import uniquify_sequence
20
+ from opentelemetry import context as context_api
21
+ from opentelemetry import trace as trace_api
22
+ from opentelemetry.sdk.trace import ReadableSpan
23
+ from opentelemetry.trace import Tracer
24
+ from opentelemetry.util import types as otel_types
25
+ from ulid import ULID
26
+
27
+ from dreadnode.artifact.merger import ArtifactMerger
28
+ from dreadnode.artifact.storage import ArtifactStorage
29
+ from dreadnode.artifact.tree_builder import ArtifactTreeBuilder, DirectoryNode
30
+ from dreadnode.constants import MAX_INLINE_OBJECT_BYTES
31
+ from dreadnode.metric import Metric, MetricAggMode, MetricDict
32
+ from dreadnode.object import Object, ObjectRef, ObjectUri, ObjectVal
33
+ from dreadnode.serialization import Serialized, serialize
34
+ from dreadnode.types import UNSET, AnyDict, JsonDict, JsonValue, Unset
35
+ from dreadnode.version import VERSION
36
+
37
+ from .constants import (
38
+ EVENT_ATTRIBUTE_LINK_HASH,
39
+ EVENT_ATTRIBUTE_OBJECT_HASH,
40
+ EVENT_ATTRIBUTE_OBJECT_LABEL,
41
+ EVENT_ATTRIBUTE_ORIGIN_SPAN_ID,
42
+ EVENT_NAME_OBJECT,
43
+ EVENT_NAME_OBJECT_INPUT,
44
+ EVENT_NAME_OBJECT_LINK,
45
+ EVENT_NAME_OBJECT_METRIC,
46
+ EVENT_NAME_OBJECT_OUTPUT,
47
+ METRIC_ATTRIBUTE_SOURCE_HASH,
48
+ SPAN_ATTRIBUTE_ARTIFACTS,
49
+ SPAN_ATTRIBUTE_INPUTS,
50
+ SPAN_ATTRIBUTE_LABEL,
51
+ SPAN_ATTRIBUTE_LARGE_ATTRIBUTES,
52
+ SPAN_ATTRIBUTE_METRICS,
53
+ SPAN_ATTRIBUTE_OBJECT_SCHEMAS,
54
+ SPAN_ATTRIBUTE_OBJECTS,
55
+ SPAN_ATTRIBUTE_OUTPUTS,
56
+ SPAN_ATTRIBUTE_PARAMS,
57
+ SPAN_ATTRIBUTE_PARENT_TASK_ID,
58
+ SPAN_ATTRIBUTE_PROJECT,
59
+ SPAN_ATTRIBUTE_RUN_ID,
60
+ SPAN_ATTRIBUTE_SCHEMA,
61
+ SPAN_ATTRIBUTE_TAGS_,
62
+ SPAN_ATTRIBUTE_TYPE,
63
+ SPAN_ATTRIBUTE_VERSION,
64
+ SpanType,
65
+ )
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+ R = t.TypeVar("R")
70
+
71
+
72
+ current_task_span: ContextVar["TaskSpan[t.Any] | None"] = ContextVar(
73
+ "current_task_span",
74
+ default=None,
75
+ )
76
+ current_run_span: ContextVar["RunSpan | None"] = ContextVar(
77
+ "current_run_span",
78
+ default=None,
79
+ )
80
+
81
+
82
+ class Span(ReadableSpan):
83
+ def __init__(
84
+ self,
85
+ name: str,
86
+ attributes: AnyDict,
87
+ tracer: Tracer,
88
+ *,
89
+ label: str | None = None,
90
+ type: SpanType = "span",
91
+ tags: t.Sequence[str] | None = None,
92
+ ) -> None:
93
+ self._label = label or ""
94
+ self._span_name = name
95
+ self._pre_attributes = {
96
+ SPAN_ATTRIBUTE_VERSION: VERSION,
97
+ SPAN_ATTRIBUTE_TYPE: type,
98
+ SPAN_ATTRIBUTE_LABEL: self._label,
99
+ SPAN_ATTRIBUTE_TAGS_: uniquify_sequence(tags or ()),
100
+ **attributes,
101
+ }
102
+ self._tracer = tracer
103
+
104
+ self._schema: JsonSchemaProperties = JsonSchemaProperties({})
105
+ self._token: object | None = None # trace sdk context
106
+ self._span: trace_api.Span | None = None
107
+
108
+ if not t.TYPE_CHECKING:
109
+
110
+ def __getattr__(self, name: str) -> t.Any:
111
+ return getattr(self._span, name)
112
+
113
+ def __enter__(self) -> te.Self:
114
+ if self._span is None:
115
+ self._span = self._tracer.start_span(
116
+ name=self._span_name,
117
+ attributes=prepare_otlp_attributes(self._pre_attributes),
118
+ )
119
+
120
+ self._span.__enter__()
121
+
122
+ OPEN_SPANS.add(self._span) # type: ignore [arg-type]
123
+
124
+ if self._token is None:
125
+ self._token = context_api.attach(trace_api.set_span_in_context(self._span))
126
+
127
+ return self
128
+
129
+ def __exit__(
130
+ self,
131
+ exc_type: type[BaseException] | None,
132
+ exc_value: BaseException | None,
133
+ traceback: types.TracebackType | None,
134
+ ) -> None:
135
+ if self._token is None or self._span is None:
136
+ return
137
+
138
+ context_api.detach(self._token) # type: ignore [arg-type]
139
+ self._token = None
140
+
141
+ if not self._span.is_recording():
142
+ return
143
+
144
+ self._span.set_attribute(
145
+ SPAN_ATTRIBUTE_SCHEMA,
146
+ attributes_json_schema(self._schema) if self._schema else r"{}",
147
+ )
148
+ self._span.__exit__(exc_type, exc_value, traceback)
149
+
150
+ OPEN_SPANS.discard(self._span) # type: ignore [arg-type]
151
+
152
+ @property
153
+ def span_id(self) -> str:
154
+ if self._span is None:
155
+ raise ValueError("Span is not active")
156
+ return trace_api.format_span_id(self._span.get_span_context().span_id)
157
+
158
+ @property
159
+ def trace_id(self) -> str:
160
+ if self._span is None:
161
+ raise ValueError("Span is not active")
162
+ return trace_api.format_trace_id(self._span.get_span_context().trace_id)
163
+
164
+ @property
165
+ def is_recording(self) -> bool:
166
+ if self._span is None:
167
+ return False
168
+ return self._span.is_recording()
169
+
170
+ @property
171
+ def tags(self) -> tuple[str, ...]:
172
+ return tuple(self.get_attribute(SPAN_ATTRIBUTE_TAGS_, ()))
173
+
174
+ @tags.setter
175
+ def tags(self, new_tags: t.Sequence[str]) -> None:
176
+ self.set_attribute(SPAN_ATTRIBUTE_TAGS_, uniquify_sequence(new_tags))
177
+
178
+ def set_attribute(
179
+ self,
180
+ key: str,
181
+ value: t.Any,
182
+ *,
183
+ schema: bool = True,
184
+ raw: bool = False,
185
+ ) -> None:
186
+ self._added_attributes = True
187
+ if schema and raw is False:
188
+ self._schema[key] = create_json_schema(value, set())
189
+ otel_value = self._pre_attributes[key] = value if raw else prepare_otlp_attribute(value)
190
+ if self._span is not None:
191
+ self._span.set_attribute(key, otel_value)
192
+ self._pre_attributes[key] = otel_value
193
+
194
+ def set_attributes(self, attributes: AnyDict) -> None:
195
+ for key, value in attributes.items():
196
+ self.set_attribute(key, value)
197
+
198
+ def get_attributes(self) -> AnyDict:
199
+ if self._span is not None:
200
+ return getattr(self._span, "attributes", {})
201
+ return self._pre_attributes
202
+
203
+ def get_attribute(self, key: str, default: t.Any) -> t.Any:
204
+ return self.get_attributes().get(key, default)
205
+
206
+ def log_event(
207
+ self,
208
+ name: str,
209
+ attributes: AnyDict | None = None,
210
+ ) -> None:
211
+ if self._span is not None:
212
+ self._span.add_event(
213
+ name,
214
+ attributes=prepare_otlp_attributes(attributes or {}),
215
+ )
216
+
217
+
218
+ class RunUpdateSpan(Span):
219
+ def __init__(
220
+ self,
221
+ run_id: str,
222
+ tracer: Tracer,
223
+ project: str,
224
+ *,
225
+ metrics: MetricDict | None = None,
226
+ params: JsonDict | None = None,
227
+ inputs: JsonDict | None = None,
228
+ ) -> None:
229
+ attributes: AnyDict = {
230
+ SPAN_ATTRIBUTE_RUN_ID: run_id,
231
+ SPAN_ATTRIBUTE_PROJECT: project,
232
+ }
233
+
234
+ if metrics:
235
+ attributes[SPAN_ATTRIBUTE_METRICS] = metrics
236
+ if params:
237
+ attributes[SPAN_ATTRIBUTE_PARAMS] = params
238
+ if inputs:
239
+ attributes[SPAN_ATTRIBUTE_INPUTS] = inputs
240
+
241
+ super().__init__(f"run.{run_id}.update", attributes, tracer, type="run_update")
242
+
243
+
244
+ class RunSpan(Span):
245
+ def __init__(
246
+ self,
247
+ name: str,
248
+ project: str,
249
+ attributes: AnyDict,
250
+ tracer: Tracer,
251
+ file_system: AbstractFileSystem,
252
+ prefix_path: str,
253
+ params: AnyDict | None = None,
254
+ metrics: MetricDict | None = None,
255
+ run_id: str | None = None,
256
+ tags: t.Sequence[str] | None = None,
257
+ ) -> None:
258
+ self._params = params or {}
259
+ self._metrics = metrics or {}
260
+ self._objects: dict[str, Object] = {}
261
+ self._object_schemas: dict[str, JsonDict] = {}
262
+ self._inputs: list[ObjectRef] = []
263
+ self._outputs: list[ObjectRef] = []
264
+ self._artifact_storage = ArtifactStorage(file_system=file_system)
265
+ self._artifacts: list[DirectoryNode] = []
266
+ self._artifact_merger = ArtifactMerger()
267
+ self._artifact_tree_builder = ArtifactTreeBuilder(
268
+ storage=self._artifact_storage,
269
+ prefix_path=prefix_path,
270
+ )
271
+ self.project = project
272
+
273
+ self._last_pushed_params = deepcopy(self._params)
274
+ self._last_pushed_metrics = deepcopy(self._metrics)
275
+
276
+ self._context_token: Token[RunSpan | None] | None = None # contextvars context
277
+ self._file_system = file_system
278
+ self._prefix_path = prefix_path
279
+
280
+ attributes = {
281
+ SPAN_ATTRIBUTE_RUN_ID: str(run_id or ULID()),
282
+ SPAN_ATTRIBUTE_PROJECT: project,
283
+ SPAN_ATTRIBUTE_PARAMS: self._params,
284
+ SPAN_ATTRIBUTE_METRICS: self._metrics,
285
+ **attributes,
286
+ }
287
+ super().__init__(name, attributes, tracer, type="run", tags=tags)
288
+
289
+ def __enter__(self) -> te.Self:
290
+ if current_run_span.get() is not None:
291
+ raise RuntimeError("You cannot start a run span within another run")
292
+
293
+ self._context_token = current_run_span.set(self)
294
+ return super().__enter__()
295
+
296
+ def __exit__(
297
+ self,
298
+ exc_type: type[BaseException] | None,
299
+ exc_value: BaseException | None,
300
+ traceback: types.TracebackType | None,
301
+ ) -> None:
302
+ self.set_attribute(SPAN_ATTRIBUTE_PARAMS, self._params)
303
+ self.set_attribute(SPAN_ATTRIBUTE_INPUTS, self._inputs, schema=False)
304
+ self.set_attribute(SPAN_ATTRIBUTE_METRICS, self._metrics, schema=False)
305
+ self.set_attribute(SPAN_ATTRIBUTE_OBJECTS, self._objects, schema=False)
306
+ self.set_attribute(
307
+ SPAN_ATTRIBUTE_OBJECT_SCHEMAS,
308
+ self._object_schemas,
309
+ schema=False,
310
+ )
311
+ self.set_attribute(SPAN_ATTRIBUTE_ARTIFACTS, self._artifacts, schema=False)
312
+
313
+ # Mark our objects attribute as large so it's stored separately
314
+ self.set_attribute(
315
+ SPAN_ATTRIBUTE_LARGE_ATTRIBUTES,
316
+ [SPAN_ATTRIBUTE_OBJECTS, SPAN_ATTRIBUTE_OBJECT_SCHEMAS],
317
+ raw=True,
318
+ )
319
+
320
+ super().__exit__(exc_type, exc_value, traceback)
321
+ if self._context_token is not None:
322
+ current_run_span.reset(self._context_token)
323
+
324
+ def push_update(self) -> None:
325
+ if self._span is None:
326
+ return
327
+
328
+ metrics: MetricDict | None = None
329
+ if self._last_pushed_metrics != self._metrics:
330
+ metrics = self._metrics
331
+ self._last_pushed_metrics = deepcopy(self._metrics)
332
+
333
+ params: JsonDict | None = None
334
+ if self._last_pushed_params != self._params:
335
+ params = self._params
336
+ self._last_pushed_params = deepcopy(self._params)
337
+
338
+ if metrics is None and params is None:
339
+ return
340
+
341
+ with RunUpdateSpan(
342
+ run_id=self.run_id,
343
+ project=self.project,
344
+ tracer=self._tracer,
345
+ params=params,
346
+ metrics=metrics,
347
+ ):
348
+ pass
349
+
350
+ @property
351
+ def run_id(self) -> str:
352
+ return str(self.get_attribute(SPAN_ATTRIBUTE_RUN_ID, ""))
353
+
354
+ def log_object(
355
+ self,
356
+ value: t.Any,
357
+ *,
358
+ label: str | None = None,
359
+ event_name: str = EVENT_NAME_OBJECT,
360
+ **attributes: JsonValue,
361
+ ) -> str:
362
+ serialized = serialize(value)
363
+ data_hash = serialized.data_hash
364
+ schema_hash = serialized.schema_hash
365
+
366
+ # Store object if we haven't already
367
+ if data_hash not in self._objects:
368
+ self._objects[data_hash] = self._create_object(serialized)
369
+
370
+ object_ = self._objects[data_hash]
371
+
372
+ # Store schema if new
373
+ if schema_hash not in self._object_schemas:
374
+ self._object_schemas[schema_hash] = serialized.schema
375
+
376
+ # Build event attributes
377
+ event_attributes = {
378
+ **attributes,
379
+ EVENT_ATTRIBUTE_OBJECT_HASH: object_.hash,
380
+ EVENT_ATTRIBUTE_ORIGIN_SPAN_ID: trace_api.format_span_id(
381
+ trace_api.get_current_span().get_span_context().span_id,
382
+ ),
383
+ }
384
+ if label is not None:
385
+ event_attributes[EVENT_ATTRIBUTE_OBJECT_LABEL] = label
386
+
387
+ self.log_event(name=event_name, attributes=event_attributes)
388
+ return object_.hash
389
+
390
+ def _store_file_by_hash(self, data: bytes, full_path: str) -> str:
391
+ """
392
+ Writes data to the given full_path in the object store if it doesn't already exist.
393
+
394
+ Args:
395
+ data: Content to write.
396
+ full_path: The path in the object store (e.g., S3 key or local path).
397
+
398
+ Returns:
399
+ The unstrip_protocol version of the full path (for object store URI).
400
+ """
401
+ if not self._file_system.exists(full_path):
402
+ logger.debug("Storing new object at: %s", full_path)
403
+ with self._file_system.open(full_path, "wb") as f:
404
+ f.write(data)
405
+
406
+ return str(self._file_system.unstrip_protocol(full_path))
407
+
408
+ def _create_object(self, serialized: Serialized) -> Object:
409
+ """Create an ObjectVal or ObjectUri depending on size."""
410
+ data = serialized.data
411
+ data_bytes = serialized.data_bytes
412
+ data_len = serialized.data_len
413
+ data_hash = serialized.data_hash
414
+ schema_hash = serialized.schema_hash
415
+
416
+ if data is None or data_bytes is None or data_len <= MAX_INLINE_OBJECT_BYTES:
417
+ return ObjectVal(
418
+ hash=data_hash,
419
+ value=data,
420
+ schema_hash=schema_hash,
421
+ )
422
+
423
+ # Offload to file system (e.g., S3)
424
+ full_path = f"{self._prefix_path.rstrip('/')}/{data_hash}"
425
+ object_uri = self._store_file_by_hash(data_bytes, full_path)
426
+
427
+ return ObjectUri(
428
+ hash=data_hash,
429
+ uri=object_uri,
430
+ schema_hash=schema_hash,
431
+ size=data_len,
432
+ )
433
+
434
+ def get_object(self, hash_: str) -> t.Any:
435
+ return self._objects[hash_]
436
+
437
+ def link_objects(
438
+ self,
439
+ object_hash: str,
440
+ link_hash: str,
441
+ **attributes: JsonValue,
442
+ ) -> None:
443
+ self.log_event(
444
+ name=EVENT_NAME_OBJECT_LINK,
445
+ attributes={
446
+ **attributes,
447
+ EVENT_ATTRIBUTE_OBJECT_HASH: object_hash,
448
+ EVENT_ATTRIBUTE_LINK_HASH: link_hash,
449
+ EVENT_ATTRIBUTE_ORIGIN_SPAN_ID: (
450
+ trace_api.format_span_id(
451
+ trace_api.get_current_span().get_span_context().span_id,
452
+ )
453
+ ),
454
+ },
455
+ )
456
+
457
+ @property
458
+ def params(self) -> AnyDict:
459
+ return self._params
460
+
461
+ def log_param(self, key: str, value: t.Any) -> None:
462
+ self.log_params(**{key: value})
463
+
464
+ def log_params(self, **params: t.Any) -> None:
465
+ for key, value in params.items():
466
+ self._params[key] = value
467
+
468
+ # Always push updates for run params
469
+ self.push_update()
470
+
471
+ @property
472
+ def inputs(self) -> AnyDict:
473
+ return {ref.name: self.get_object(ref.hash) for ref in self._inputs}
474
+
475
+ def log_input(
476
+ self,
477
+ name: str,
478
+ value: t.Any,
479
+ *,
480
+ label: str | None = None,
481
+ **attributes: JsonValue,
482
+ ) -> None:
483
+ label = label or re.sub(r"\W+", "_", name.lower())
484
+ hash_ = self.log_object(
485
+ value,
486
+ label=label,
487
+ event_name=EVENT_NAME_OBJECT_INPUT,
488
+ **attributes,
489
+ )
490
+ self._inputs.append(ObjectRef(name, label=label, hash=hash_))
491
+
492
+ def log_artifact(
493
+ self,
494
+ local_uri: str | Path,
495
+ ) -> None:
496
+ """
497
+ Logs a local file or directory as an artifact to the object store.
498
+ Preserves directory structure and uses content hashing for deduplication.
499
+
500
+ Args:
501
+ local_uri: Path to the local file or directory
502
+
503
+ Returns:
504
+ DirectoryNode representing the artifact's tree structure
505
+
506
+ Raises:
507
+ FileNotFoundError: If the path doesn't exist
508
+ """
509
+
510
+ artifact_tree = self._artifact_tree_builder.process_artifact(local_uri)
511
+
512
+ self._artifact_merger.add_tree(artifact_tree)
513
+
514
+ self._artifacts = self._artifact_merger.get_merged_trees()
515
+
516
+ @property
517
+ def metrics(self) -> MetricDict:
518
+ return self._metrics
519
+
520
+ @t.overload
521
+ def log_metric(
522
+ self,
523
+ key: str,
524
+ value: float | bool,
525
+ *,
526
+ step: int = 0,
527
+ origin: t.Any | None = None,
528
+ timestamp: datetime | None = None,
529
+ mode: MetricAggMode | None = None,
530
+ attributes: JsonDict | None = None,
531
+ ) -> None: ...
532
+
533
+ @t.overload
534
+ def log_metric(
535
+ self,
536
+ key: str,
537
+ value: Metric,
538
+ *,
539
+ origin: t.Any | None = None,
540
+ mode: MetricAggMode | None = None,
541
+ ) -> None: ...
542
+
543
+ def log_metric(
544
+ self,
545
+ key: str,
546
+ value: float | bool | Metric,
547
+ *,
548
+ step: int = 0,
549
+ origin: t.Any | None = None,
550
+ timestamp: datetime | None = None,
551
+ mode: MetricAggMode | None = None,
552
+ attributes: JsonDict | None = None,
553
+ ) -> None:
554
+ metric = (
555
+ value
556
+ if isinstance(value, Metric)
557
+ else Metric(
558
+ float(value), step, timestamp or datetime.now(timezone.utc), attributes or {}
559
+ )
560
+ )
561
+
562
+ if origin is not None:
563
+ origin_hash = self.log_object(
564
+ origin,
565
+ label=key,
566
+ event_name=EVENT_NAME_OBJECT_METRIC,
567
+ )
568
+ metric.attributes[METRIC_ATTRIBUTE_SOURCE_HASH] = origin_hash
569
+
570
+ metrics = self._metrics.setdefault(key, [])
571
+ if mode is not None:
572
+ metric = metric.apply_mode(mode, metrics)
573
+ metrics.append(metric)
574
+
575
+ @property
576
+ def outputs(self) -> AnyDict:
577
+ return {ref.name: self.get_object(ref.hash) for ref in self._outputs}
578
+
579
+ def log_output(
580
+ self,
581
+ name: str,
582
+ value: t.Any,
583
+ *,
584
+ label: str | None = None,
585
+ **attributes: JsonValue,
586
+ ) -> None:
587
+ label = label or re.sub(r"\W+", "_", name.lower())
588
+ hash_ = self.log_object(
589
+ value,
590
+ label=label,
591
+ event_name=EVENT_NAME_OBJECT_OUTPUT,
592
+ **attributes,
593
+ )
594
+ self._outputs.append(ObjectRef(name, label=label, hash=hash_))
595
+
596
+
597
+ class TaskSpan(Span, t.Generic[R]):
598
+ def __init__(
599
+ self,
600
+ name: str,
601
+ attributes: AnyDict,
602
+ run_id: str,
603
+ tracer: Tracer,
604
+ *,
605
+ label: str | None = None,
606
+ params: AnyDict | None = None,
607
+ metrics: MetricDict | None = None,
608
+ tags: t.Sequence[str] | None = None,
609
+ ) -> None:
610
+ self._params = params or {}
611
+ self._metrics = metrics or {}
612
+ self._inputs: list[ObjectRef] = []
613
+ self._outputs: list[ObjectRef] = []
614
+
615
+ self._output: R | Unset = UNSET # For the python output
616
+
617
+ self._context_token: Token[TaskSpan[t.Any] | None] | None = None # contextvars context
618
+
619
+ attributes = {
620
+ SPAN_ATTRIBUTE_RUN_ID: str(run_id),
621
+ SPAN_ATTRIBUTE_PARAMS: self._params,
622
+ SPAN_ATTRIBUTE_INPUTS: self._inputs,
623
+ SPAN_ATTRIBUTE_METRICS: self._metrics,
624
+ SPAN_ATTRIBUTE_OUTPUTS: self._outputs,
625
+ **attributes,
626
+ }
627
+ super().__init__(name, attributes, tracer, type="task", label=label, tags=tags)
628
+
629
+ def __enter__(self) -> te.Self:
630
+ self._parent_task = current_task_span.get()
631
+ if self._parent_task is not None:
632
+ self.set_attribute(SPAN_ATTRIBUTE_PARENT_TASK_ID, self._parent_task.span_id)
633
+
634
+ self._run = current_run_span.get()
635
+ if self._run is None:
636
+ raise RuntimeError("You cannot start a task span without a run")
637
+
638
+ self._context_token = current_task_span.set(self)
639
+ return super().__enter__()
640
+
641
+ def __exit__(
642
+ self,
643
+ exc_type: type[BaseException] | None,
644
+ exc_value: BaseException | None,
645
+ traceback: types.TracebackType | None,
646
+ ) -> None:
647
+ self.set_attribute(SPAN_ATTRIBUTE_PARAMS, self._params)
648
+ self.set_attribute(SPAN_ATTRIBUTE_INPUTS, self._inputs, schema=False)
649
+ self.set_attribute(SPAN_ATTRIBUTE_METRICS, self._metrics, schema=False)
650
+ self.set_attribute(SPAN_ATTRIBUTE_OUTPUTS, self._outputs, schema=False)
651
+ super().__exit__(exc_type, exc_value, traceback)
652
+ if self._context_token is not None:
653
+ current_task_span.reset(self._context_token)
654
+
655
+ @property
656
+ def run_id(self) -> str:
657
+ return str(self.get_attribute(SPAN_ATTRIBUTE_RUN_ID, ""))
658
+
659
+ @property
660
+ def parent_task_id(self) -> str:
661
+ return str(self.get_attribute(SPAN_ATTRIBUTE_PARENT_TASK_ID, ""))
662
+
663
+ @property
664
+ def run(self) -> RunSpan:
665
+ if self._run is None:
666
+ raise ValueError("Task span is not in an active run")
667
+ return self._run
668
+
669
+ @property
670
+ def outputs(self) -> AnyDict:
671
+ return {ref.name: self.run.get_object(ref.hash) for ref in self._outputs}
672
+
673
+ @property
674
+ def output(self) -> R:
675
+ if isinstance(self._output, Unset):
676
+ raise TypeError("Task output is not set")
677
+ return self._output
678
+
679
+ @output.setter
680
+ def output(self, value: R) -> None:
681
+ self._output = value
682
+
683
+ def log_output(
684
+ self,
685
+ name: str,
686
+ value: t.Any,
687
+ *,
688
+ label: str | None = None,
689
+ **attributes: JsonValue,
690
+ ) -> str:
691
+ label = label or re.sub(r"\W+", "_", name.lower())
692
+ hash_ = self.run.log_object(
693
+ value,
694
+ label=label,
695
+ event_name=EVENT_NAME_OBJECT_OUTPUT,
696
+ **attributes,
697
+ )
698
+ self._outputs.append(ObjectRef(name, label=label, hash=hash_))
699
+ return hash_
700
+
701
+ @property
702
+ def params(self) -> AnyDict:
703
+ return self._params
704
+
705
+ def log_param(self, key: str, value: t.Any) -> None:
706
+ self.log_params(**{key: value})
707
+
708
+ def log_params(self, **params: t.Any) -> None:
709
+ self._params.update(params)
710
+
711
+ @property
712
+ def inputs(self) -> AnyDict:
713
+ return {ref.name: self.run.get_object(ref.hash) for ref in self._inputs}
714
+
715
+ def log_input(
716
+ self,
717
+ name: str,
718
+ value: t.Any,
719
+ *,
720
+ label: str | None = None,
721
+ **attributes: JsonValue,
722
+ ) -> str:
723
+ label = label or re.sub(r"\W+", "_", name.lower())
724
+ hash_ = self.run.log_object(
725
+ value,
726
+ label=label,
727
+ event_name=EVENT_NAME_OBJECT_INPUT,
728
+ **attributes,
729
+ )
730
+ self._inputs.append(ObjectRef(name, label=label, hash=hash_))
731
+ return hash_
732
+
733
+ @property
734
+ def metrics(self) -> dict[str, list[Metric]]:
735
+ return self._metrics
736
+
737
+ @t.overload
738
+ def log_metric(
739
+ self,
740
+ key: str,
741
+ value: float | bool,
742
+ *,
743
+ step: int = 0,
744
+ origin: t.Any | None = None,
745
+ timestamp: datetime | None = None,
746
+ mode: MetricAggMode | None = None,
747
+ attributes: JsonDict | None = None,
748
+ ) -> None: ...
749
+
750
+ @t.overload
751
+ def log_metric(
752
+ self,
753
+ key: str,
754
+ value: Metric,
755
+ *,
756
+ origin: t.Any | None = None,
757
+ mode: MetricAggMode | None = None,
758
+ ) -> None: ...
759
+
760
+ def log_metric(
761
+ self,
762
+ key: str,
763
+ value: float | bool | Metric,
764
+ *,
765
+ step: int = 0,
766
+ origin: t.Any | None = None,
767
+ timestamp: datetime | None = None,
768
+ mode: MetricAggMode | None = None,
769
+ attributes: JsonDict | None = None,
770
+ ) -> None:
771
+ metric = (
772
+ value
773
+ if isinstance(value, Metric)
774
+ else Metric(
775
+ float(value), step, timestamp or datetime.now(timezone.utc), attributes or {}
776
+ )
777
+ )
778
+
779
+ if origin is not None:
780
+ origin_hash = self.run.log_object(
781
+ origin,
782
+ label=key,
783
+ event_name=EVENT_NAME_OBJECT_METRIC,
784
+ )
785
+ metric.attributes[METRIC_ATTRIBUTE_SOURCE_HASH] = origin_hash
786
+
787
+ metrics = self._metrics.setdefault(key, [])
788
+ if mode is not None:
789
+ metric = metric.apply_mode(mode, metrics)
790
+ metrics.append(metric)
791
+
792
+ # For every metric we log, also log it to the run
793
+ # with our `label` as a prefix.
794
+ #
795
+ # Don't include `source` and `mode` as we handled it here.
796
+ if (run := current_run_span.get()) is not None:
797
+ run.log_metric(f"{self._label}.{key}", metric)
798
+
799
+ def get_average_metric_value(self, key: str | None = None) -> float:
800
+ metrics = (
801
+ self._metrics.get(key, [])
802
+ if key is not None
803
+ else [m for ms in self._metrics.values() for m in ms]
804
+ )
805
+ return sum(metric.value for metric in metrics) / len(
806
+ metrics,
807
+ )
808
+
809
+
810
+ def prepare_otlp_attributes(
811
+ attributes: AnyDict,
812
+ ) -> dict[str, otel_types.AttributeValue]:
813
+ return {key: prepare_otlp_attribute(value) for key, value in attributes.items()}
814
+
815
+
816
+ def prepare_otlp_attribute(value: t.Any) -> otel_types.AttributeValue:
817
+ if isinstance(value, str | int | bool | float):
818
+ return value
819
+ return json_dumps(value)