wandb 0.19.10__py3-none-musllinux_1_2_aarch64.whl → 0.19.11__py3-none-musllinux_1_2_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +3 -3
  3. wandb/_pydantic/__init__.py +2 -3
  4. wandb/_pydantic/base.py +11 -31
  5. wandb/_pydantic/utils.py +8 -1
  6. wandb/_pydantic/v1_compat.py +3 -3
  7. wandb/apis/public/api.py +590 -22
  8. wandb/apis/public/artifacts.py +13 -5
  9. wandb/apis/public/automations.py +1 -1
  10. wandb/apis/public/integrations.py +22 -10
  11. wandb/apis/public/registries/__init__.py +0 -0
  12. wandb/apis/public/registries/_freezable_list.py +179 -0
  13. wandb/apis/public/{registries.py → registries/registries_search.py} +22 -129
  14. wandb/apis/public/registries/registry.py +357 -0
  15. wandb/apis/public/registries/utils.py +140 -0
  16. wandb/apis/public/runs.py +58 -56
  17. wandb/automations/__init__.py +16 -24
  18. wandb/automations/_filters/expressions.py +12 -10
  19. wandb/automations/_filters/operators.py +10 -19
  20. wandb/automations/_filters/run_metrics.py +231 -82
  21. wandb/automations/_generated/__init__.py +27 -34
  22. wandb/automations/_generated/create_automation.py +17 -0
  23. wandb/automations/_generated/delete_automation.py +17 -0
  24. wandb/automations/_generated/fragments.py +40 -25
  25. wandb/automations/_generated/{get_triggers.py → get_automations.py} +5 -5
  26. wandb/automations/_generated/{get_triggers_by_entity.py → get_automations_by_entity.py} +7 -5
  27. wandb/automations/_generated/operations.py +35 -98
  28. wandb/automations/_generated/update_automation.py +17 -0
  29. wandb/automations/_utils.py +178 -64
  30. wandb/automations/_validators.py +94 -2
  31. wandb/automations/actions.py +113 -98
  32. wandb/automations/automations.py +47 -69
  33. wandb/automations/events.py +139 -87
  34. wandb/automations/integrations.py +23 -4
  35. wandb/automations/scopes.py +22 -20
  36. wandb/bin/gpu_stats +0 -0
  37. wandb/bin/wandb-core +0 -0
  38. wandb/env.py +11 -0
  39. wandb/old/settings.py +4 -1
  40. wandb/proto/v3/wandb_internal_pb2.py +240 -236
  41. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  42. wandb/proto/v4/wandb_internal_pb2.py +236 -236
  43. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  44. wandb/proto/v5/wandb_internal_pb2.py +236 -236
  45. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  46. wandb/proto/v6/wandb_internal_pb2.py +236 -236
  47. wandb/proto/v6/wandb_telemetry_pb2.py +10 -10
  48. wandb/sdk/artifacts/_generated/__init__.py +42 -1
  49. wandb/sdk/artifacts/_generated/add_aliases.py +21 -0
  50. wandb/sdk/artifacts/_generated/delete_aliases.py +21 -0
  51. wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +67 -0
  52. wandb/sdk/artifacts/_generated/fragments.py +35 -0
  53. wandb/sdk/artifacts/_generated/input_types.py +12 -0
  54. wandb/sdk/artifacts/_generated/operations.py +101 -0
  55. wandb/sdk/artifacts/_generated/update_artifact.py +26 -0
  56. wandb/sdk/artifacts/_graphql_fragments.py +1 -0
  57. wandb/sdk/artifacts/_validators.py +120 -1
  58. wandb/sdk/artifacts/artifact.py +380 -203
  59. wandb/sdk/artifacts/artifact_file_cache.py +4 -6
  60. wandb/sdk/artifacts/artifact_manifest_entry.py +11 -2
  61. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +182 -1
  62. wandb/sdk/artifacts/storage_policy.py +3 -0
  63. wandb/sdk/data_types/video.py +46 -32
  64. wandb/sdk/interface/interface.py +2 -3
  65. wandb/sdk/internal/internal_api.py +21 -31
  66. wandb/sdk/internal/sender.py +5 -2
  67. wandb/sdk/launch/sweeps/utils.py +8 -0
  68. wandb/sdk/projects/_generated/__init__.py +47 -0
  69. wandb/sdk/projects/_generated/delete_project.py +22 -0
  70. wandb/sdk/projects/_generated/enums.py +4 -0
  71. wandb/sdk/projects/_generated/fetch_registry.py +22 -0
  72. wandb/sdk/projects/_generated/fragments.py +41 -0
  73. wandb/sdk/projects/_generated/input_types.py +13 -0
  74. wandb/sdk/projects/_generated/operations.py +88 -0
  75. wandb/sdk/projects/_generated/rename_project.py +27 -0
  76. wandb/sdk/projects/_generated/upsert_registry_project.py +27 -0
  77. wandb/sdk/service/service.py +9 -1
  78. wandb/sdk/wandb_init.py +32 -5
  79. wandb/sdk/wandb_run.py +37 -9
  80. wandb/sdk/wandb_settings.py +6 -7
  81. wandb/sdk/wandb_setup.py +12 -0
  82. wandb/util.py +7 -3
  83. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/METADATA +1 -1
  84. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/RECORD +87 -70
  85. wandb/automations/_generated/create_filter_trigger.py +0 -21
  86. wandb/automations/_generated/delete_trigger.py +0 -19
  87. wandb/automations/_generated/update_filter_trigger.py +0 -21
  88. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/WHEEL +0 -0
  89. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,13 @@
1
+ import wandb
1
2
  from wandb._pydantic import IS_PYDANTIC_V2
2
3
 
3
- from . import _filters as filters
4
- from . import actions, automations, events, scopes
5
- from .actions import ActionType, DoNothing, DoNotification, DoWebhook
6
- from .automations import Automation, NewAutomation, PreparedAutomation
4
+ from .actions import ActionType, DoNothing, SendNotification, SendWebhook
5
+ from .automations import Automation, NewAutomation
7
6
  from .events import (
8
7
  ArtifactEvent,
9
8
  EventType,
9
+ MetricChangeFilter,
10
+ MetricThresholdFilter,
10
11
  OnAddArtifactAlias,
11
12
  OnCreateArtifact,
12
13
  OnLinkArtifact,
@@ -26,33 +27,23 @@ if not IS_PYDANTIC_V2:
26
27
  # - Drop support for Pydantic v1
27
28
  # - Are able to implement (limited) Pydantic v1 support
28
29
  raise ImportError(
29
- "The W&B Automations API is not supported in Pydantic v1 at this time. "
30
- "If at all possible, we currently recommend upgrading to Pydantic v2 to use this feature.",
30
+ "The W&B Automations API requires Pydantic v2. "
31
+ "We recommend upgrading `pydantic` to use this feature."
31
32
  )
32
33
 
33
34
  else:
34
35
  # If Pydantic v2 is available, we can use the full Automations API
35
36
  # but communicate to users that the API is still experimental and
36
37
  # may change rapidly.
37
- import warnings
38
-
39
- warnings.warn(
40
- "The W&B Automations API is currently experimental. Although we'll communicate "
41
- "breaking changes in release notes and attempt to minimize them in general, "
42
- "please know that such changes may occur between release versions without notice. "
43
- "We strongly recommend pinning your `wandb` version when using the Automations API "
44
- "to avoid unexpected breakages.",
45
- FutureWarning,
46
- stacklevel=1,
38
+ wandb.termwarn(
39
+ "The W&B Automations API is experimental and the implementation is subject to change."
40
+ "Review the release notes before upgrading. We recommend pinning your "
41
+ f"package version to `{wandb.__package__}=={wandb.__version__}` to reduce the risk of disruption.",
42
+ repeat=False,
47
43
  )
48
44
  # ----------------------------------------------------------------------------
49
45
 
50
46
  __all__ = [
51
- "filters",
52
- "scopes",
53
- "events",
54
- "actions",
55
- "automations",
56
47
  # Scopes
57
48
  "ScopeType",
58
49
  "ArtifactCollectionScope",
@@ -65,15 +56,16 @@ __all__ = [
65
56
  "OnRunMetric",
66
57
  "ArtifactEvent",
67
58
  "RunEvent",
59
+ "MetricThresholdFilter",
60
+ "MetricChangeFilter",
68
61
  # Actions
69
62
  "ActionType",
70
- "DoNotification",
71
- "DoWebhook",
63
+ "SendNotification",
64
+ "SendWebhook",
72
65
  "DoNothing",
73
66
  # Automations
74
67
  "Automation",
75
68
  "NewAutomation",
76
- "PreparedAutomation",
77
69
  # Integrations
78
70
  "Integration",
79
71
  "SlackIntegration",
@@ -6,7 +6,7 @@ from collections.abc import Iterable
6
6
  from typing import Any, Union
7
7
 
8
8
  from pydantic import ConfigDict, model_serializer
9
- from typing_extensions import Self, TypeAlias
9
+ from typing_extensions import Self, TypeAlias, get_args
10
10
 
11
11
  from wandb._pydantic import CompatBaseModel, model_validator
12
12
 
@@ -98,22 +98,22 @@ class FilterableField:
98
98
  # Override the default behavior of comparison operators: <, >=, ==, etc
99
99
  def __lt__(self, other: Any) -> FilterExpr:
100
100
  if isinstance(other, ScalarTypes):
101
- return self.lt(other)
101
+ return self.lt(other) # type: ignore[arg-type]
102
102
  raise TypeError(f"Invalid operand type in filter expression: {type(other)!r}")
103
103
 
104
104
  def __gt__(self, other: Any) -> FilterExpr:
105
105
  if isinstance(other, ScalarTypes):
106
- return self.gt(other)
106
+ return self.gt(other) # type: ignore[arg-type]
107
107
  raise TypeError(f"Invalid operand type in filter expression: {type(other)!r}")
108
108
 
109
109
  def __le__(self, other: Any) -> FilterExpr:
110
110
  if isinstance(other, ScalarTypes):
111
- return self.lte(other)
111
+ return self.lte(other) # type: ignore[arg-type]
112
112
  raise TypeError(f"Invalid operand type in filter expression: {type(other)!r}")
113
113
 
114
114
  def __ge__(self, other: Any) -> FilterExpr:
115
115
  if isinstance(other, ScalarTypes):
116
- return self.gte(other)
116
+ return self.gte(other) # type: ignore[arg-type]
117
117
  raise TypeError(f"Invalid operand type in filter expression: {type(other)!r}")
118
118
 
119
119
  # Operator behavior is intentionally overridden to allow defining
@@ -124,12 +124,12 @@ class FilterableField:
124
124
  # https://github.com/sqlalchemy/sqlalchemy/blob/f21ae633486380a26dc0b67b70ae1c0efc6b4dc4/lib/sqlalchemy/orm/descriptor_props.py#L808-L812
125
125
  def __eq__(self, other: Any) -> FilterExpr:
126
126
  if isinstance(other, ScalarTypes):
127
- return self.eq(other)
127
+ return self.eq(other) # type: ignore[arg-type]
128
128
  raise TypeError(f"Invalid operand type in filter expression: {type(other)!r}")
129
129
 
130
130
  def __ne__(self, other: Any) -> FilterExpr:
131
131
  if isinstance(other, ScalarTypes):
132
- return self.ne(other)
132
+ return self.ne(other) # type: ignore[arg-type]
133
133
  raise TypeError(f"Invalid operand type in filter expression: {type(other)!r}")
134
134
 
135
135
 
@@ -145,7 +145,7 @@ class FilterExpr(CompatBaseModel, SupportsLogicalOpSyntax):
145
145
  op: Op
146
146
 
147
147
  def __repr__(self) -> str:
148
- return f"{type(self).__name__}({self.field!s}={self.op!r})"
148
+ return f"{type(self).__name__}({self.field!s}: {self.op!r})"
149
149
 
150
150
  def __rich_repr__(self) -> RichReprResult: # type: ignore[override]
151
151
  # https://rich.readthedocs.io/en/stable/pretty.html
@@ -172,8 +172,10 @@ class FilterExpr(CompatBaseModel, SupportsLogicalOpSyntax):
172
172
  """Return a MongoDB dict representation of the expression."""
173
173
  from pydantic_core import to_jsonable_python # Only valid in pydantic v2
174
174
 
175
- op_dict = to_jsonable_python(self.op, by_alias=True, round_trip=True)
176
- return {self.field: op_dict}
175
+ return {self.field: to_jsonable_python(self.op, by_alias=True, round_trip=True)}
177
176
 
178
177
 
178
+ # for type annotations
179
179
  MongoLikeFilter: TypeAlias = Union[Op, FilterExpr]
180
+ # for runtime type checks
181
+ MongoLikeFilterTypes: tuple[type, ...] = get_args(MongoLikeFilter)
@@ -2,22 +2,17 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Any, Dict, Iterable, Tuple, TypeVar, Union, overload
5
+ from typing import Any, Dict, Iterable, Tuple, TypeVar, Union
6
6
 
7
7
  from pydantic import ConfigDict, Field, StrictBool, StrictFloat, StrictInt, StrictStr
8
8
  from typing_extensions import TypeAlias, get_args
9
9
 
10
- from wandb._pydantic import Base
11
-
12
- if TYPE_CHECKING:
13
- from wandb.automations._filters.run_metrics import MetricThresholdFilter
14
- from wandb.automations.events import RunMetricFilter
15
-
10
+ from wandb._pydantic import GQLBase
16
11
 
17
12
  # for type annotations
18
13
  Scalar = Union[StrictStr, StrictInt, StrictFloat, StrictBool]
19
14
  # for runtime type checks
20
- ScalarTypes = get_args(Scalar)
15
+ ScalarTypes: tuple[type, ...] = tuple(t.__origin__ for t in get_args(Scalar))
21
16
 
22
17
  # See: https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol
23
18
  RichReprResult: TypeAlias = Iterable[
@@ -44,18 +39,13 @@ class SupportsLogicalOpSyntax:
44
39
  """Syntactic sugar for: `a | b` -> `Or(a, b)`."""
45
40
  return Or(or_=[self, other])
46
41
 
47
- @overload
48
- def __and__(self, other: MetricThresholdFilter) -> RunMetricFilter: ...
49
- @overload
50
- def __and__(self, other: Any) -> And: ...
51
- def __and__(self, other: Any) -> Any:
42
+ def __and__(self, other: Any) -> And:
52
43
  """Syntactic sugar for: `a & b` -> `And(a, b)`."""
53
- from wandb.automations._filters.run_metrics import MetricThresholdFilter
44
+ from .expressions import FilterExpr
54
45
 
55
- # Special handling `run_filter & metric_filter`
56
- if isinstance(other, MetricThresholdFilter):
57
- return other.__and__(self)
58
- return And(and_=[self, other])
46
+ if isinstance(other, (BaseOp, FilterExpr)):
47
+ return And(and_=[self, other])
48
+ return NotImplemented
59
49
 
60
50
  def __invert__(self) -> Not:
61
51
  """Syntactic sugar for: `~a` -> `Not(a)`."""
@@ -63,8 +53,9 @@ class SupportsLogicalOpSyntax:
63
53
 
64
54
 
65
55
  # Base class for parsed MongoDB filter/query operators, e.g. `{"$and": [...]}`.
66
- class BaseOp(Base, SupportsLogicalOpSyntax):
56
+ class BaseOp(GQLBase, SupportsLogicalOpSyntax):
67
57
  model_config = ConfigDict(
58
+ extra="forbid",
68
59
  frozen=True, # Make pseudo-immutable for easier comparison and hashing
69
60
  )
70
61
 
@@ -2,13 +2,21 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from enum import Enum
5
+ from abc import ABC, abstractmethod
6
6
  from typing import TYPE_CHECKING, Any, Final, Literal, Optional, Union, overload
7
7
 
8
- from pydantic import Field, PositiveInt, StrictFloat, StrictInt, field_validator
9
- from typing_extensions import Self, override
8
+ from pydantic import (
9
+ Field,
10
+ PositiveFloat,
11
+ PositiveInt,
12
+ StrictFloat,
13
+ StrictInt,
14
+ field_validator,
15
+ )
16
+ from typing_extensions import Annotated, TypeAlias, override
10
17
 
11
- from wandb._pydantic.base import Base, GQLBase
18
+ from wandb._pydantic import GQLBase
19
+ from wandb.automations._validators import LenientStrEnum
12
20
 
13
21
  from .expressions import FilterExpr
14
22
  from .operators import BaseOp, RichReprResult
@@ -28,83 +36,101 @@ MONGO2PY_OPS: Final[dict[str, str]] = {
28
36
  # Reverse mapping from Python literal (str) -> MongoDB operator key
29
37
  PY2MONGO_OPS: Final[dict[str, str]] = {v: k for k, v in MONGO2PY_OPS.items()}
30
38
 
39
+ # Type hint for positive numbers (int or float)
40
+ PosNum: TypeAlias = Union[PositiveInt, PositiveFloat]
31
41
 
32
- class Agg(str, Enum): # from `Aggregation`
42
+
43
+ class Agg(LenientStrEnum): # from: Aggregation
33
44
  """Supported run metric aggregation operations."""
34
45
 
35
46
  MAX = "MAX"
36
47
  MIN = "MIN"
37
48
  AVERAGE = "AVERAGE"
38
49
 
50
+ # Shorter aliases for convenience
51
+ AVG = AVERAGE
52
+
39
53
 
40
- class ChangeType(str, Enum): # from `RunMetricChangeType`
54
+ class ChangeType(LenientStrEnum): # from: RunMetricChangeType
41
55
  """Describes the metric change as absolute (arithmetic difference) or relative (decimal percentage)."""
42
56
 
43
57
  ABSOLUTE = "ABSOLUTE"
44
58
  RELATIVE = "RELATIVE"
45
59
 
60
+ # Shorter aliases for convenience
61
+ ABS = ABSOLUTE
62
+ REL = RELATIVE
63
+
46
64
 
47
- class ChangeDirection(str, Enum): # from `RunMetricChangeDirection`
65
+ class ChangeDir(LenientStrEnum): # from: RunMetricChangeDirection
48
66
  """Describes the direction of the metric change."""
49
67
 
50
68
  INCREASE = "INCREASE"
51
69
  DECREASE = "DECREASE"
52
70
  ANY = "ANY"
53
71
 
72
+ # Shorter aliases for convenience
73
+ INC = INCREASE
74
+ DEC = DECREASE
54
75
 
55
- class _BaseMetricFilter(GQLBase):
76
+
77
+ class BaseMetricFilter(GQLBase, ABC, extra="forbid"):
56
78
  name: str
57
79
  """Name of the observed metric."""
58
80
 
59
81
  agg: Optional[Agg]
60
- """Aggregation operation, if any, to apply over the window size."""
82
+ """Aggregate operation, if any, to apply over the window size."""
61
83
 
62
84
  window: PositiveInt
63
- """Size of the window over which the metric is aggregated."""
85
+ """Size of the window over which the metric is aggregated (ignored if `agg is None`)."""
64
86
 
65
87
  # ------------------------------------------------------------------------------
88
+ cmp: Optional[str]
89
+ """Comparison between the metric expression (left) vs. the threshold or target value (right)."""
66
90
 
91
+ # ------------------------------------------------------------------------------
67
92
  threshold: Union[StrictInt, StrictFloat]
68
93
  """Threshold value to compare against."""
69
94
 
70
- @field_validator("agg", mode="before")
71
- @classmethod
72
- def _validate_agg(cls, v: Any) -> Any:
73
- # Be helpful: e.g. "min" -> "MIN"
74
- return v.strip().upper() if isinstance(v, str) else v
75
-
76
- @overload
77
- def __and__(self, other: BaseOp | FilterExpr) -> RunMetricFilter: ...
78
- @overload
79
- def __and__(self, other: Any) -> Any: ...
80
- def __and__(self, other: BaseOp | FilterExpr | Any) -> RunMetricFilter | Any:
81
- """Supports syntactic sugar for defining a triggering RunMetricEvent from `run_metric_filter & run_filter`."""
82
- from wandb.automations.events import RunMetricFilter, _InnerRunMetricFilter
95
+ def __and__(self, other: Any) -> RunMetricFilter:
96
+ """Implements `(metric_filter & run_filter) -> RunMetricFilter`."""
97
+ from wandb.automations.events import RunMetricFilter
83
98
 
84
99
  if isinstance(run_filter := other, (BaseOp, FilterExpr)):
85
100
  # Assume `other` is a run filter, and we are building a RunMetricEvent.
86
101
  # For the metric filter, delegate to the inner validator(s) to further wrap/nest as appropriate.
87
- metric_filter = _InnerRunMetricFilter.model_validate(self)
88
- return RunMetricFilter(
89
- run_metric_filter=metric_filter, run_filter=run_filter
90
- )
91
- return other.__and__(self) # Try switching the order of operands
102
+ return RunMetricFilter(run=run_filter, metric=self)
103
+ return NotImplemented
92
104
 
105
+ def __rand__(self, other: BaseOp | FilterExpr) -> RunMetricFilter:
106
+ """Ensures `&` is commutative: `(run_filter & metric_filter) == (metric_filter & run_filter)`."""
107
+ return self.__and__(other)
93
108
 
94
- class MetricThresholdFilter(_BaseMetricFilter): # from `RunMetricThresholdFilter`
95
- """For run events, defines a metric filter comparing a metric against a user-defined threshold value."""
109
+ @abstractmethod
110
+ def __repr__(self) -> str:
111
+ """The text representation of the metric filter."""
112
+ raise NotImplementedError
113
+
114
+ @override
115
+ def __rich_repr__(self) -> RichReprResult: # type: ignore[override]
116
+ """The representation of the metric filter when using `rich` for pretty-printing."""
117
+ # See: https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol
118
+ yield None, repr(self)
119
+
120
+
121
+ class MetricThresholdFilter(BaseMetricFilter): # from: RunMetricThresholdFilter
122
+ """Defines a filter that compares a run metric against a user-defined threshold value."""
96
123
 
97
124
  name: str
98
- agg: Optional[Agg] = Field(default=None, alias="agg_op")
99
- window: PositiveInt = Field(default=1, alias="window_size")
125
+ agg: Annotated[Optional[Agg], Field(alias="agg_op")] = None
126
+ window: Annotated[PositiveInt, Field(alias="window_size")] = 1
100
127
 
101
- cmp: Literal["$gte", "$gt", "$lt", "$lte"] = Field(alias="cmp_op")
128
+ cmp: Annotated[Literal["$gte", "$gt", "$lt", "$lte"], Field(alias="cmp_op")]
102
129
  """Comparison operator used to compare the metric value (left) vs. the threshold value (right)."""
103
130
 
104
131
  threshold: Union[StrictInt, StrictFloat]
105
132
 
106
133
  @field_validator("cmp", mode="before")
107
- @classmethod
108
134
  def _validate_cmp(cls, v: Any) -> Any:
109
135
  # Be helpful: e.g. ">" -> "$gt"
110
136
  return PY2MONGO_OPS.get(v.strip(), v) if isinstance(v, str) else v
@@ -112,72 +138,195 @@ class MetricThresholdFilter(_BaseMetricFilter): # from `RunMetricThresholdFilte
112
138
  def __repr__(self) -> str:
113
139
  metric = f"{self.agg.value}({self.name})" if self.agg else self.name
114
140
  op = MONGO2PY_OPS.get(self.cmp, self.cmp)
115
- expr = rf"{metric} {op} {self.threshold}"
116
- return repr(expr)
141
+ return repr(rf"{metric} {op} {self.threshold}")
117
142
 
118
- @override
119
- def __rich_repr__(self) -> RichReprResult: # type: ignore[override]
120
- yield None, repr(self)
121
143
 
144
+ class MetricChangeFilter(BaseMetricFilter): # from: RunMetricChangeFilter
145
+ """Defines a filter that compares a change in a run metric against a user-defined threshold.
122
146
 
123
- class MetricChangeFilter(_BaseMetricFilter): # from `RunMetricChangeFilter`
124
- # FIXME:
125
- # - `prior_window` should be optional and default to `window` if not provided.
126
- # - implement declarative syntax for `MetricChangeFilter` similar to `MetricThresholdFilter`.
127
- # - split this into tagged union of relative/absolute change filters.
147
+ The change is calculated over "tumbling" windows, i.e. the difference
148
+ between the current window and the non-overlapping prior window.
149
+ """
128
150
 
129
151
  name: str
130
- agg: Optional[Agg] = Field(default=None, alias="agg_op")
152
+ agg: Annotated[Optional[Agg], Field(alias="agg_op")] = None
153
+ window: Annotated[PositiveInt, Field(alias="current_window_size")] = 1
131
154
 
132
- # FIXME: Set the `prior_window` to `window` if it's not provided, for convenience.
133
- window: PositiveInt = Field(alias="current_window_size")
134
- prior_window: PositiveInt = Field(alias="prior_window_size")
135
- """Size of the preceding window over which the metric is aggregated."""
155
+ # `prior_window` is only for `RUN_METRIC_CHANGE` events
156
+ prior_window: Annotated[
157
+ PositiveInt,
158
+ # By default, set `window -> prior_window` if the latter wasn't provided.
159
+ Field(alias="prior_window_size", default_factory=lambda data: data["window"]),
160
+ ]
161
+ """Size of the prior window over which the metric is aggregated (ignored if `agg is None`).
136
162
 
137
- # NOTE: `cmp_op` isn't a field here. In the backend, it's effectively `cmp_op` = "$gte"
163
+ If omitted, defaults to the size of the current window.
164
+ """
138
165
 
139
- change_type: ChangeType = Field(alias="change_type")
140
- change_direction: ChangeDirection = Field(alias="change_dir")
166
+ # ------------------------------------------------------------------------------
167
+ # NOTE:
168
+ # - The "comparison" operator isn't actually part of the backend schema,
169
+ # but it's defined here for consistency -- and ignored otherwise.
170
+ # - In the backend, it's effectively "$gte" or "$lte", depending on the sign
171
+ # (change_dir), though again, this is not explicit in the schema.
172
+ cmp: Annotated[None, Field(frozen=True, exclude=True, repr=False)] = None
173
+ """Ignored."""
141
174
 
142
- threshold: Union[StrictInt, StrictFloat] = Field(alias="change_amount")
175
+ # ------------------------------------------------------------------------------
176
+ change_type: Annotated[ChangeType, Field(alias="change_type")]
177
+ change_dir: Annotated[ChangeDir, Field(alias="change_dir")]
178
+ threshold: Annotated[PosNum, Field(alias="change_amount")]
143
179
 
180
+ def __repr__(self) -> str:
181
+ metric = f"{self.agg.value}({self.name})" if self.agg else self.name
182
+ verb = (
183
+ "changes"
184
+ if (self.change_dir is ChangeDir.ANY)
185
+ else f"{self.change_dir.value.lower()}s"
186
+ )
187
+
188
+ fmt_spec = ".2%" if (self.change_type is ChangeType.REL) else ""
189
+ amt = f"{self.threshold:{fmt_spec}}"
190
+ return repr(rf"{metric} {verb} {amt}")
191
+
192
+
193
+ class BaseMetricOperand(GQLBase, extra="forbid"):
194
+ def gt(self, value: int | float, /) -> MetricThresholdFilter:
195
+ """Defines a `MetricThresholdFilter` that observes for `metric_expr > threshold`."""
196
+ return self > value
197
+
198
+ def lt(self, value: int | float, /) -> MetricThresholdFilter:
199
+ """Defines a `MetricThresholdFilter` that observes for `metric_expr < threshold`."""
200
+ return self < value
201
+
202
+ def gte(self, value: int | float, /) -> MetricThresholdFilter:
203
+ """Defines a `MetricThresholdFilter` that observes for `metric_expr >= threshold`."""
204
+ return self >= value
205
+
206
+ def lte(self, value: int | float, /) -> MetricThresholdFilter:
207
+ """Defines a `MetricThresholdFilter` that observes for `metric_expr <= threshold`."""
208
+ return self <= value
209
+
210
+ # Overloads to implement:
211
+ # - `(metric_operand > threshold) -> MetricThresholdFilter`
212
+ # - `(metric_operand < threshold) -> MetricThresholdFilter`
213
+ # - `(metric_operand >= threshold) -> MetricThresholdFilter`
214
+ # - `(metric_operand <= threshold) -> MetricThresholdFilter`
215
+ def __gt__(self, other: Any) -> MetricThresholdFilter:
216
+ if isinstance(other, (int, float)):
217
+ return MetricThresholdFilter(**dict(self), cmp="$gt", threshold=other)
218
+ return NotImplemented
219
+
220
+ def __lt__(self, other: Any) -> MetricThresholdFilter:
221
+ if isinstance(other, (int, float)):
222
+ return MetricThresholdFilter(**dict(self), cmp="$lt", threshold=other)
223
+ return NotImplemented
224
+
225
+ def __ge__(self, other: Any) -> MetricThresholdFilter:
226
+ if isinstance(other, (int, float)):
227
+ return MetricThresholdFilter(**dict(self), cmp="$gte", threshold=other)
228
+ return NotImplemented
229
+
230
+ def __le__(self, other: Any) -> MetricThresholdFilter:
231
+ if isinstance(other, (int, float)):
232
+ return MetricThresholdFilter(**dict(self), cmp="$lte", threshold=other)
233
+ return NotImplemented
144
234
 
145
- class MetricOperand(Base):
146
- name: str
147
- agg: Optional[Agg] = Field(default=None, alias="agg_op")
148
- window: PositiveInt = Field(default=1, alias="window_size")
235
+ @overload
236
+ def changes_by(self, *, diff: PosNum, frac: None) -> MetricChangeFilter: ...
237
+ @overload
238
+ def changes_by(self, *, diff: None, frac: PosNum) -> MetricChangeFilter: ...
239
+ @overload # NOTE: This overload is for internal use only.
240
+ def changes_by(
241
+ self, *, diff: PosNum | None, frac: PosNum | None, _dir: ChangeDir
242
+ ) -> MetricChangeFilter: ...
243
+ def changes_by(
244
+ self,
245
+ *,
246
+ diff: PosNum | None = None,
247
+ frac: PosNum | None = None,
248
+ _dir: ChangeDir = ChangeDir.ANY,
249
+ ) -> MetricChangeFilter:
250
+ """Defines a filter that observes for any change (increase OR decrease) in a run metric.
251
+
252
+ Exactly one of the keyword arguments `frac` or `diff` must be provided.
253
+
254
+ Args:
255
+ diff:
256
+ If given, the arithmetic difference that must be observed
257
+ in the metric. Must be a positive number.
258
+ frac:
259
+ If given, the fractional (relative) change that must be observed
260
+ in the metric. Must be a positive number. E.g. `frac=0.1`
261
+ denotes a 10% relative increase OR decrease.
262
+ """
263
+ # Enforce mutually exclusive keyword args
264
+ if (frac is None) is (diff is None):
265
+ raise ValueError("Must provide exactly one of `frac` or `diff`")
266
+
267
+ # Enforce positive values
268
+ if (frac is not None) and (frac <= 0):
269
+ raise ValueError(f"Expected positive quantity, got: {frac=}")
270
+ if (diff is not None) and (diff <= 0):
271
+ raise ValueError(f"Expected positive quantity, got: {diff=}")
272
+
273
+ if diff is None:
274
+ change_kws = dict(change_type=ChangeType.REL, threshold=frac)
275
+ return MetricChangeFilter(**dict(self), change_dir=_dir, **change_kws)
276
+ else:
277
+ change_kws = dict(change_type=ChangeType.ABS, threshold=diff)
278
+ return MetricChangeFilter(**dict(self), change_dir=_dir, **change_kws)
279
+
280
+ @overload
281
+ def increases_by(self, *, diff: PosNum, frac: None) -> MetricChangeFilter: ...
282
+ @overload
283
+ def increases_by(self, *, diff: None, frac: PosNum) -> MetricChangeFilter: ...
284
+ def increases_by(
285
+ self, *, diff: PosNum | None = None, frac: PosNum | None = None
286
+ ) -> MetricChangeFilter:
287
+ """Defines a filter that observes for an increase in the numerical value of a run metric.
149
288
 
150
- def _agg(self, op: Agg, window: int) -> Self:
151
- if self.agg is None: # Prevent overwriting an existing aggregation operator
152
- return self.model_copy(update={"agg": op, "window": window})
153
- raise ValueError(f"Aggregation operator already set as: {self.agg!r}")
289
+ Arguments are the same as for `.changes_by()`.
290
+ """
291
+ return self.changes_by(diff=diff, frac=frac, _dir=ChangeDir.INC)
154
292
 
155
- def max(self, window: int) -> Self:
156
- return self._agg(Agg.MAX, window)
293
+ @overload
294
+ def decreases_by(self, *, diff: PosNum, frac: None) -> MetricChangeFilter: ...
295
+ @overload
296
+ def decreases_by(self, *, diff: None, frac: PosNum) -> MetricChangeFilter: ...
297
+ def decreases_by(
298
+ self, *, diff: PosNum | None = None, frac: PosNum | None = None
299
+ ) -> MetricChangeFilter:
300
+ """Defines a filter that observes for a decrease in the numerical value of a run metric.
157
301
 
158
- def min(self, window: int) -> Self:
159
- return self._agg(Agg.MIN, window)
302
+ Arguments are the same as for `.changes_by()`.
303
+ """
304
+ return self.changes_by(diff=diff, frac=frac, _dir=ChangeDir.DEC)
160
305
 
161
- def average(self, window: int) -> Self:
162
- return self._agg(Agg.AVERAGE, window)
163
306
 
164
- # Aliased method for users familiar with e.g. torch/tf/numpy/pandas/polars/etc.
165
- def mean(self, window: int) -> Self:
166
- return self.average(window=window)
307
+ class MetricVal(BaseMetricOperand):
308
+ """Represents a single, unaggregated metric value when defining a metric filter."""
167
309
 
168
- def gt(self, other: int | float) -> MetricThresholdFilter:
169
- return MetricThresholdFilter(**dict(self), cmp="$gt", threshold=other)
310
+ name: str
170
311
 
171
- def lt(self, other: int | float) -> MetricThresholdFilter:
172
- return MetricThresholdFilter(**dict(self), cmp="$lt", threshold=other)
312
+ # Allow users to convert this single-value metric into an aggregated metric expression.
313
+ def max(self, window: int) -> MetricAgg:
314
+ return MetricAgg(name=self.name, agg=Agg.MAX, window=window)
173
315
 
174
- def gte(self, other: int | float) -> MetricThresholdFilter:
175
- return MetricThresholdFilter(**dict(self), cmp="$gte", threshold=other)
316
+ def min(self, window: int) -> MetricAgg:
317
+ return MetricAgg(name=self.name, agg=Agg.MIN, window=window)
176
318
 
177
- def lte(self, other: int | float) -> MetricThresholdFilter:
178
- return MetricThresholdFilter(**dict(self), cmp="$lte", threshold=other)
319
+ def avg(self, window: int) -> MetricAgg:
320
+ return MetricAgg(name=self.name, agg=Agg.AVG, window=window)
179
321
 
180
- __gt__ = gt
181
- __lt__ = lt
182
- __ge__ = gte
183
- __le__ = lte
322
+ # Aliased method for users familiar with e.g. torch/tf/numpy/pandas/polars/etc.
323
+ def mean(self, window: int) -> MetricAgg:
324
+ return self.avg(window=window)
325
+
326
+
327
+ class MetricAgg(BaseMetricOperand):
328
+ """Represents an aggregated metric value when defining a metric filter."""
329
+
330
+ name: str
331
+ agg: Annotated[Agg, Field(alias="agg_op")]
332
+ window: Annotated[PositiveInt, Field(alias="window_size")]