sdk-seshat-python 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdk_seshat_python-0.3.4.dist-info/LICENSE +42 -0
- sdk_seshat_python-0.3.4.dist-info/METADATA +24 -0
- sdk_seshat_python-0.3.4.dist-info/RECORD +105 -0
- sdk_seshat_python-0.3.4.dist-info/WHEEL +4 -0
- seshat/__init__.py +43 -0
- seshat/__main__.py +11 -0
- seshat/data_class/__init__.py +9 -0
- seshat/data_class/base.py +267 -0
- seshat/data_class/pandas.py +65 -0
- seshat/data_class/pyspark.py +63 -0
- seshat/evaluation/__init__.py +1 -0
- seshat/evaluation/base.py +42 -0
- seshat/evaluation/evaluator/__init__.py +1 -0
- seshat/evaluation/evaluator/base.py +15 -0
- seshat/evaluation/evaluator/general/__init__.py +3 -0
- seshat/evaluation/evaluator/general/classification.py +172 -0
- seshat/evaluation/evaluator/general/clustering.py +20 -0
- seshat/evaluation/evaluator/general/regression.py +112 -0
- seshat/evaluation/evaluator/recommendation/__init__.py +2 -0
- seshat/evaluation/evaluator/recommendation/diversity.py +73 -0
- seshat/evaluation/evaluator/recommendation/ranking.py +95 -0
- seshat/feature_view/__init__.py +0 -0
- seshat/feature_view/base.py +141 -0
- seshat/general/__init__.py +1 -0
- seshat/general/command/__init__.py +1 -0
- seshat/general/command/base.py +98 -0
- seshat/general/config.py +33 -0
- seshat/general/exceptions.py +62 -0
- seshat/general/lazy_config.py +41 -0
- seshat/general/template/README.md-tmpl +0 -0
- seshat/general/template/config.py-tmpl +17 -0
- seshat/general/template/env-templ +3 -0
- seshat/general/template/pyproject._toml-tmpl +15 -0
- seshat/general/template/recommender-jupyter.ipynb-tmpl +52 -0
- seshat/general/template/recommender.py-tmpl +56 -0
- seshat/profiler/__init__.py +2 -0
- seshat/profiler/base.py +222 -0
- seshat/profiler/decorator.py +7 -0
- seshat/source/__init__.py +1 -0
- seshat/source/base.py +46 -0
- seshat/source/database/__init__.py +1 -0
- seshat/source/database/base.py +52 -0
- seshat/source/exceptions.py +16 -0
- seshat/source/flip_side/__init__.py +1 -0
- seshat/source/flip_side/base.py +247 -0
- seshat/source/local/__init__.py +1 -0
- seshat/source/local/base.py +26 -0
- seshat/source/mixins.py +70 -0
- seshat/source/multisource/__init__.py +1 -0
- seshat/source/multisource/base.py +28 -0
- seshat/source/saver/__init__.py +2 -0
- seshat/source/saver/base.py +47 -0
- seshat/source/saver/database.py +197 -0
- seshat/source/saver/utils/__init__.py +1 -0
- seshat/source/saver/utils/postgres.py +22 -0
- seshat/transformer/__init__.py +1 -0
- seshat/transformer/augmenter/__init__.py +0 -0
- seshat/transformer/augmenter/base.py +6 -0
- seshat/transformer/base.py +144 -0
- seshat/transformer/deriver/__init__.py +9 -0
- seshat/transformer/deriver/base.py +997 -0
- seshat/transformer/deriver/from_database.py +66 -0
- seshat/transformer/imputer/__init__.py +0 -0
- seshat/transformer/imputer/base.py +6 -0
- seshat/transformer/merger/__init__.py +1 -0
- seshat/transformer/merger/base.py +221 -0
- seshat/transformer/pipeline/__init__.py +1 -0
- seshat/transformer/pipeline/base.py +60 -0
- seshat/transformer/pipeline/branch.py +162 -0
- seshat/transformer/pipeline/recommendation/__init__.py +0 -0
- seshat/transformer/pipeline/recommendation/address_pipeline.py +79 -0
- seshat/transformer/reducer/__init__.py +0 -0
- seshat/transformer/reducer/base.py +6 -0
- seshat/transformer/scaler/__init__.py +0 -0
- seshat/transformer/scaler/base.py +6 -0
- seshat/transformer/schema/__init__.py +1 -0
- seshat/transformer/schema/base.py +127 -0
- seshat/transformer/splitter/__init__.py +1 -0
- seshat/transformer/splitter/base.py +70 -0
- seshat/transformer/splitter/block/__init__.py +1 -0
- seshat/transformer/splitter/block/base.py +39 -0
- seshat/transformer/splitter/random/__init__.py +1 -0
- seshat/transformer/splitter/random/base.py +28 -0
- seshat/transformer/splitter/time_line/__init__.py +1 -0
- seshat/transformer/splitter/time_line/base.py +8 -0
- seshat/transformer/trimmer/__init__.py +6 -0
- seshat/transformer/trimmer/base.py +321 -0
- seshat/transformer/vectorizer/__init__.py +2 -0
- seshat/transformer/vectorizer/base.py +8 -0
- seshat/transformer/vectorizer/cosine_similarity.py +231 -0
- seshat/transformer/vectorizer/pivot.py +115 -0
- seshat/transformer/vectorizer/utils.py +134 -0
- seshat/utils/__init__.py +0 -0
- seshat/utils/col_to_list.py +18 -0
- seshat/utils/contracts.py +32 -0
- seshat/utils/file.py +9 -0
- seshat/utils/join_columns_to_list.py +29 -0
- seshat/utils/join_str.py +5 -0
- seshat/utils/memory.py +5 -0
- seshat/utils/mixin.py +49 -0
- seshat/utils/pandas_func.py +22 -0
- seshat/utils/patching.py +29 -0
- seshat/utils/pyspark_func.py +16 -0
- seshat/utils/singleton.py +7 -0
- seshat/utils/validation.py +38 -0
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from typing import Dict, Callable
|
|
2
|
+
|
|
3
|
+
from pandas import DataFrame
|
|
4
|
+
from pyspark.sql import DataFrame as PySparkDataFrame
|
|
5
|
+
|
|
6
|
+
from seshat.source.database import SQLDBSource
|
|
7
|
+
from seshat.transformer.deriver.base import SFrameDeriver
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FromSQLDBDeriver(SFrameDeriver):
|
|
11
|
+
"""
|
|
12
|
+
This transformer derive new column by fetching data from sql database source find
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
DEFAULT_GROUP_KEYS = {"default": "default", "result": "result"}
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
source: SQLDBSource,
|
|
20
|
+
base_col: str = None,
|
|
21
|
+
query: str = None,
|
|
22
|
+
get_query_fn: Callable = None,
|
|
23
|
+
get_query_fn_kwargs=None,
|
|
24
|
+
filters: dict = None,
|
|
25
|
+
merge_result: bool = True,
|
|
26
|
+
merge_how: str = "left",
|
|
27
|
+
group_keys=None,
|
|
28
|
+
*args,
|
|
29
|
+
**kwargs,
|
|
30
|
+
):
|
|
31
|
+
super().__init__(group_keys, *args, **kwargs)
|
|
32
|
+
if get_query_fn_kwargs is None:
|
|
33
|
+
get_query_fn_kwargs = {}
|
|
34
|
+
if filters is None:
|
|
35
|
+
filters = {}
|
|
36
|
+
self.get_query_fn_kwargs = get_query_fn_kwargs
|
|
37
|
+
self.source = source
|
|
38
|
+
self.get_query_fn = get_query_fn
|
|
39
|
+
self.query = query
|
|
40
|
+
self.filters = filters
|
|
41
|
+
self.base_col = base_col
|
|
42
|
+
self.merge_result = merge_result
|
|
43
|
+
self.merge_how = merge_how
|
|
44
|
+
|
|
45
|
+
def derive_df(self, default: DataFrame, *args, **kwargs) -> Dict["str", DataFrame]:
|
|
46
|
+
db_result = self.get_from_source(default, *args, **kwargs)
|
|
47
|
+
if self.merge_result:
|
|
48
|
+
default = default.merge(
|
|
49
|
+
right=db_result, on=self.base_col, how=self.merge_how
|
|
50
|
+
)
|
|
51
|
+
return {"default": default}
|
|
52
|
+
return {"default": default, "result": db_result}
|
|
53
|
+
|
|
54
|
+
def derive_spf(self, default: PySparkDataFrame, *args, **kwargs):
|
|
55
|
+
db_result = self.get_from_source(default, *args, **kwargs)
|
|
56
|
+
if self.merge_result:
|
|
57
|
+
default = default.join(db_result, on=self.base_col, how=self.merge_how)
|
|
58
|
+
return {"default": default}
|
|
59
|
+
return {"default": default, "result": db_result}
|
|
60
|
+
|
|
61
|
+
def get_from_source(self, default, *args, **kwargs):
|
|
62
|
+
self.source.query = self.query or self.get_query_fn(
|
|
63
|
+
default, *args, **kwargs, **self.get_query_fn_kwargs
|
|
64
|
+
)
|
|
65
|
+
self.source.filters = self.filters
|
|
66
|
+
return self.source().to_raw()
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import Merger
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
from typing import Iterable, List, Dict
|
|
2
|
+
from typing import Type
|
|
3
|
+
|
|
4
|
+
from pyspark.sql import DataFrame as PySparkDataFrame
|
|
5
|
+
from pyspark.sql.dataframe import DataFrame
|
|
6
|
+
|
|
7
|
+
from seshat.data_class import DFrame, GroupSFrame
|
|
8
|
+
from seshat.data_class import SFrame, SPFrame
|
|
9
|
+
from seshat.general import configs
|
|
10
|
+
from seshat.general.exceptions import InvalidArgumentsError
|
|
11
|
+
from seshat.transformer import Transformer
|
|
12
|
+
from seshat.transformer.schema import Schema
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SFrameMerger(Transformer):
|
|
16
|
+
HANDLER_NAME = "merge"
|
|
17
|
+
ONLY_GROUP = True
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Merger(SFrameMerger):
|
|
21
|
+
"""
|
|
22
|
+
This transformer merge two sframe based on columns names.
|
|
23
|
+
|
|
24
|
+
Parameters:
|
|
25
|
+
axis: int, optional
|
|
26
|
+
Show that merging operation should operate on which axis. Default is 1.
|
|
27
|
+
left_on: str, optional
|
|
28
|
+
Column name that is base of merging for default sf
|
|
29
|
+
right_on: str, optional
|
|
30
|
+
Column name that is base of merging for right sf
|
|
31
|
+
on: str, optional
|
|
32
|
+
Column name that is base of merging for both default & right sf.
|
|
33
|
+
This is same when left_on and right_on are same.
|
|
34
|
+
right_schema: Schema, optional
|
|
35
|
+
This schema will be ran on right sf before merging happened.
|
|
36
|
+
Useful when keep only some columns of right sf in default sf is required.
|
|
37
|
+
inplace: bool, optional
|
|
38
|
+
Show that weather default sf should changed or merged result save on new sf.
|
|
39
|
+
Default is False.
|
|
40
|
+
merge_how: str, optional
|
|
41
|
+
Show how to merge two sframes. For example: left, right, inner.
|
|
42
|
+
Default is left.
|
|
43
|
+
drop_unmerged: bool, optional
|
|
44
|
+
If true, two sframe that used for merging will drop from result.
|
|
45
|
+
group_key
|
|
46
|
+
Group keys for the parent Transformer class.
|
|
47
|
+
*args
|
|
48
|
+
Additional positional arguments.
|
|
49
|
+
**kwargs
|
|
50
|
+
Additional keyword arguments.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
DEFAULT_GROUP_KEYS = {
|
|
54
|
+
"default": configs.DEFAULT_SF_KEY,
|
|
55
|
+
"other": configs.OTHER_SF_KEY,
|
|
56
|
+
"merged": configs.MERGED_SF_KEY,
|
|
57
|
+
}
|
|
58
|
+
on: str | Iterable[str]
|
|
59
|
+
axis: int
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
axis: int = 1,
|
|
64
|
+
left_on: str | List[str] = None,
|
|
65
|
+
right_on: str | List[str] = None,
|
|
66
|
+
on: str | List[str] = None,
|
|
67
|
+
right_schema: Schema = None,
|
|
68
|
+
inplace: bool = False,
|
|
69
|
+
group_keys=None,
|
|
70
|
+
merge_how: str = "left",
|
|
71
|
+
drop_unmerged: bool = False,
|
|
72
|
+
*args,
|
|
73
|
+
**kwargs,
|
|
74
|
+
):
|
|
75
|
+
if isinstance(left_on, list) != isinstance(right_on, list):
|
|
76
|
+
raise InvalidArgumentsError(
|
|
77
|
+
"left_on and right_on must both be either lists or strings and cannot be of different types."
|
|
78
|
+
)
|
|
79
|
+
if isinstance(left_on, list) and len(left_on) != len(right_on):
|
|
80
|
+
raise InvalidArgumentsError("len(right_on) must equal len(left_on)")
|
|
81
|
+
|
|
82
|
+
super().__init__(group_keys, *args, **kwargs)
|
|
83
|
+
self.axis = axis
|
|
84
|
+
self.left_on = left_on
|
|
85
|
+
self.right_on = right_on
|
|
86
|
+
self.on = on
|
|
87
|
+
self.right_schema = right_schema
|
|
88
|
+
self.inplace = inplace
|
|
89
|
+
self.merge_how = merge_how
|
|
90
|
+
self.drop_unmerged = drop_unmerged
|
|
91
|
+
|
|
92
|
+
def set_raw(self, sf: GroupSFrame, result: Dict[str, object]):
|
|
93
|
+
if self.drop_unmerged:
|
|
94
|
+
for _, v in self.group_keys.items():
|
|
95
|
+
sf.children.pop(v, None)
|
|
96
|
+
return super().set_raw(sf, result)
|
|
97
|
+
|
|
98
|
+
def merge_df(self, default: DataFrame, other: DataFrame, *args, **kwargs):
|
|
99
|
+
merging_kwargs = self.get_merging_kwargs()
|
|
100
|
+
merged = self.merge(default, other, DFrame, merging_kwargs)
|
|
101
|
+
if self.axis == 1 and "right_on" in merging_kwargs:
|
|
102
|
+
merged = merged.drop(self.get_drop_cols(), axis=1)
|
|
103
|
+
merged = merged.drop_duplicates().reset_index(drop=True)
|
|
104
|
+
return self.choose_return_result(default, other, merged)
|
|
105
|
+
|
|
106
|
+
def merge_spf(
|
|
107
|
+
self, default: PySparkDataFrame, other: PySparkDataFrame, *args, **kwargs
|
|
108
|
+
):
|
|
109
|
+
merging_kwargs = self.get_merging_kwargs()
|
|
110
|
+
merged = self.merge(default, other, SPFrame, merging_kwargs)
|
|
111
|
+
if self.axis == 1 and "right_on" in merging_kwargs:
|
|
112
|
+
merged: PySparkDataFrame = merged.drop(self.get_drop_cols())
|
|
113
|
+
merged = merged.dropDuplicates()
|
|
114
|
+
return self.choose_return_result(default, other, merged)
|
|
115
|
+
|
|
116
|
+
def merge(
|
|
117
|
+
self,
|
|
118
|
+
default: object,
|
|
119
|
+
other: object,
|
|
120
|
+
sf_class: Type[SFrame],
|
|
121
|
+
merging_kwargs: dict,
|
|
122
|
+
) -> object:
|
|
123
|
+
default_sf = sf_class.from_raw(default)
|
|
124
|
+
other_sf = sf_class.from_raw(other)
|
|
125
|
+
if self.right_schema:
|
|
126
|
+
other_sf = self.right_schema(other_sf)
|
|
127
|
+
result = default_sf.extend(other_sf.to_raw(), axis=self.axis, **merging_kwargs)
|
|
128
|
+
return result
|
|
129
|
+
|
|
130
|
+
def get_merging_kwargs(self):
|
|
131
|
+
if self.on:
|
|
132
|
+
return {"on": self.on, "how": self.merge_how}
|
|
133
|
+
return {
|
|
134
|
+
"left_on": self.left_on,
|
|
135
|
+
"right_on": self.right_on,
|
|
136
|
+
"how": self.merge_how,
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
def choose_return_result(self, default, other, merged):
|
|
140
|
+
if self.drop_unmerged:
|
|
141
|
+
result_kwargs = {}
|
|
142
|
+
else:
|
|
143
|
+
result_kwargs = {"default": default, "other": other}
|
|
144
|
+
|
|
145
|
+
if self.inplace:
|
|
146
|
+
result_kwargs["default"] = merged
|
|
147
|
+
else:
|
|
148
|
+
result_kwargs["merged"] = merged
|
|
149
|
+
return result_kwargs
|
|
150
|
+
|
|
151
|
+
def get_drop_cols(self):
|
|
152
|
+
left_on_set = (
|
|
153
|
+
{self.left_on} if isinstance(self.left_on, str) else set(self.left_on)
|
|
154
|
+
)
|
|
155
|
+
right_on_set = (
|
|
156
|
+
{self.right_on} if isinstance(self.right_on, str) else set(self.right_on)
|
|
157
|
+
)
|
|
158
|
+
return right_on_set - left_on_set
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class ListMerger(SFrameMerger):
|
|
162
|
+
"""
|
|
163
|
+
Get the list of sframe either they are group or non-group and at last return
|
|
164
|
+
only one group with all input sf.
|
|
165
|
+
To avoid duplication in sf keys in result group sf, this format have been used:
|
|
166
|
+
If sf is group, then key is prefix + index of sf in input sf_list + child key
|
|
167
|
+
Otherwise key is prefix + index of sf in input sf_list
|
|
168
|
+
|
|
169
|
+
Parameters:
|
|
170
|
+
----------
|
|
171
|
+
sf_prefix: str
|
|
172
|
+
The prefix that all sf keys in output starts with it.
|
|
173
|
+
sf_list: List[SFrame]
|
|
174
|
+
The List of sframes, each one can be group or non-group. The merging operation
|
|
175
|
+
is performed on this list.
|
|
176
|
+
|
|
177
|
+
Examples
|
|
178
|
+
--------
|
|
179
|
+
>>> sf1: SFrame, sf2: GroupSFrame
|
|
180
|
+
>>> list(sf2.keys)
|
|
181
|
+
['default', 'address']
|
|
182
|
+
|
|
183
|
+
>>> merger = ListMerger(sf_prefix="sf_")
|
|
184
|
+
>>> result_sf = merger(sf_list=[sf1, sf2])
|
|
185
|
+
>>> list(result_sf.keys)
|
|
186
|
+
['sf_0', 'sf_1_default', 'sf_1_address']
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
def __init__(self, sf_prefix="sf", *args, **kwargs):
|
|
190
|
+
super().__init__(*args, **kwargs)
|
|
191
|
+
self.sf_prefix = sf_prefix
|
|
192
|
+
|
|
193
|
+
def __call__(self, sf_list: List[SFrame], *args, **kwargs):
|
|
194
|
+
result_sf = GroupSFrame()
|
|
195
|
+
for i, sf in enumerate(sf_list):
|
|
196
|
+
if isinstance(sf, GroupSFrame):
|
|
197
|
+
for child_key, child_sf in sf.children.items():
|
|
198
|
+
result_sf[f"{self.sf_prefix}{i}_{child_key}"] = child_sf
|
|
199
|
+
else:
|
|
200
|
+
result_sf[f"{self.sf_prefix}{i}"] = sf
|
|
201
|
+
return result_sf
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class MultiMerger(SFrameMerger):
|
|
205
|
+
"""
|
|
206
|
+
This merger apply multiple merger on input sf.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
mergers = List[Merger]
|
|
210
|
+
ONLY_GROUP = True
|
|
211
|
+
|
|
212
|
+
def __init__(self, mergers, group_keys=None, *args, **kwargs):
|
|
213
|
+
super().__init__(group_keys, *args, **kwargs)
|
|
214
|
+
self.mergers = mergers
|
|
215
|
+
|
|
216
|
+
def __call__(
|
|
217
|
+
self, sf_input: GroupSFrame, *args: object, **kwargs: object
|
|
218
|
+
) -> SFrame:
|
|
219
|
+
for merger in self.mergers:
|
|
220
|
+
sf_input = merger(sf_input)
|
|
221
|
+
return sf_input
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import Pipeline
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from seshat.transformer import Transformer
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Pipeline(Transformer):
|
|
7
|
+
"""
|
|
8
|
+
A data processing pipeline that sequentially applies a list of transformers to input data.
|
|
9
|
+
Each transformer in the `pipes` list processes the data and passes the output to the next transformer
|
|
10
|
+
in the sequence. The result from the last transformer in the list is the final output of the pipeline.
|
|
11
|
+
|
|
12
|
+
The transformers in the `pipes` list must be capable of applying a transformation and
|
|
13
|
+
producing an output that can be handled by the next transformer in the list, if there is one.
|
|
14
|
+
This allows for a flexible and modular design where different transformations can be chained
|
|
15
|
+
together to achieve complex data processing workflows.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
pipes : list of Transformer instances
|
|
20
|
+
A list of transformer objects through which the data will be passed in sequence.
|
|
21
|
+
The output of one transformer becomes the input to the next.
|
|
22
|
+
|
|
23
|
+
Examples
|
|
24
|
+
--------
|
|
25
|
+
>>> pipeline = Pipeline(pipes=[transformer1, transformer2]
|
|
26
|
+
>>> pipeline(input_data)
|
|
27
|
+
|
|
28
|
+
Notes
|
|
29
|
+
-----
|
|
30
|
+
This pipeline design is particularly useful for data transformations where multiple
|
|
31
|
+
discrete processing steps are required. Each transformer should be designed to handle
|
|
32
|
+
the output from the previous transformer in the sequence, ensuring compatibility between transformations.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
pipes: List[Transformer]
|
|
36
|
+
|
|
37
|
+
def __init__(self, pipes: List[Transformer]):
|
|
38
|
+
self.pipes = pipes
|
|
39
|
+
|
|
40
|
+
def __call__(self, data, *args, **kwargs):
|
|
41
|
+
result = data
|
|
42
|
+
for pipe in self.pipes:
|
|
43
|
+
result = pipe(result, *args, **kwargs)
|
|
44
|
+
return result
|
|
45
|
+
|
|
46
|
+
def replace(self, transformer: Transformer, index: int):
|
|
47
|
+
self.pipes[index] = transformer
|
|
48
|
+
return self
|
|
49
|
+
|
|
50
|
+
def append(self, transformer: Transformer):
|
|
51
|
+
self.pipes.append(transformer)
|
|
52
|
+
return self
|
|
53
|
+
|
|
54
|
+
def insert(self, transformer: Transformer, index: int):
|
|
55
|
+
self.pipes.insert(index, transformer)
|
|
56
|
+
return self
|
|
57
|
+
|
|
58
|
+
def remove(self, index: int):
|
|
59
|
+
self.pipes.pop(index)
|
|
60
|
+
return self
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
from concurrent.futures.process import ProcessPoolExecutor
|
|
2
|
+
from typing import List, Dict
|
|
3
|
+
|
|
4
|
+
from seshat.data_class import SFrame, GroupSFrame
|
|
5
|
+
from seshat.general.config import DEFAULT_SF_KEY
|
|
6
|
+
from seshat.general.exceptions import InvalidArgumentsError
|
|
7
|
+
from seshat.transformer import Transformer
|
|
8
|
+
from seshat.transformer.merger import Merger
|
|
9
|
+
from seshat.transformer.pipeline import Pipeline
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def process_pipeline(sf, sf_key, pipe, index, *args, **kwargs):
|
|
13
|
+
if index is not None:
|
|
14
|
+
return sf_key + f"__{index}", pipe(sf.get(sf_key), *args, **kwargs)
|
|
15
|
+
return sf_key, pipe(sf.get(sf_key), *args, **kwargs)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Branch(Transformer):
|
|
19
|
+
"""
|
|
20
|
+
Branch is a Transformer that runs multiple pipelines simultaneously and merges their outputs.
|
|
21
|
+
It can execute pipelines in parallel using multiple processes.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
pipe_map : Dict[str, Pipeline | List[Pipeline]], optional
|
|
26
|
+
A dictionary mapping SFrames to pipelines or lists of pipelines.
|
|
27
|
+
merger : Merger, optional
|
|
28
|
+
An object to merge the pipeline outputs.
|
|
29
|
+
parallel : bool, optional
|
|
30
|
+
Whether to run pipelines in parallel. Default is False.
|
|
31
|
+
max_workers : int, optional
|
|
32
|
+
Maximum number of workers for parallel execution. Default is 16.
|
|
33
|
+
group_keys
|
|
34
|
+
Group keys for the parent Transformer class.
|
|
35
|
+
*args
|
|
36
|
+
Additional positional arguments.
|
|
37
|
+
**kwargs
|
|
38
|
+
Additional keyword arguments.
|
|
39
|
+
|
|
40
|
+
Example usage
|
|
41
|
+
----------
|
|
42
|
+
1. Creating a Branch instance with a list of pipelines:
|
|
43
|
+
|
|
44
|
+
>>> sf_input = GroupSFrame(children={"default": DFrame(transaction_df), "address": DFrame(address_df)})
|
|
45
|
+
... pipe_map = {
|
|
46
|
+
... "default": [Pipeline(pipes=[ZeroAddressTrimmer(), LowTransactionTrimmer()]),
|
|
47
|
+
... Pipeline(pipes=[OperationOnColsDeriver(
|
|
48
|
+
... cols=("sent_count", "received_count"),
|
|
49
|
+
... result_col="tx_count",
|
|
50
|
+
... agg_func="sum",
|
|
51
|
+
... is_numeric=True,
|
|
52
|
+
... )])],
|
|
53
|
+
... }
|
|
54
|
+
... branch = Branch(pipe_map=pipe_map, merger=Merger())
|
|
55
|
+
... result = branch(sf_input)
|
|
56
|
+
|
|
57
|
+
2. Creating a Branch instance with a single pipeline for each SFrame:
|
|
58
|
+
|
|
59
|
+
>>> sf_input = GroupSFrame(children={"default": DFrame(transaction_df), "address": DFrame(address_df)})
|
|
60
|
+
... pipe_map = {
|
|
61
|
+
... "default": Pipeline(pipes=[ZeroAddressTrimmer(), LowTransactionTrimmer()]),
|
|
62
|
+
... "address": Pipeline(pipes=[FeatureForAddressDeriver(
|
|
63
|
+
... value_col="ez_token_transfers_id",
|
|
64
|
+
... result_col="received_count",
|
|
65
|
+
... default_index_col="to_address",
|
|
66
|
+
... agg_func="nunique",
|
|
67
|
+
... is_numeric=False,
|
|
68
|
+
... )])
|
|
69
|
+
... }
|
|
70
|
+
... branch = Branch(pipe_map=pipe_map, merger=Merger())
|
|
71
|
+
... result = branch(sf_input)
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
merger: Merger
|
|
75
|
+
parallel: bool
|
|
76
|
+
max_workers: int
|
|
77
|
+
HANDLER_NAME = "run"
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
pipe_map: Dict[str, Pipeline | List[Pipeline]] = None,
|
|
82
|
+
merger: Merger = None,
|
|
83
|
+
parallel: bool = False,
|
|
84
|
+
max_workers: int = 16,
|
|
85
|
+
group_keys=None,
|
|
86
|
+
*args,
|
|
87
|
+
**kwargs,
|
|
88
|
+
):
|
|
89
|
+
super().__init__(group_keys, *args, **kwargs)
|
|
90
|
+
|
|
91
|
+
self.pipe_map = pipe_map
|
|
92
|
+
self.merger = merger
|
|
93
|
+
self.parallel = parallel
|
|
94
|
+
self.max_workers = max_workers
|
|
95
|
+
|
|
96
|
+
def __call__(self, sf_input: SFrame, *args, **kwargs) -> SFrame:
|
|
97
|
+
result = super().__call__(sf_input, *args, **kwargs)
|
|
98
|
+
if self.merger:
|
|
99
|
+
result = self.merger(result)
|
|
100
|
+
return result
|
|
101
|
+
|
|
102
|
+
def call_handler(self, sf: SFrame, *args, **kwargs) -> Dict[str, object]:
|
|
103
|
+
return self.run(sf, *args, **kwargs)
|
|
104
|
+
|
|
105
|
+
def run(self, sf: SFrame, *args, **kwargs):
|
|
106
|
+
if isinstance(sf, GroupSFrame):
|
|
107
|
+
if set(self.pipe_map.keys()).difference(set(sf.children.keys())):
|
|
108
|
+
raise InvalidArgumentsError(
|
|
109
|
+
"Pipeline map keys must exist in passed SFrame"
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
sf = sf.make_group(DEFAULT_SF_KEY)
|
|
113
|
+
result = (
|
|
114
|
+
self.run_parallel_pipes(sf, *args, **kwargs)
|
|
115
|
+
if self.parallel
|
|
116
|
+
else self.run_pipes(sf, *args, **kwargs)
|
|
117
|
+
)
|
|
118
|
+
return result
|
|
119
|
+
|
|
120
|
+
def run_pipes(self, sf: SFrame, *args, **kwargs):
|
|
121
|
+
results = {}
|
|
122
|
+
for sf_key, pipes in self.pipe_map.items():
|
|
123
|
+
if isinstance(pipes, list):
|
|
124
|
+
for i in range(len(pipes)):
|
|
125
|
+
results[sf_key + f"__{i}"] = pipes[i](
|
|
126
|
+
sf.get(sf_key), *args, **kwargs
|
|
127
|
+
)
|
|
128
|
+
else:
|
|
129
|
+
results[sf_key] = pipes(sf.get(sf_key), *args, **kwargs)
|
|
130
|
+
return results
|
|
131
|
+
|
|
132
|
+
def run_parallel_pipes(self, sf: SFrame, *args, **kwargs):
|
|
133
|
+
results = {}
|
|
134
|
+
|
|
135
|
+
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
|
|
136
|
+
futures = []
|
|
137
|
+
for sf_key, pipes in self.pipe_map.items():
|
|
138
|
+
if isinstance(pipes, list):
|
|
139
|
+
for i in range(len(pipes)):
|
|
140
|
+
futures.append(
|
|
141
|
+
executor.submit(
|
|
142
|
+
process_pipeline,
|
|
143
|
+
sf,
|
|
144
|
+
sf_key,
|
|
145
|
+
pipes[i],
|
|
146
|
+
i,
|
|
147
|
+
*args,
|
|
148
|
+
**kwargs,
|
|
149
|
+
)
|
|
150
|
+
)
|
|
151
|
+
else:
|
|
152
|
+
futures.append(
|
|
153
|
+
executor.submit(
|
|
154
|
+
process_pipeline, sf, sf_key, pipes, None, *args, **kwargs
|
|
155
|
+
)
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
for future in futures:
|
|
159
|
+
key, result = future.result()
|
|
160
|
+
results[key] = result
|
|
161
|
+
|
|
162
|
+
return results
|
|
File without changes
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from seshat.general import configs
|
|
2
|
+
from seshat.transformer.deriver.base import (
|
|
3
|
+
SFrameFromColsDeriver,
|
|
4
|
+
FeatureForAddressDeriver,
|
|
5
|
+
OperationOnColsDeriver,
|
|
6
|
+
PercentileTransactionValueDeriver,
|
|
7
|
+
InteractedSymbolsToSentenceDeriver,
|
|
8
|
+
)
|
|
9
|
+
from seshat.transformer.pipeline import Pipeline
|
|
10
|
+
|
|
11
|
+
address_pipeline = Pipeline(
|
|
12
|
+
[
|
|
13
|
+
SFrameFromColsDeriver(
|
|
14
|
+
cols=(configs.FROM_ADDRESS_COL, configs.TO_ADDRESS_COL),
|
|
15
|
+
result_col="address",
|
|
16
|
+
),
|
|
17
|
+
FeatureForAddressDeriver(
|
|
18
|
+
value_col="SYMBOL",
|
|
19
|
+
is_numeric=False,
|
|
20
|
+
result_col="unique_tokens_count",
|
|
21
|
+
agg_func="count",
|
|
22
|
+
),
|
|
23
|
+
FeatureForAddressDeriver(
|
|
24
|
+
value_col="AMOUNT_USD",
|
|
25
|
+
is_numeric=False,
|
|
26
|
+
result_col="sent_amount",
|
|
27
|
+
),
|
|
28
|
+
FeatureForAddressDeriver(
|
|
29
|
+
value_col="AMOUNT_USD",
|
|
30
|
+
is_numeric=False,
|
|
31
|
+
address_col=configs.TO_ADDRESS_COL,
|
|
32
|
+
result_col="received_amount",
|
|
33
|
+
),
|
|
34
|
+
FeatureForAddressDeriver(
|
|
35
|
+
value_col="SYMBOL",
|
|
36
|
+
is_numeric=False,
|
|
37
|
+
result_col="sent_symbols",
|
|
38
|
+
agg_func="unique",
|
|
39
|
+
),
|
|
40
|
+
FeatureForAddressDeriver(
|
|
41
|
+
value_col="SYMBOL",
|
|
42
|
+
is_numeric=False,
|
|
43
|
+
address_col=configs.TO_ADDRESS_COL,
|
|
44
|
+
result_col="received_symbols",
|
|
45
|
+
agg_func="unique",
|
|
46
|
+
),
|
|
47
|
+
FeatureForAddressDeriver(
|
|
48
|
+
value_col=configs.FROM_ADDRESS_COL,
|
|
49
|
+
is_numeric=False,
|
|
50
|
+
address_col=configs.TO_ADDRESS_COL,
|
|
51
|
+
result_col="received_partners",
|
|
52
|
+
agg_func="nunique",
|
|
53
|
+
),
|
|
54
|
+
FeatureForAddressDeriver(
|
|
55
|
+
value_col=configs.TO_ADDRESS_COL,
|
|
56
|
+
is_numeric=False,
|
|
57
|
+
result_col="sent_partners",
|
|
58
|
+
agg_func="nunique",
|
|
59
|
+
),
|
|
60
|
+
OperationOnColsDeriver(
|
|
61
|
+
cols=("received_amount", "sent_amount"),
|
|
62
|
+
result_col="interacted_avg_amount",
|
|
63
|
+
),
|
|
64
|
+
OperationOnColsDeriver(
|
|
65
|
+
group_keys={"default": "address"},
|
|
66
|
+
cols=("received_amount", "sent_amount"),
|
|
67
|
+
agg_func="sum",
|
|
68
|
+
result_col="interacted_total_amount",
|
|
69
|
+
),
|
|
70
|
+
PercentileTransactionValueDeriver(
|
|
71
|
+
group_keys={"default": "address"},
|
|
72
|
+
value_col="interacted_avg_amount",
|
|
73
|
+
),
|
|
74
|
+
InteractedSymbolsToSentenceDeriver(
|
|
75
|
+
symbol_col="SYMBOL",
|
|
76
|
+
result_col="sentence_symbols",
|
|
77
|
+
),
|
|
78
|
+
]
|
|
79
|
+
)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import Schema, Col
|