kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1184 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
import warnings
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Generator, Iterator
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from dataclasses import dataclass, replace
|
|
8
|
+
from typing import Any, Literal, overload
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from kumoapi.model_plan import RunMode
|
|
13
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
|
+
from kumoapi.pquery.AST import (
|
|
15
|
+
Aggregation,
|
|
16
|
+
Column,
|
|
17
|
+
Condition,
|
|
18
|
+
Join,
|
|
19
|
+
LogicalOperation,
|
|
20
|
+
)
|
|
21
|
+
from kumoapi.rfm import Context
|
|
22
|
+
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
23
|
+
from kumoapi.rfm import (
|
|
24
|
+
RFMEvaluateRequest,
|
|
25
|
+
RFMParseQueryRequest,
|
|
26
|
+
RFMPredictRequest,
|
|
27
|
+
)
|
|
28
|
+
from kumoapi.task import TaskType
|
|
29
|
+
from kumoapi.typing import AggregationType, Stype
|
|
30
|
+
|
|
31
|
+
from kumoai.client.rfm import RFMAPI
|
|
32
|
+
from kumoai.exceptions import HTTPException
|
|
33
|
+
from kumoai.experimental.rfm import Graph
|
|
34
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
35
|
+
from kumoai.mixin import CastMixin
|
|
36
|
+
from kumoai.utils import ProgressLogger, display
|
|
37
|
+
|
|
38
|
+
_RANDOM_SEED = 42
|
|
39
|
+
|
|
40
|
+
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
41
|
+
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
42
|
+
|
|
43
|
+
_MAX_CONTEXT_SIZE = {
|
|
44
|
+
RunMode.DEBUG: 100,
|
|
45
|
+
RunMode.FAST: 1_000,
|
|
46
|
+
RunMode.NORMAL: 5_000,
|
|
47
|
+
RunMode.BEST: 10_000,
|
|
48
|
+
}
|
|
49
|
+
_MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
|
|
50
|
+
RunMode.DEBUG: 100,
|
|
51
|
+
RunMode.FAST: 2_000,
|
|
52
|
+
RunMode.NORMAL: 2_000,
|
|
53
|
+
RunMode.BEST: 2_000,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
_MAX_SIZE = 30 * 1024 * 1024
|
|
57
|
+
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
58
|
+
"reduce either the number of tables in the graph, their "
|
|
59
|
+
"number of columns (e.g., large text columns), "
|
|
60
|
+
"neighborhood configuration, or the run mode. If none of "
|
|
61
|
+
"this is possible, please create a feature request at "
|
|
62
|
+
"'https://github.com/kumo-ai/kumo-rfm' if you must go "
|
|
63
|
+
"beyond this for your use-case.")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass(repr=False)
|
|
67
|
+
class ExplainConfig(CastMixin):
|
|
68
|
+
"""Configuration for explainability.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
72
|
+
the explanation.
|
|
73
|
+
"""
|
|
74
|
+
skip_summary: bool = False
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass(repr=False)
|
|
78
|
+
class Explanation:
|
|
79
|
+
prediction: pd.DataFrame
|
|
80
|
+
summary: str
|
|
81
|
+
details: ExplanationConfig
|
|
82
|
+
|
|
83
|
+
@overload
|
|
84
|
+
def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
@overload
|
|
88
|
+
def __getitem__(self, index: Literal[1]) -> str:
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
92
|
+
if index == 0:
|
|
93
|
+
return self.prediction
|
|
94
|
+
if index == 1:
|
|
95
|
+
return self.summary
|
|
96
|
+
raise IndexError("Index out of range")
|
|
97
|
+
|
|
98
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
99
|
+
return iter((self.prediction, self.summary))
|
|
100
|
+
|
|
101
|
+
def __repr__(self) -> str:
|
|
102
|
+
return str((self.prediction, self.summary))
|
|
103
|
+
|
|
104
|
+
def print(self) -> None:
|
|
105
|
+
r"""Prints the explanation."""
|
|
106
|
+
display.dataframe(self.prediction)
|
|
107
|
+
display.message(self.summary)
|
|
108
|
+
|
|
109
|
+
def _ipython_display_(self) -> None:
|
|
110
|
+
self.print()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class KumoRFM:
|
|
114
|
+
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
115
|
+
Foundation Model for In-Context Learning on Relational Data
|
|
116
|
+
<https://kumo.ai/research/kumo_relational_foundation_model.pdf>`_ paper.
|
|
117
|
+
|
|
118
|
+
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
119
|
+
relational dataset without training.
|
|
120
|
+
The model is pre-trained and the class provides an interface to query the
|
|
121
|
+
model from a :class:`Graph` object.
|
|
122
|
+
|
|
123
|
+
.. code-block:: python
|
|
124
|
+
|
|
125
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
126
|
+
|
|
127
|
+
df_users = pd.DataFrame(...)
|
|
128
|
+
df_items = pd.DataFrame(...)
|
|
129
|
+
df_orders = pd.DataFrame(...)
|
|
130
|
+
|
|
131
|
+
graph = Graph.from_data({
|
|
132
|
+
'users': df_users,
|
|
133
|
+
'items': df_items,
|
|
134
|
+
'orders': df_orders,
|
|
135
|
+
})
|
|
136
|
+
|
|
137
|
+
rfm = KumoRFM(graph)
|
|
138
|
+
|
|
139
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
140
|
+
"FOR users.user_id=1")
|
|
141
|
+
result = rfm.predict(query)
|
|
142
|
+
|
|
143
|
+
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
144
|
+
# 1 0.85
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
graph: The graph.
|
|
148
|
+
verbose: Whether to print verbose output.
|
|
149
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
150
|
+
for optimal querying. For example, for transactional database
|
|
151
|
+
backends, will create any missing indices. Requires write-access to
|
|
152
|
+
the data backend.
|
|
153
|
+
"""
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
graph: Graph,
|
|
157
|
+
verbose: bool | ProgressLogger = True,
|
|
158
|
+
optimize: bool = False,
|
|
159
|
+
) -> None:
|
|
160
|
+
graph = graph.validate()
|
|
161
|
+
self._graph_def = graph._to_api_graph_definition()
|
|
162
|
+
|
|
163
|
+
if graph.backend == DataBackend.LOCAL:
|
|
164
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
165
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
166
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
167
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
168
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
169
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
170
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
171
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
172
|
+
else:
|
|
173
|
+
raise NotImplementedError
|
|
174
|
+
|
|
175
|
+
self._client: RFMAPI | None = None
|
|
176
|
+
|
|
177
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
178
|
+
self.num_retries: int = 0
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
def _api_client(self) -> RFMAPI:
|
|
182
|
+
if self._client is not None:
|
|
183
|
+
return self._client
|
|
184
|
+
|
|
185
|
+
from kumoai.experimental.rfm import global_state
|
|
186
|
+
self._client = RFMAPI(global_state.client)
|
|
187
|
+
return self._client
|
|
188
|
+
|
|
189
|
+
def __repr__(self) -> str:
|
|
190
|
+
return f'{self.__class__.__name__}()'
|
|
191
|
+
|
|
192
|
+
@contextmanager
|
|
193
|
+
def batch_mode(
|
|
194
|
+
self,
|
|
195
|
+
batch_size: int | Literal['max'] = 'max',
|
|
196
|
+
num_retries: int = 1,
|
|
197
|
+
) -> Generator[None, None, None]:
|
|
198
|
+
"""Context manager to predict in batches.
|
|
199
|
+
|
|
200
|
+
.. code-block:: python
|
|
201
|
+
|
|
202
|
+
with model.batch_mode(batch_size='max', num_retries=1):
|
|
203
|
+
df = model.predict(query, indices=...)
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
batch_size: The batch size. If set to ``"max"``, will use the
|
|
207
|
+
maximum applicable batch size for the given task.
|
|
208
|
+
num_retries: The maximum number of retries for failed queries due
|
|
209
|
+
to unexpected server issues.
|
|
210
|
+
"""
|
|
211
|
+
if batch_size != 'max' and batch_size <= 0:
|
|
212
|
+
raise ValueError(f"'batch_size' must be greater than zero "
|
|
213
|
+
f"(got {batch_size})")
|
|
214
|
+
|
|
215
|
+
if num_retries < 0:
|
|
216
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
217
|
+
f"zero (got {num_retries})")
|
|
218
|
+
|
|
219
|
+
self._batch_size = batch_size
|
|
220
|
+
self.num_retries = num_retries
|
|
221
|
+
yield
|
|
222
|
+
self._batch_size = None
|
|
223
|
+
self.num_retries = 0
|
|
224
|
+
|
|
225
|
+
@overload
|
|
226
|
+
def predict(
|
|
227
|
+
self,
|
|
228
|
+
query: str,
|
|
229
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
230
|
+
*,
|
|
231
|
+
explain: Literal[False] = False,
|
|
232
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
233
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
234
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
235
|
+
num_neighbors: list[int] | None = None,
|
|
236
|
+
num_hops: int = 2,
|
|
237
|
+
max_pq_iterations: int = 10,
|
|
238
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
239
|
+
verbose: bool | ProgressLogger = True,
|
|
240
|
+
use_prediction_time: bool = False,
|
|
241
|
+
) -> pd.DataFrame:
|
|
242
|
+
pass
|
|
243
|
+
|
|
244
|
+
@overload
|
|
245
|
+
def predict(
|
|
246
|
+
self,
|
|
247
|
+
query: str,
|
|
248
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
249
|
+
*,
|
|
250
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
251
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
252
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
253
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
254
|
+
num_neighbors: list[int] | None = None,
|
|
255
|
+
num_hops: int = 2,
|
|
256
|
+
max_pq_iterations: int = 10,
|
|
257
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
258
|
+
verbose: bool | ProgressLogger = True,
|
|
259
|
+
use_prediction_time: bool = False,
|
|
260
|
+
) -> Explanation:
|
|
261
|
+
pass
|
|
262
|
+
|
|
263
|
+
def predict(
|
|
264
|
+
self,
|
|
265
|
+
query: str,
|
|
266
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
267
|
+
*,
|
|
268
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
269
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
270
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
271
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
272
|
+
num_neighbors: list[int] | None = None,
|
|
273
|
+
num_hops: int = 2,
|
|
274
|
+
max_pq_iterations: int = 10,
|
|
275
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
276
|
+
verbose: bool | ProgressLogger = True,
|
|
277
|
+
use_prediction_time: bool = False,
|
|
278
|
+
) -> pd.DataFrame | Explanation:
|
|
279
|
+
"""Returns predictions for a predictive query.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
query: The predictive query.
|
|
283
|
+
indices: The entity primary keys to predict for. Will override the
|
|
284
|
+
indices given as part of the predictive query. Predictions will
|
|
285
|
+
be generated for all indices, independent of whether they
|
|
286
|
+
fulfill entity filter constraints. To pre-filter entities, use
|
|
287
|
+
:meth:`~KumoRFM.is_valid_entity`.
|
|
288
|
+
explain: Configuration for explainability.
|
|
289
|
+
If set to ``True``, will additionally explain the prediction.
|
|
290
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
291
|
+
over which parts of explanation are generated.
|
|
292
|
+
Explainability is currently only supported for single entity
|
|
293
|
+
predictions with ``run_mode="FAST"``.
|
|
294
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
295
|
+
``None``, will use the maximum timestamp in the data.
|
|
296
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
297
|
+
context_anchor_time: The maximum anchor timestamp for context
|
|
298
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
299
|
+
determine the anchor time for context examples.
|
|
300
|
+
run_mode: The :class:`RunMode` for the query.
|
|
301
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
302
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
303
|
+
num_hops: The number of hops to sample when generating the context.
|
|
304
|
+
max_pq_iterations: The maximum number of iterations to perform to
|
|
305
|
+
collect valid labels. It is advised to increase the number of
|
|
306
|
+
iterations in case the predictive query has strict entity
|
|
307
|
+
filters, in which case, :class:`KumoRFM` needs to sample more
|
|
308
|
+
entities to find valid labels.
|
|
309
|
+
random_seed: A manual seed for generating pseudo-random numbers.
|
|
310
|
+
verbose: Whether to print verbose output.
|
|
311
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
312
|
+
additional feature during prediction. This is typically
|
|
313
|
+
beneficial for time series forecasting tasks.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
317
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
318
|
+
containing the prediction, summary, and details.
|
|
319
|
+
"""
|
|
320
|
+
explain_config: ExplainConfig | None = None
|
|
321
|
+
if explain is True:
|
|
322
|
+
explain_config = ExplainConfig()
|
|
323
|
+
elif explain is not False:
|
|
324
|
+
explain_config = ExplainConfig._cast(explain)
|
|
325
|
+
|
|
326
|
+
query_def = self._parse_query(query)
|
|
327
|
+
query_str = query_def.to_string()
|
|
328
|
+
|
|
329
|
+
if num_hops != 2 and num_neighbors is not None:
|
|
330
|
+
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
331
|
+
f"custom 'num_hops={num_hops}' option")
|
|
332
|
+
|
|
333
|
+
if explain_config is not None and run_mode in {
|
|
334
|
+
RunMode.NORMAL, RunMode.BEST
|
|
335
|
+
}:
|
|
336
|
+
warnings.warn(f"Explainability is currently only supported for "
|
|
337
|
+
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
338
|
+
f"mode has been reset. Please lower the run mode to "
|
|
339
|
+
f"suppress this warning.")
|
|
340
|
+
|
|
341
|
+
if indices is None:
|
|
342
|
+
if query_def.rfm_entity_ids is None:
|
|
343
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
344
|
+
"pass them via `predict(query, indices=...)`")
|
|
345
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
346
|
+
else:
|
|
347
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
348
|
+
|
|
349
|
+
if len(indices) == 0:
|
|
350
|
+
raise ValueError("At least one entity is required")
|
|
351
|
+
|
|
352
|
+
if explain_config is not None and len(indices) > 1:
|
|
353
|
+
raise ValueError(
|
|
354
|
+
f"Cannot explain predictions for more than a single entity "
|
|
355
|
+
f"(got {len(indices)})")
|
|
356
|
+
|
|
357
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
358
|
+
if explain_config is not None:
|
|
359
|
+
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
360
|
+
else:
|
|
361
|
+
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
362
|
+
|
|
363
|
+
if not isinstance(verbose, ProgressLogger):
|
|
364
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
365
|
+
|
|
366
|
+
with verbose as logger:
|
|
367
|
+
|
|
368
|
+
batch_size: int | None = None
|
|
369
|
+
if self._batch_size == 'max':
|
|
370
|
+
task_type = self._get_task_type(
|
|
371
|
+
query=query_def,
|
|
372
|
+
edge_types=self._sampler.edge_types,
|
|
373
|
+
)
|
|
374
|
+
batch_size = _MAX_PRED_SIZE[task_type]
|
|
375
|
+
else:
|
|
376
|
+
batch_size = self._batch_size
|
|
377
|
+
|
|
378
|
+
if batch_size is not None:
|
|
379
|
+
offsets = range(0, len(indices), batch_size)
|
|
380
|
+
batches = [indices[step:step + batch_size] for step in offsets]
|
|
381
|
+
else:
|
|
382
|
+
batches = [indices]
|
|
383
|
+
|
|
384
|
+
if len(batches) > 1:
|
|
385
|
+
logger.log(f"Splitting {len(indices):,} entities into "
|
|
386
|
+
f"{len(batches):,} batches of size {batch_size:,}")
|
|
387
|
+
|
|
388
|
+
predictions: list[pd.DataFrame] = []
|
|
389
|
+
summary: str | None = None
|
|
390
|
+
details: Explanation | None = None
|
|
391
|
+
for i, batch in enumerate(batches):
|
|
392
|
+
# TODO Re-use the context for subsequent predictions.
|
|
393
|
+
context = self._get_context(
|
|
394
|
+
query=query_def,
|
|
395
|
+
indices=batch,
|
|
396
|
+
anchor_time=anchor_time,
|
|
397
|
+
context_anchor_time=context_anchor_time,
|
|
398
|
+
run_mode=RunMode(run_mode),
|
|
399
|
+
num_neighbors=num_neighbors,
|
|
400
|
+
num_hops=num_hops,
|
|
401
|
+
max_pq_iterations=max_pq_iterations,
|
|
402
|
+
evaluate=False,
|
|
403
|
+
random_seed=random_seed,
|
|
404
|
+
logger=logger if i == 0 else None,
|
|
405
|
+
)
|
|
406
|
+
request = RFMPredictRequest(
|
|
407
|
+
context=context,
|
|
408
|
+
run_mode=RunMode(run_mode),
|
|
409
|
+
query=query_str,
|
|
410
|
+
use_prediction_time=use_prediction_time,
|
|
411
|
+
)
|
|
412
|
+
with warnings.catch_warnings():
|
|
413
|
+
warnings.filterwarnings('ignore', message='gencode')
|
|
414
|
+
request_msg = request.to_protobuf()
|
|
415
|
+
_bytes = request_msg.SerializeToString()
|
|
416
|
+
if i == 0:
|
|
417
|
+
logger.log(f"Generated context of size "
|
|
418
|
+
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
419
|
+
|
|
420
|
+
if len(_bytes) > _MAX_SIZE:
|
|
421
|
+
stats = Context.get_memory_stats(request_msg.context)
|
|
422
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
423
|
+
|
|
424
|
+
if i == 0 and len(batches) > 1:
|
|
425
|
+
verbose.init_progress(
|
|
426
|
+
total=len(batches),
|
|
427
|
+
description='Predicting',
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
for attempt in range(self.num_retries + 1):
|
|
431
|
+
try:
|
|
432
|
+
if explain_config is not None:
|
|
433
|
+
resp = self._api_client.explain(
|
|
434
|
+
request=_bytes,
|
|
435
|
+
skip_summary=explain_config.skip_summary,
|
|
436
|
+
)
|
|
437
|
+
summary = resp.summary
|
|
438
|
+
details = resp.details
|
|
439
|
+
else:
|
|
440
|
+
resp = self._api_client.predict(_bytes)
|
|
441
|
+
df = pd.DataFrame(**resp.prediction)
|
|
442
|
+
|
|
443
|
+
# Cast 'ENTITY' to correct data type:
|
|
444
|
+
if 'ENTITY' in df:
|
|
445
|
+
table_dict = context.subgraph.table_dict
|
|
446
|
+
table = table_dict[query_def.entity_table]
|
|
447
|
+
ser = table.df[table.primary_key]
|
|
448
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
449
|
+
|
|
450
|
+
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
451
|
+
if 'ANCHOR_TIMESTAMP' in df:
|
|
452
|
+
ser = df['ANCHOR_TIMESTAMP']
|
|
453
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
454
|
+
if isinstance(ser.iloc[0], str):
|
|
455
|
+
unit = None
|
|
456
|
+
else:
|
|
457
|
+
unit = 'ms'
|
|
458
|
+
df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
|
|
459
|
+
ser, errors='coerce', unit=unit)
|
|
460
|
+
|
|
461
|
+
predictions.append(df)
|
|
462
|
+
|
|
463
|
+
if len(batches) > 1:
|
|
464
|
+
verbose.step()
|
|
465
|
+
|
|
466
|
+
break
|
|
467
|
+
except HTTPException as e:
|
|
468
|
+
if attempt == self.num_retries:
|
|
469
|
+
try:
|
|
470
|
+
msg = json.loads(e.detail)['detail']
|
|
471
|
+
except Exception:
|
|
472
|
+
msg = e.detail
|
|
473
|
+
raise RuntimeError(
|
|
474
|
+
f"An unexpected exception occurred. Please "
|
|
475
|
+
f"create an issue at "
|
|
476
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
477
|
+
) from None
|
|
478
|
+
|
|
479
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
480
|
+
|
|
481
|
+
if len(predictions) == 1:
|
|
482
|
+
prediction = predictions[0]
|
|
483
|
+
else:
|
|
484
|
+
prediction = pd.concat(predictions, ignore_index=True)
|
|
485
|
+
|
|
486
|
+
if explain_config is not None:
|
|
487
|
+
assert len(predictions) == 1
|
|
488
|
+
assert summary is not None
|
|
489
|
+
assert details is not None
|
|
490
|
+
return Explanation(
|
|
491
|
+
prediction=prediction,
|
|
492
|
+
summary=summary,
|
|
493
|
+
details=details,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
return prediction
|
|
497
|
+
|
|
498
|
+
def is_valid_entity(
|
|
499
|
+
self,
|
|
500
|
+
query: str,
|
|
501
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
502
|
+
*,
|
|
503
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
504
|
+
) -> np.ndarray:
|
|
505
|
+
r"""Returns a mask that denotes which entities are valid for the
|
|
506
|
+
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
507
|
+
entity filter constraints.
|
|
508
|
+
|
|
509
|
+
Args:
|
|
510
|
+
query: The predictive query.
|
|
511
|
+
indices: The entity primary keys to predict for. Will override the
|
|
512
|
+
indices given as part of the predictive query.
|
|
513
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
514
|
+
``None``, will use the maximum timestamp in the data.
|
|
515
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
516
|
+
"""
|
|
517
|
+
query_def = self._parse_query(query)
|
|
518
|
+
|
|
519
|
+
if indices is None:
|
|
520
|
+
if query_def.rfm_entity_ids is None:
|
|
521
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
522
|
+
"pass them via "
|
|
523
|
+
"`is_valid_entity(query, indices=...)`")
|
|
524
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
525
|
+
|
|
526
|
+
if len(indices) == 0:
|
|
527
|
+
raise ValueError("At least one entity is required")
|
|
528
|
+
|
|
529
|
+
if anchor_time is None:
|
|
530
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
531
|
+
|
|
532
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
533
|
+
self._validate_time(query_def, anchor_time, None, False)
|
|
534
|
+
else:
|
|
535
|
+
assert anchor_time == 'entity'
|
|
536
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
537
|
+
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
538
|
+
f"table '{query_def.entity_table}' "
|
|
539
|
+
f"to have a time column.")
|
|
540
|
+
|
|
541
|
+
raise NotImplementedError
|
|
542
|
+
|
|
543
|
+
def evaluate(
|
|
544
|
+
self,
|
|
545
|
+
query: str,
|
|
546
|
+
*,
|
|
547
|
+
metrics: list[str] | None = None,
|
|
548
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
549
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
550
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
551
|
+
num_neighbors: list[int] | None = None,
|
|
552
|
+
num_hops: int = 2,
|
|
553
|
+
max_pq_iterations: int = 10,
|
|
554
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
555
|
+
verbose: bool | ProgressLogger = True,
|
|
556
|
+
use_prediction_time: bool = False,
|
|
557
|
+
) -> pd.DataFrame:
|
|
558
|
+
"""Evaluates a predictive query.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
query: The predictive query.
|
|
562
|
+
metrics: The metrics to use.
|
|
563
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
564
|
+
``None``, will use the maximum timestamp in the data.
|
|
565
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
566
|
+
context_anchor_time: The maximum anchor timestamp for context
|
|
567
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
568
|
+
determine the anchor time for context examples.
|
|
569
|
+
run_mode: The :class:`RunMode` for the query.
|
|
570
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
571
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
572
|
+
num_hops: The number of hops to sample when generating the context.
|
|
573
|
+
max_pq_iterations: The maximum number of iterations to perform to
|
|
574
|
+
collect valid labels. It is advised to increase the number of
|
|
575
|
+
iterations in case the predictive query has strict entity
|
|
576
|
+
filters, in which case, :class:`KumoRFM` needs to sample more
|
|
577
|
+
entities to find valid labels.
|
|
578
|
+
random_seed: A manual seed for generating pseudo-random numbers.
|
|
579
|
+
verbose: Whether to print verbose output.
|
|
580
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
581
|
+
additional feature during prediction. This is typically
|
|
582
|
+
beneficial for time series forecasting tasks.
|
|
583
|
+
|
|
584
|
+
Returns:
|
|
585
|
+
The metrics as a :class:`pandas.DataFrame`
|
|
586
|
+
"""
|
|
587
|
+
query_def = self._parse_query(query)
|
|
588
|
+
|
|
589
|
+
if num_hops != 2 and num_neighbors is not None:
|
|
590
|
+
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
591
|
+
f"custom 'num_hops={num_hops}' option")
|
|
592
|
+
|
|
593
|
+
if query_def.rfm_entity_ids is not None:
|
|
594
|
+
query_def = replace(
|
|
595
|
+
query_def,
|
|
596
|
+
rfm_entity_ids=None,
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
600
|
+
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
601
|
+
|
|
602
|
+
if not isinstance(verbose, ProgressLogger):
|
|
603
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
604
|
+
|
|
605
|
+
with verbose as logger:
|
|
606
|
+
context = self._get_context(
|
|
607
|
+
query=query_def,
|
|
608
|
+
indices=None,
|
|
609
|
+
anchor_time=anchor_time,
|
|
610
|
+
context_anchor_time=context_anchor_time,
|
|
611
|
+
run_mode=RunMode(run_mode),
|
|
612
|
+
num_neighbors=num_neighbors,
|
|
613
|
+
num_hops=num_hops,
|
|
614
|
+
max_pq_iterations=max_pq_iterations,
|
|
615
|
+
evaluate=True,
|
|
616
|
+
random_seed=random_seed,
|
|
617
|
+
logger=logger if verbose else None,
|
|
618
|
+
)
|
|
619
|
+
if metrics is not None and len(metrics) > 0:
|
|
620
|
+
self._validate_metrics(metrics, context.task_type)
|
|
621
|
+
metrics = list(dict.fromkeys(metrics))
|
|
622
|
+
request = RFMEvaluateRequest(
|
|
623
|
+
context=context,
|
|
624
|
+
run_mode=RunMode(run_mode),
|
|
625
|
+
metrics=metrics,
|
|
626
|
+
use_prediction_time=use_prediction_time,
|
|
627
|
+
)
|
|
628
|
+
with warnings.catch_warnings():
|
|
629
|
+
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
630
|
+
request_msg = request.to_protobuf()
|
|
631
|
+
request_bytes = request_msg.SerializeToString()
|
|
632
|
+
logger.log(f"Generated context of size "
|
|
633
|
+
f"{len(request_bytes) / (1024*1024):.2f}MB")
|
|
634
|
+
|
|
635
|
+
if len(request_bytes) > _MAX_SIZE:
|
|
636
|
+
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
637
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
638
|
+
|
|
639
|
+
try:
|
|
640
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
641
|
+
except HTTPException as e:
|
|
642
|
+
try:
|
|
643
|
+
msg = json.loads(e.detail)['detail']
|
|
644
|
+
except Exception:
|
|
645
|
+
msg = e.detail
|
|
646
|
+
raise RuntimeError(f"An unexpected exception occurred. "
|
|
647
|
+
f"Please create an issue at "
|
|
648
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. "
|
|
649
|
+
f"{msg}") from None
|
|
650
|
+
|
|
651
|
+
return pd.DataFrame.from_dict(
|
|
652
|
+
resp.metrics,
|
|
653
|
+
orient='index',
|
|
654
|
+
columns=['value'],
|
|
655
|
+
).reset_index(names='metric')
|
|
656
|
+
|
|
657
|
+
def get_train_table(
|
|
658
|
+
self,
|
|
659
|
+
query: str,
|
|
660
|
+
size: int,
|
|
661
|
+
*,
|
|
662
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
663
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
664
|
+
max_iterations: int = 10,
|
|
665
|
+
) -> pd.DataFrame:
|
|
666
|
+
"""Returns the labels of a predictive query for a specified anchor
|
|
667
|
+
time.
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
query: The predictive query.
|
|
671
|
+
size: The maximum number of entities to generate labels for.
|
|
672
|
+
anchor_time: The anchor timestamp for the query. If set to
|
|
673
|
+
:obj:`None`, will use the maximum timestamp in the data.
|
|
674
|
+
If set to :`"entity"`, will use the timestamp of the entity.
|
|
675
|
+
random_seed: A manual seed for generating pseudo-random numbers.
|
|
676
|
+
max_iterations: The number of steps to run before aborting.
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
The labels as a :class:`pandas.DataFrame`.
|
|
680
|
+
"""
|
|
681
|
+
query_def = self._parse_query(query)
|
|
682
|
+
|
|
683
|
+
if anchor_time is None:
|
|
684
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
685
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
686
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
687
|
+
offset *= query_def.num_forecasts
|
|
688
|
+
anchor_time -= offset
|
|
689
|
+
|
|
690
|
+
assert anchor_time is not None
|
|
691
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
692
|
+
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
693
|
+
else:
|
|
694
|
+
assert anchor_time == 'entity'
|
|
695
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
696
|
+
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
697
|
+
f"table '{query_def.entity_table}' "
|
|
698
|
+
f"to have a time column")
|
|
699
|
+
|
|
700
|
+
train, test = self._sampler.sample_target(
|
|
701
|
+
query=query_def,
|
|
702
|
+
num_train_examples=0,
|
|
703
|
+
train_anchor_time=anchor_time,
|
|
704
|
+
num_train_trials=0,
|
|
705
|
+
num_test_examples=size,
|
|
706
|
+
test_anchor_time=anchor_time,
|
|
707
|
+
num_test_trials=max_iterations * size,
|
|
708
|
+
random_seed=random_seed,
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
return pd.DataFrame({
|
|
712
|
+
'ENTITY': test.entity_pkey,
|
|
713
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
714
|
+
'TARGET': test.target,
|
|
715
|
+
})
|
|
716
|
+
|
|
717
|
+
# Helpers #################################################################
|
|
718
|
+
|
|
719
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
720
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
721
|
+
return query
|
|
722
|
+
|
|
723
|
+
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
724
|
+
raise ValueError("'EVALUATE PREDICT ...' queries are not "
|
|
725
|
+
"supported in the SDK. Instead, use either "
|
|
726
|
+
"`predict()` or `evaluate()` methods to perform "
|
|
727
|
+
"predictions or evaluations.")
|
|
728
|
+
|
|
729
|
+
try:
|
|
730
|
+
request = RFMParseQueryRequest(
|
|
731
|
+
query=query,
|
|
732
|
+
graph_definition=self._graph_def,
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
resp = self._api_client.parse_query(request)
|
|
736
|
+
|
|
737
|
+
if len(resp.validation_response.warnings) > 0:
|
|
738
|
+
msg = '\n'.join([
|
|
739
|
+
f'{i+1}. {warning.title}: {warning.message}' for i, warning
|
|
740
|
+
in enumerate(resp.validation_response.warnings)
|
|
741
|
+
])
|
|
742
|
+
warnings.warn(f"Encountered the following warnings during "
|
|
743
|
+
f"parsing:\n{msg}")
|
|
744
|
+
|
|
745
|
+
return resp.query
|
|
746
|
+
except HTTPException as e:
|
|
747
|
+
try:
|
|
748
|
+
msg = json.loads(e.detail)['detail']
|
|
749
|
+
except Exception:
|
|
750
|
+
msg = e.detail
|
|
751
|
+
raise ValueError(f"Failed to parse query '{query}'. "
|
|
752
|
+
f"{msg}") from None
|
|
753
|
+
|
|
754
|
+
@staticmethod
|
|
755
|
+
def _get_task_type(
|
|
756
|
+
query: ValidatedPredictiveQuery,
|
|
757
|
+
edge_types: list[tuple[str, str, str]],
|
|
758
|
+
) -> TaskType:
|
|
759
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
760
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
761
|
+
|
|
762
|
+
target = query.target_ast
|
|
763
|
+
if isinstance(target, Join):
|
|
764
|
+
target = target.rhs_target
|
|
765
|
+
if isinstance(target, Aggregation):
|
|
766
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
767
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
768
|
+
'.')
|
|
769
|
+
target_edge_types = [
|
|
770
|
+
edge_type for edge_type in edge_types
|
|
771
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
772
|
+
]
|
|
773
|
+
if len(target_edge_types) != 1:
|
|
774
|
+
raise NotImplementedError(
|
|
775
|
+
f"Multilabel-classification queries based on "
|
|
776
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
777
|
+
f"planned to write a link prediction query instead, "
|
|
778
|
+
f"make sure to register '{col_name}' as a "
|
|
779
|
+
f"foreign key.")
|
|
780
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
781
|
+
|
|
782
|
+
return TaskType.REGRESSION
|
|
783
|
+
|
|
784
|
+
assert isinstance(target, Column)
|
|
785
|
+
|
|
786
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
787
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
788
|
+
|
|
789
|
+
if target.stype in {Stype.numerical}:
|
|
790
|
+
return TaskType.REGRESSION
|
|
791
|
+
|
|
792
|
+
raise NotImplementedError("Task type not yet supported")
|
|
793
|
+
|
|
794
|
+
def _get_default_anchor_time(
|
|
795
|
+
self,
|
|
796
|
+
query: ValidatedPredictiveQuery,
|
|
797
|
+
) -> pd.Timestamp:
|
|
798
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
799
|
+
aggr_table_names = [
|
|
800
|
+
aggr._get_target_column_name().split('.')[0]
|
|
801
|
+
for aggr in query.get_all_target_aggregations()
|
|
802
|
+
]
|
|
803
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
804
|
+
|
|
805
|
+
assert query.query_type == QueryType.STATIC
|
|
806
|
+
return self._sampler.get_max_time()
|
|
807
|
+
|
|
808
|
+
def _validate_time(
|
|
809
|
+
self,
|
|
810
|
+
query: ValidatedPredictiveQuery,
|
|
811
|
+
anchor_time: pd.Timestamp,
|
|
812
|
+
context_anchor_time: pd.Timestamp | None,
|
|
813
|
+
evaluate: bool,
|
|
814
|
+
) -> None:
|
|
815
|
+
|
|
816
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
817
|
+
return # Graph without timestamps
|
|
818
|
+
|
|
819
|
+
min_time = self._sampler.get_min_time()
|
|
820
|
+
max_time = self._sampler.get_max_time()
|
|
821
|
+
|
|
822
|
+
if anchor_time < min_time:
|
|
823
|
+
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
824
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
825
|
+
f"data.")
|
|
826
|
+
|
|
827
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
828
|
+
raise ValueError(f"Context anchor timestamp is too early or "
|
|
829
|
+
f"aggregation time range is too large. To make "
|
|
830
|
+
f"this prediction, we would need data back to "
|
|
831
|
+
f"'{context_anchor_time}', however, your data "
|
|
832
|
+
f"only contains data back to '{min_time}'.")
|
|
833
|
+
|
|
834
|
+
if query.target_ast.date_offset_range is not None:
|
|
835
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
836
|
+
else:
|
|
837
|
+
end_offset = pd.DateOffset(0)
|
|
838
|
+
end_offset = end_offset * query.num_forecasts
|
|
839
|
+
|
|
840
|
+
if (context_anchor_time is not None
|
|
841
|
+
and context_anchor_time > anchor_time):
|
|
842
|
+
warnings.warn(f"Context anchor timestamp "
|
|
843
|
+
f"(got '{context_anchor_time}') is set to a later "
|
|
844
|
+
f"date than the prediction anchor timestamp "
|
|
845
|
+
f"(got '{anchor_time}'). Please make sure this is "
|
|
846
|
+
f"intended.")
|
|
847
|
+
elif (query.query_type == QueryType.TEMPORAL
|
|
848
|
+
and context_anchor_time is not None
|
|
849
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
850
|
+
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
851
|
+
f"'{context_anchor_time}' will leak information "
|
|
852
|
+
f"from the prediction anchor timestamp "
|
|
853
|
+
f"'{anchor_time}'. Please make sure this is "
|
|
854
|
+
f"intended.")
|
|
855
|
+
|
|
856
|
+
elif (context_anchor_time is not None
|
|
857
|
+
and context_anchor_time - end_offset < min_time):
|
|
858
|
+
_time = context_anchor_time - end_offset
|
|
859
|
+
warnings.warn(f"Context anchor timestamp is too early or "
|
|
860
|
+
f"aggregation time range is too large. To form "
|
|
861
|
+
f"proper input data, we would need data back to "
|
|
862
|
+
f"'{_time}', however, your data only contains "
|
|
863
|
+
f"data back to '{min_time}'.")
|
|
864
|
+
|
|
865
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
866
|
+
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
867
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
868
|
+
f"make sure this is intended.")
|
|
869
|
+
|
|
870
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
871
|
+
raise ValueError(
|
|
872
|
+
f"Anchor timestamp for evaluation is after the latest "
|
|
873
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
874
|
+
|
|
875
|
+
def _get_context(
|
|
876
|
+
self,
|
|
877
|
+
query: ValidatedPredictiveQuery,
|
|
878
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
879
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None,
|
|
880
|
+
context_anchor_time: pd.Timestamp | None,
|
|
881
|
+
run_mode: RunMode,
|
|
882
|
+
num_neighbors: list[int] | None,
|
|
883
|
+
num_hops: int,
|
|
884
|
+
max_pq_iterations: int,
|
|
885
|
+
evaluate: bool,
|
|
886
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
887
|
+
logger: ProgressLogger | None = None,
|
|
888
|
+
) -> Context:
|
|
889
|
+
|
|
890
|
+
if num_neighbors is not None:
|
|
891
|
+
num_hops = len(num_neighbors)
|
|
892
|
+
|
|
893
|
+
if num_hops < 0:
|
|
894
|
+
raise ValueError(f"'num_hops' must be non-negative "
|
|
895
|
+
f"(got {num_hops})")
|
|
896
|
+
if num_hops > 6:
|
|
897
|
+
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
898
|
+
f"hops (got {num_hops}). Please reduce the "
|
|
899
|
+
f"number of hops and try again. Please create a "
|
|
900
|
+
f"feature request at "
|
|
901
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
902
|
+
f"must go beyond this for your use-case.")
|
|
903
|
+
|
|
904
|
+
task_type = self._get_task_type(
|
|
905
|
+
query=query,
|
|
906
|
+
edge_types=self._sampler.edge_types,
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
if logger is not None:
|
|
910
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
911
|
+
task_type_repr = 'binary classification'
|
|
912
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
913
|
+
task_type_repr = 'multi-class classification'
|
|
914
|
+
elif task_type == TaskType.REGRESSION:
|
|
915
|
+
task_type_repr = 'regression'
|
|
916
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
917
|
+
task_type_repr = 'link prediction'
|
|
918
|
+
else:
|
|
919
|
+
task_type_repr = str(task_type)
|
|
920
|
+
logger.log(f"Identified {query.query_type} {task_type_repr} task")
|
|
921
|
+
|
|
922
|
+
if task_type.is_link_pred and num_hops < 2:
|
|
923
|
+
raise ValueError(f"Cannot perform link prediction on subgraphs "
|
|
924
|
+
f"with less than 2 hops (got {num_hops}) since "
|
|
925
|
+
f"historical target entities need to be part of "
|
|
926
|
+
f"the context. Please increase the number of "
|
|
927
|
+
f"hops and try again.")
|
|
928
|
+
|
|
929
|
+
if num_neighbors is None:
|
|
930
|
+
if run_mode == RunMode.DEBUG:
|
|
931
|
+
num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
|
|
932
|
+
elif run_mode == RunMode.FAST or task_type.is_link_pred:
|
|
933
|
+
num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
|
|
934
|
+
else:
|
|
935
|
+
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
936
|
+
|
|
937
|
+
if query.target_ast.date_offset_range is None:
|
|
938
|
+
step_offset = pd.DateOffset(0)
|
|
939
|
+
else:
|
|
940
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
941
|
+
end_offset = step_offset * query.num_forecasts
|
|
942
|
+
|
|
943
|
+
if anchor_time is None:
|
|
944
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
945
|
+
|
|
946
|
+
if evaluate:
|
|
947
|
+
anchor_time = anchor_time - end_offset
|
|
948
|
+
|
|
949
|
+
if logger is not None:
|
|
950
|
+
assert isinstance(anchor_time, pd.Timestamp)
|
|
951
|
+
if anchor_time == pd.Timestamp.min:
|
|
952
|
+
pass # Static graph
|
|
953
|
+
elif (anchor_time.hour == 0 and anchor_time.minute == 0
|
|
954
|
+
and anchor_time.second == 0
|
|
955
|
+
and anchor_time.microsecond == 0):
|
|
956
|
+
logger.log(f"Derived anchor time {anchor_time.date()}")
|
|
957
|
+
else:
|
|
958
|
+
logger.log(f"Derived anchor time {anchor_time}")
|
|
959
|
+
|
|
960
|
+
assert anchor_time is not None
|
|
961
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
962
|
+
if context_anchor_time == 'entity':
|
|
963
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
964
|
+
"for context and prediction examples")
|
|
965
|
+
if context_anchor_time is None:
|
|
966
|
+
context_anchor_time = anchor_time - end_offset
|
|
967
|
+
self._validate_time(query, anchor_time, context_anchor_time,
|
|
968
|
+
evaluate)
|
|
969
|
+
else:
|
|
970
|
+
assert anchor_time == 'entity'
|
|
971
|
+
if query.query_type != QueryType.STATIC:
|
|
972
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
973
|
+
"static predictive queries")
|
|
974
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
975
|
+
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
976
|
+
f"table '{query.entity_table}' to "
|
|
977
|
+
f"have a time column")
|
|
978
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
979
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
980
|
+
"for context and prediction examples")
|
|
981
|
+
context_anchor_time = 'entity'
|
|
982
|
+
|
|
983
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
984
|
+
if evaluate:
|
|
985
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
986
|
+
if task_type.is_link_pred:
|
|
987
|
+
num_test_examples = num_test_examples // 5
|
|
988
|
+
else:
|
|
989
|
+
num_test_examples = 0
|
|
990
|
+
|
|
991
|
+
train, test = self._sampler.sample_target(
|
|
992
|
+
query=query,
|
|
993
|
+
num_train_examples=num_train_examples,
|
|
994
|
+
train_anchor_time=context_anchor_time,
|
|
995
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
996
|
+
num_test_examples=num_test_examples,
|
|
997
|
+
test_anchor_time=anchor_time,
|
|
998
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
999
|
+
random_seed=random_seed,
|
|
1000
|
+
)
|
|
1001
|
+
train_pkey, train_time, y_train = train
|
|
1002
|
+
test_pkey, test_time, y_test = test
|
|
1003
|
+
|
|
1004
|
+
if evaluate and logger is not None:
|
|
1005
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1006
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1007
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1008
|
+
f"{pos:.2f}% positive cases")
|
|
1009
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1010
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1011
|
+
f"{y_test.nunique()} classes")
|
|
1012
|
+
elif task_type == TaskType.REGRESSION:
|
|
1013
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1014
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1015
|
+
f"between {format_value(_min)} and "
|
|
1016
|
+
f"{format_value(_max)}")
|
|
1017
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1018
|
+
num_rhs = y_test.explode().nunique()
|
|
1019
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1020
|
+
f"{num_rhs:,} unique items")
|
|
1021
|
+
else:
|
|
1022
|
+
raise NotImplementedError
|
|
1023
|
+
logger.log(msg)
|
|
1024
|
+
|
|
1025
|
+
if not evaluate:
|
|
1026
|
+
assert indices is not None
|
|
1027
|
+
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
1028
|
+
raise ValueError(f"Cannot predict for more than "
|
|
1029
|
+
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
1030
|
+
f"once (got {len(indices):,}). Use "
|
|
1031
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
1032
|
+
f"in batches")
|
|
1033
|
+
|
|
1034
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
1035
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
1036
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1037
|
+
len(indices)).reset_index(drop=True)
|
|
1038
|
+
else:
|
|
1039
|
+
train_time = test_time = 'entity'
|
|
1040
|
+
|
|
1041
|
+
if logger is not None:
|
|
1042
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1043
|
+
pos = 100 * int((y_train > 0).sum()) / len(y_train)
|
|
1044
|
+
msg = (f"Collected {len(y_train):,} in-context examples with "
|
|
1045
|
+
f"{pos:.2f}% positive cases")
|
|
1046
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1047
|
+
msg = (f"Collected {len(y_train):,} in-context examples "
|
|
1048
|
+
f"holding {y_train.nunique()} classes")
|
|
1049
|
+
elif task_type == TaskType.REGRESSION:
|
|
1050
|
+
_min, _max = float(y_train.min()), float(y_train.max())
|
|
1051
|
+
msg = (f"Collected {len(y_train):,} in-context examples with "
|
|
1052
|
+
f"targets between {format_value(_min)} and "
|
|
1053
|
+
f"{format_value(_max)}")
|
|
1054
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1055
|
+
num_rhs = y_train.explode().nunique()
|
|
1056
|
+
msg = (f"Collected {len(y_train):,} in-context examples with "
|
|
1057
|
+
f"{num_rhs:,} unique items")
|
|
1058
|
+
else:
|
|
1059
|
+
raise NotImplementedError
|
|
1060
|
+
logger.log(msg)
|
|
1061
|
+
|
|
1062
|
+
entity_table_names: tuple[str, ...]
|
|
1063
|
+
if task_type.is_link_pred:
|
|
1064
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1065
|
+
assert final_aggr is not None
|
|
1066
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1067
|
+
for edge_type in self._sampler.edge_types:
|
|
1068
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1069
|
+
entity_table_names = (
|
|
1070
|
+
query.entity_table,
|
|
1071
|
+
edge_type[2],
|
|
1072
|
+
)
|
|
1073
|
+
else:
|
|
1074
|
+
entity_table_names = (query.entity_table, )
|
|
1075
|
+
|
|
1076
|
+
# Exclude the entity anchor time from the feature set to prevent
|
|
1077
|
+
# running out-of-distribution between in-context and test examples:
|
|
1078
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
1079
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
1080
|
+
if entity_table_names[0] not in exclude_cols_dict:
|
|
1081
|
+
exclude_cols_dict[entity_table_names[0]] = []
|
|
1082
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
1083
|
+
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
1084
|
+
|
|
1085
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1086
|
+
entity_table_names=entity_table_names,
|
|
1087
|
+
entity_pkey=pd.concat(
|
|
1088
|
+
[train_pkey, test_pkey],
|
|
1089
|
+
axis=0,
|
|
1090
|
+
ignore_index=True,
|
|
1091
|
+
),
|
|
1092
|
+
anchor_time=pd.concat(
|
|
1093
|
+
[train_time, test_time],
|
|
1094
|
+
axis=0,
|
|
1095
|
+
ignore_index=True,
|
|
1096
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1097
|
+
num_neighbors=num_neighbors,
|
|
1098
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
1099
|
+
)
|
|
1100
|
+
|
|
1101
|
+
if len(subgraph.table_dict) >= 15:
|
|
1102
|
+
raise ValueError(f"Cannot query from a graph with more than 15 "
|
|
1103
|
+
f"tables (got {len(subgraph.table_dict)}). "
|
|
1104
|
+
f"Please create a feature request at "
|
|
1105
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1106
|
+
f"must go beyond this for your use-case.")
|
|
1107
|
+
|
|
1108
|
+
return Context(
|
|
1109
|
+
task_type=task_type,
|
|
1110
|
+
entity_table_names=entity_table_names,
|
|
1111
|
+
subgraph=subgraph,
|
|
1112
|
+
y_train=y_train,
|
|
1113
|
+
y_test=y_test if evaluate else None,
|
|
1114
|
+
top_k=query.top_k,
|
|
1115
|
+
step_size=None,
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
@staticmethod
|
|
1119
|
+
def _validate_metrics(
|
|
1120
|
+
metrics: list[str],
|
|
1121
|
+
task_type: TaskType,
|
|
1122
|
+
) -> None:
|
|
1123
|
+
|
|
1124
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1125
|
+
supported_metrics = [
|
|
1126
|
+
'acc', 'precision', 'recall', 'f1', 'auroc', 'auprc', 'ap'
|
|
1127
|
+
]
|
|
1128
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1129
|
+
supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
|
|
1130
|
+
elif task_type == TaskType.REGRESSION:
|
|
1131
|
+
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
|
|
1132
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1133
|
+
supported_metrics = [
|
|
1134
|
+
'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
|
|
1135
|
+
'hit_ratio@'
|
|
1136
|
+
]
|
|
1137
|
+
else:
|
|
1138
|
+
raise NotImplementedError
|
|
1139
|
+
|
|
1140
|
+
for metric in metrics:
|
|
1141
|
+
if '@' in metric:
|
|
1142
|
+
metric_split = metric.split('@')
|
|
1143
|
+
if len(metric_split) != 2:
|
|
1144
|
+
raise ValueError(f"Unsupported metric '{metric}'. "
|
|
1145
|
+
f"Available metrics "
|
|
1146
|
+
f"are {supported_metrics}.")
|
|
1147
|
+
|
|
1148
|
+
name, top_k = f'{metric_split[0]}@', metric_split[1]
|
|
1149
|
+
|
|
1150
|
+
if not top_k.isdigit():
|
|
1151
|
+
raise ValueError(f"Metric '{metric}' does not define a "
|
|
1152
|
+
f"valid 'top_k' value (got '{top_k}').")
|
|
1153
|
+
|
|
1154
|
+
if int(top_k) <= 0:
|
|
1155
|
+
raise ValueError(f"Metric '{metric}' needs to define a "
|
|
1156
|
+
f"positive 'top_k' value (got '{top_k}')")
|
|
1157
|
+
|
|
1158
|
+
if int(top_k) > 100:
|
|
1159
|
+
raise ValueError(f"Metric '{metric}' defines a 'top_k' "
|
|
1160
|
+
f"value greater than 100 "
|
|
1161
|
+
f"(got '{top_k}'). Please create a "
|
|
1162
|
+
f"feature request at "
|
|
1163
|
+
f"'https://github.com/kumo-ai/kumo-rfm' "
|
|
1164
|
+
f"if you must go beyond this for your "
|
|
1165
|
+
f"use-case.")
|
|
1166
|
+
|
|
1167
|
+
metric = name
|
|
1168
|
+
|
|
1169
|
+
if metric not in supported_metrics:
|
|
1170
|
+
raise ValueError(f"Unsupported metric '{metric}'. Available "
|
|
1171
|
+
f"metrics are {supported_metrics}. If you "
|
|
1172
|
+
f"feel a metric is missing, please create a "
|
|
1173
|
+
f"feature request at "
|
|
1174
|
+
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1175
|
+
|
|
1176
|
+
|
|
1177
|
+
def format_value(value: int | float) -> str:
|
|
1178
|
+
if value == int(value):
|
|
1179
|
+
return f'{int(value):,}'
|
|
1180
|
+
if abs(value) >= 1000:
|
|
1181
|
+
return f'{value:,.0f}'
|
|
1182
|
+
if abs(value) >= 10:
|
|
1183
|
+
return f'{value:.1f}'
|
|
1184
|
+
return f'{value:.2f}'
|