unaiverse 0.1.6__cp314-cp314-musllinux_1_2_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unaiverse might be problematic. Click here for more details.
- unaiverse/__init__.py +19 -0
- unaiverse/agent.py +2008 -0
- unaiverse/agent_basics.py +1846 -0
- unaiverse/clock.py +191 -0
- unaiverse/dataprops.py +1209 -0
- unaiverse/hsm.py +1880 -0
- unaiverse/modules/__init__.py +18 -0
- unaiverse/modules/cnu/__init__.py +17 -0
- unaiverse/modules/cnu/cnus.py +536 -0
- unaiverse/modules/cnu/layers.py +261 -0
- unaiverse/modules/cnu/psi.py +60 -0
- unaiverse/modules/hl/__init__.py +15 -0
- unaiverse/modules/hl/hl_utils.py +411 -0
- unaiverse/modules/networks.py +1509 -0
- unaiverse/modules/utils.py +680 -0
- unaiverse/networking/__init__.py +16 -0
- unaiverse/networking/node/__init__.py +18 -0
- unaiverse/networking/node/connpool.py +1261 -0
- unaiverse/networking/node/node.py +2223 -0
- unaiverse/networking/node/profile.py +446 -0
- unaiverse/networking/node/tokens.py +79 -0
- unaiverse/networking/p2p/__init__.py +198 -0
- unaiverse/networking/p2p/go.mod +127 -0
- unaiverse/networking/p2p/go.sum +548 -0
- unaiverse/networking/p2p/golibp2p.py +18 -0
- unaiverse/networking/p2p/golibp2p.pyi +135 -0
- unaiverse/networking/p2p/lib.go +2714 -0
- unaiverse/networking/p2p/lib.go.sha256 +1 -0
- unaiverse/networking/p2p/lib_types.py +312 -0
- unaiverse/networking/p2p/message_pb2.py +63 -0
- unaiverse/networking/p2p/messages.py +265 -0
- unaiverse/networking/p2p/mylogger.py +77 -0
- unaiverse/networking/p2p/p2p.py +929 -0
- unaiverse/networking/p2p/proto-go/message.pb.go +616 -0
- unaiverse/networking/p2p/unailib.cpython-314-aarch64-linux-musl.so +0 -0
- unaiverse/streamlib/__init__.py +15 -0
- unaiverse/streamlib/streamlib.py +210 -0
- unaiverse/streams.py +770 -0
- unaiverse/utils/__init__.py +16 -0
- unaiverse/utils/ask_lone_wolf.json +27 -0
- unaiverse/utils/lone_wolf.json +19 -0
- unaiverse/utils/misc.py +305 -0
- unaiverse/utils/sandbox.py +293 -0
- unaiverse/utils/server.py +435 -0
- unaiverse/world.py +175 -0
- unaiverse-0.1.6.dist-info/METADATA +365 -0
- unaiverse-0.1.6.dist-info/RECORD +50 -0
- unaiverse-0.1.6.dist-info/WHEEL +5 -0
- unaiverse-0.1.6.dist-info/licenses/LICENSE +43 -0
- unaiverse-0.1.6.dist-info/top_level.txt +1 -0
unaiverse/dataprops.py
ADDED
|
@@ -0,0 +1,1209 @@
|
|
|
1
|
+
"""
|
|
2
|
+
█████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
|
|
3
|
+
░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
|
|
4
|
+
░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
|
|
5
|
+
░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
|
|
6
|
+
░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
|
|
7
|
+
░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
|
|
8
|
+
░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
|
|
9
|
+
░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
|
|
10
|
+
A Collectionless AI Project (https://collectionless.ai)
|
|
11
|
+
Registration/Login: https://unaiverse.io
|
|
12
|
+
Code Repositories: https://github.com/collectionlessai/
|
|
13
|
+
Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
|
|
14
|
+
"""
|
|
15
|
+
import torch
|
|
16
|
+
from PIL import Image
|
|
17
|
+
from typing import Callable, Any
|
|
18
|
+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Data4Proc:
|
|
22
|
+
def __init__(self, *args, private_only: bool = False, public_only: bool = False, **kwargs):
|
|
23
|
+
"""Initializes a `Data4Proc` object, which is a container for one or two `DataProps` instances. It creates a
|
|
24
|
+
`DataProps` object for a private stream and, optionally, a public stream.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
*args: Variable length argument list to be passed to the `DataProps` constructor.
|
|
28
|
+
private_only: If True, only a private stream's `DataProps` is created.
|
|
29
|
+
public_only: If True, only a public stream's `DataProps` is created.
|
|
30
|
+
**kwargs: Arbitrary keyword arguments passed to the `DataProps` constructor.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If both `private_only` and `public_only` are set to True, or if the 'public' argument is passed
|
|
34
|
+
directly.
|
|
35
|
+
"""
|
|
36
|
+
self.props = []
|
|
37
|
+
if public_only and private_only:
|
|
38
|
+
raise ValueError("Cannot set both private_only and public_only to True (it does not make any sense)")
|
|
39
|
+
if 'public' in kwargs:
|
|
40
|
+
raise ValueError("Invalid argument was provided to Data4Proc: 'public' (it is an argument of DataProps)")
|
|
41
|
+
kwargs['public'] = False
|
|
42
|
+
self.props.append(DataProps(*args, **kwargs))
|
|
43
|
+
if not private_only:
|
|
44
|
+
kwargs['public'] = True
|
|
45
|
+
self.props.append(DataProps(*args, **kwargs))
|
|
46
|
+
|
|
47
|
+
def to_list_of_dicts(self):
|
|
48
|
+
"""Converts the contained `DataProps` objects into a list of dictionaries. Each dictionary represents the
|
|
49
|
+
properties of a single data stream.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
A list of dictionaries, where each dictionary holds the properties of a `DataProps` object.
|
|
53
|
+
"""
|
|
54
|
+
return [props.to_dict() for props in self.props]
|
|
55
|
+
|
|
56
|
+
def to_dict(self):
|
|
57
|
+
"""Raises a `RuntimeError` because this method is intended for a single `DataProps` object, not for the
|
|
58
|
+
container class `Data4Proc` which can hold multiple properties.
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
RuntimeError: Always, as this method is not supported.
|
|
62
|
+
"""
|
|
63
|
+
raise RuntimeError("This method can only be called on a DataProps object and not on Data4Proc")
|
|
64
|
+
|
|
65
|
+
def from_dict(self):
|
|
66
|
+
"""Raises a `RuntimeError` because this method is intended for a single `DataProps` object, not for the
|
|
67
|
+
container class `Data4Proc` which can hold multiple properties.
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
RuntimeError: Always, as this method is not supported.
|
|
71
|
+
"""
|
|
72
|
+
raise RuntimeError("This method can only be called on a DataProps object and not on Data4Proc")
|
|
73
|
+
|
|
74
|
+
def clone(self):
|
|
75
|
+
"""Creates and returns a deep copy of the `Data4Proc` object.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
A new `Data4Proc` object that is a clone of the original.
|
|
79
|
+
"""
|
|
80
|
+
ret = Data4Proc()
|
|
81
|
+
ret.props = []
|
|
82
|
+
for p in self.props:
|
|
83
|
+
ret.props.append(p.clone())
|
|
84
|
+
return ret
|
|
85
|
+
|
|
86
|
+
def is_public(self):
|
|
87
|
+
"""Raises a `RuntimeError` because this method is intended for a single `DataProps` object, not for the
|
|
88
|
+
container class `Data4Proc` which can hold multiple properties.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
RuntimeError: Always, as this method is not supported.
|
|
92
|
+
"""
|
|
93
|
+
raise RuntimeError("This method can only be called on a DataProps object and not on Data4Proc")
|
|
94
|
+
|
|
95
|
+
def __str__(self):
|
|
96
|
+
"""Provides a formatted string representation of the `Data4Proc` object. It lists the number of `DataProps`
|
|
97
|
+
objects it contains and includes the string representation of each of them.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
A string detailing the contents of the `Data4Proc` object.
|
|
101
|
+
"""
|
|
102
|
+
s = f"[Data4Proc] Number of DataProps: {len(self.props)}"
|
|
103
|
+
for p in self.props:
|
|
104
|
+
z = str(p).replace("\n", "\n\t")
|
|
105
|
+
s += "\t" + z
|
|
106
|
+
return s
|
|
107
|
+
|
|
108
|
+
def __getattr__(self, method_or_attribute_name):
|
|
109
|
+
"""Handles dynamic attribute and method access for the `Data4Proc` class. If the requested method name starts
|
|
110
|
+
with 'set_', it creates a new function that applies the corresponding setter method to all contained
|
|
111
|
+
`DataProps` objects. For any other attribute, it returns the attribute from the first `DataProps` object
|
|
112
|
+
in the list.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
method_or_attribute_name: The name of the method or attribute being accessed.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Either a callable function or the requested attribute from the first `DataProps` object.
|
|
119
|
+
"""
|
|
120
|
+
if method_or_attribute_name.startswith('set_'):
|
|
121
|
+
def apply_set_method_to_all_props(*args, **kwargs):
|
|
122
|
+
for prop in self.props:
|
|
123
|
+
getattr(prop, method_or_attribute_name)(*args, **kwargs)
|
|
124
|
+
|
|
125
|
+
return apply_set_method_to_all_props
|
|
126
|
+
else:
|
|
127
|
+
return getattr(self.props[0], method_or_attribute_name)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class DataProps:
|
|
131
|
+
"""
|
|
132
|
+
A class for handling the properties and transformations of data, including labels.
|
|
133
|
+
It supports different data types: 'tensor', 'tensor_token_id', 'img', and 'text'.
|
|
134
|
+
|
|
135
|
+
Attributes:
|
|
136
|
+
VALID_DATA_TYPES (tuple): Tuple of valid data types ('tensor', 'tensor_token_id', 'img', 'text').
|
|
137
|
+
"""
|
|
138
|
+
VALID_DATA_TYPES = ('tensor', 'img', 'text', 'all')
|
|
139
|
+
|
|
140
|
+
def __init__(self,
|
|
141
|
+
name: str = "unk",
|
|
142
|
+
group: str = "none",
|
|
143
|
+
data_type: str = "text", # Do not set tensor as default
|
|
144
|
+
data_desc: str = "unk",
|
|
145
|
+
tensor_shape: tuple[int | None, ...] | None = None,
|
|
146
|
+
tensor_labels: list[str] | str | None = None,
|
|
147
|
+
tensor_dtype: torch.dtype | str | None = None,
|
|
148
|
+
tensor_labeling_rule: str = "max",
|
|
149
|
+
stream_to_proc_transforms: Callable[..., Any] | PreTrainedTokenizerBase | str | dict | tuple[
|
|
150
|
+
dict | Callable[..., Any] | PreTrainedTokenizerBase | str | None,
|
|
151
|
+
dict | Callable[..., Any] | PreTrainedTokenizerBase | str | None] | None = None,
|
|
152
|
+
proc_to_stream_transforms: Callable[..., Any] | PreTrainedTokenizerBase | str | list | None = None,
|
|
153
|
+
delta: float = -1,
|
|
154
|
+
pubsub: bool = False,
|
|
155
|
+
public: bool = False):
|
|
156
|
+
"""Initializes a DataProps instance.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
name (str): Name of the data (default is "unk").
|
|
160
|
+
group (str): Name of the group to which this DataProps belong (default: "none").
|
|
161
|
+
data_type (str): The type of data ('tensor', 'img', or 'text').
|
|
162
|
+
data_desc (str): Description of the data (default is "unk").
|
|
163
|
+
tensor_shape (tuple[int | None] or None): Shape of the tensor data (e.g., (3, 224, 224) or (3, None, None)
|
|
164
|
+
for tensors that are variable size images). It is None for non-tensor data.
|
|
165
|
+
tensor_labels (list[str] or "AutoTokenizer:<tokenizer_model_id>" or None):
|
|
166
|
+
List of labels for the components (features) of tensor data. It can be a string representing the ID of a
|
|
167
|
+
tokenizer which is valid in AutoTokenizer (with prefix "AutoTokenizer:").
|
|
168
|
+
tensor_dtype (torch.dtype or str or None): The string representing the Pytorch dtype of the tensor data.
|
|
169
|
+
tensor_labeling_rule (str): The labeling rule for tensor data ('max' or 'geqX' where X is a number).
|
|
170
|
+
stream_to_proc_transforms (callable or PreTrainedTokenizerBase or str or dict or a list* or None):
|
|
171
|
+
A callable stream format to tensor format conversion fcn (any callable thing, torchvision transforms, a
|
|
172
|
+
pretrained tokenizer, or the model ID from which it can be downloaded (it must have prefix
|
|
173
|
+
"AutoTokenizer:"), or a vocabulary str->int). It is None for already-tensorial data.
|
|
174
|
+
*If you need to distinguish the transform applied to the inputs and to targets, you can pass a list of
|
|
175
|
+
two elements like the just described ones - one for input, one for targets, respectively.
|
|
176
|
+
proc_to_stream_transforms (callable or PreTrainedTokenizerBase or str or list or None): A callable tensor
|
|
177
|
+
to stream format function (any callable thing, torchvision transforms, a Pretrained tokenizer (HF),
|
|
178
|
+
or the model ID from which it can be downloaded, (it must have prefix "AutoTokenizer:"), or a
|
|
179
|
+
vocabulary int->str). It is None for non text data.
|
|
180
|
+
delta (float): Time interval between consecutive data samples (<= 0 for real-time data).
|
|
181
|
+
pubsub (bool): If the stream is supposed to be sent to/received from a Pub/Sub topic.
|
|
182
|
+
public (bool): If the stream is supposed to be accessed through the public net or through the private one.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
None
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
# Checking data type
|
|
189
|
+
assert data_type in DataProps.VALID_DATA_TYPES, "Invalid data type"
|
|
190
|
+
assert isinstance(data_desc, str), "Invalid data description"
|
|
191
|
+
|
|
192
|
+
# Checking transformations
|
|
193
|
+
assert (stream_to_proc_transforms is None or
|
|
194
|
+
isinstance(stream_to_proc_transforms, str) or
|
|
195
|
+
isinstance(stream_to_proc_transforms, PreTrainedTokenizerBase) or
|
|
196
|
+
callable(stream_to_proc_transforms) or
|
|
197
|
+
isinstance(stream_to_proc_transforms, dict) or
|
|
198
|
+
isinstance(stream_to_proc_transforms, tuple) or
|
|
199
|
+
isinstance(stream_to_proc_transforms, list)), \
|
|
200
|
+
"Invalid stream to processor transforms"
|
|
201
|
+
|
|
202
|
+
if stream_to_proc_transforms is not None:
|
|
203
|
+
if not isinstance(stream_to_proc_transforms, list) and not isinstance(stream_to_proc_transforms, tuple):
|
|
204
|
+
self.stream_to_proc_transforms = [stream_to_proc_transforms, stream_to_proc_transforms]
|
|
205
|
+
else:
|
|
206
|
+
assert len(stream_to_proc_transforms) == 2, \
|
|
207
|
+
"Expected a list with two sets of transforms (input, target)"
|
|
208
|
+
self.stream_to_proc_transforms = stream_to_proc_transforms
|
|
209
|
+
self.__original_stream_to_proc_transforms = stream_to_proc_transforms
|
|
210
|
+
else:
|
|
211
|
+
self.stream_to_proc_transforms = None
|
|
212
|
+
self.__original_stream_to_proc_transforms = None
|
|
213
|
+
|
|
214
|
+
assert (proc_to_stream_transforms is None or
|
|
215
|
+
isinstance(proc_to_stream_transforms, str) or
|
|
216
|
+
isinstance(proc_to_stream_transforms, PreTrainedTokenizerBase) or
|
|
217
|
+
callable(proc_to_stream_transforms) or
|
|
218
|
+
isinstance(proc_to_stream_transforms, list)), \
|
|
219
|
+
"Invalid stream to processor transforms"
|
|
220
|
+
|
|
221
|
+
self.proc_to_stream_transforms = proc_to_stream_transforms
|
|
222
|
+
self.__original_proc_to_stream_transforms = proc_to_stream_transforms
|
|
223
|
+
|
|
224
|
+
# Setting data type and description
|
|
225
|
+
self.data_type = data_type
|
|
226
|
+
self.data_desc = data_desc
|
|
227
|
+
|
|
228
|
+
# Setting empty attributes
|
|
229
|
+
self.tensor_shape = None
|
|
230
|
+
self.tensor_dtype = None
|
|
231
|
+
self.tensor_labels = None
|
|
232
|
+
|
|
233
|
+
# Checking data in function of its type
|
|
234
|
+
if self.is_tensor():
|
|
235
|
+
|
|
236
|
+
# Checking shape
|
|
237
|
+
assert (tensor_shape is not None and
|
|
238
|
+
isinstance(tensor_shape, (tuple, list))), f"Invalid shape for DataProps: {tensor_shape}"
|
|
239
|
+
assert all(x is None or isinstance(x, int) for x in tensor_shape), \
|
|
240
|
+
f"Invalid shape for DataProps: {tensor_shape}"
|
|
241
|
+
|
|
242
|
+
# Setting shape
|
|
243
|
+
self.tensor_shape = tuple(tensor_shape) # Forcing (important)
|
|
244
|
+
|
|
245
|
+
# Checking dtype
|
|
246
|
+
assert (tensor_dtype is not None and
|
|
247
|
+
(isinstance(tensor_dtype, torch.dtype) or isinstance(tensor_dtype, str)
|
|
248
|
+
and tensor_dtype.startswith("torch."))), \
|
|
249
|
+
f"Invalid tensor type: {tensor_dtype}"
|
|
250
|
+
|
|
251
|
+
# Setting dtype
|
|
252
|
+
self.tensor_dtype = tensor_dtype if isinstance(tensor_dtype, torch.dtype) else eval(tensor_dtype)
|
|
253
|
+
|
|
254
|
+
# Checking labels
|
|
255
|
+
assert tensor_labels is None or (isinstance(tensor_labels, list) or
|
|
256
|
+
(isinstance(tensor_shape, str) and
|
|
257
|
+
tensor_labels.startswith("AutoTokenizer:"))), \
|
|
258
|
+
f"Invalid tensor labels: {tensor_labels}"
|
|
259
|
+
|
|
260
|
+
# Setting labels
|
|
261
|
+
if tensor_labels is not None:
|
|
262
|
+
if not (isinstance(tensor_labels, str) and tensor_labels.startswith("AutoTokenizer:")):
|
|
263
|
+
self.tensor_labels = TensorLabels(self, labels=tensor_labels, labeling_rule=tensor_labeling_rule)
|
|
264
|
+
else:
|
|
265
|
+
self.set_tensor_labels_from_auto_tokenizer(tensor_labels.split[:][1])
|
|
266
|
+
|
|
267
|
+
elif self.is_img():
|
|
268
|
+
|
|
269
|
+
# Ensuring other type-related tools are not set
|
|
270
|
+
assert tensor_shape is None and tensor_labels is None and tensor_dtype is None, \
|
|
271
|
+
f"Tensor-related arguments must be None when using a DataProps of type {data_type}"
|
|
272
|
+
assert (self.stream_to_proc_transforms is None or (not isinstance(self.stream_to_proc_transforms, str)
|
|
273
|
+
and not isinstance(self.stream_to_proc_transforms,
|
|
274
|
+
PreTrainedTokenizerBase))), \
|
|
275
|
+
"Non-image-related transforms were selected"
|
|
276
|
+
assert (self.proc_to_stream_transforms is None or (not isinstance(self.proc_to_stream_transforms, str)
|
|
277
|
+
and not isinstance(self.proc_to_stream_transforms,
|
|
278
|
+
PreTrainedTokenizerBase)
|
|
279
|
+
and not isinstance(self.proc_to_stream_transforms,
|
|
280
|
+
list))), \
|
|
281
|
+
"Non-image-related transforms were selected"
|
|
282
|
+
|
|
283
|
+
elif self.is_text():
|
|
284
|
+
|
|
285
|
+
# Ensuring other type-related tools are not set
|
|
286
|
+
assert tensor_shape is None and tensor_labels is None and tensor_dtype is None, \
|
|
287
|
+
f"Tensor/image-related arguments must be None when using a DataProps of type {data_type}"
|
|
288
|
+
|
|
289
|
+
# Setting text to tensor transform (tokenizer in encode mode) (if given)
|
|
290
|
+
if self.stream_to_proc_transforms is not None:
|
|
291
|
+
for j, _tttt in enumerate(self.stream_to_proc_transforms):
|
|
292
|
+
assert ((isinstance(_tttt, str) and _tttt.startswith("AutoTokenizer:")) or
|
|
293
|
+
isinstance(_tttt, PreTrainedTokenizerBase) or
|
|
294
|
+
isinstance(_tttt, dict) or
|
|
295
|
+
callable(_tttt)), \
|
|
296
|
+
("Invalid text tokenizer: expected object of type PreTrainedTokenizerBase or a "
|
|
297
|
+
"string starting with 'AutoTokenizer:' or a callable object or a dictionary "
|
|
298
|
+
"(vocabulary str->int)")
|
|
299
|
+
if isinstance(_tttt, str) and _tttt.startswith("AutoTokenizer:"):
|
|
300
|
+
self.stream_to_proc_transforms[j] = AutoTokenizer.from_pretrained(_tttt.split(":")[1])
|
|
301
|
+
|
|
302
|
+
# Setting tensor to text transform (tokenizer in decode mode OR a given vocabulary int->str) (if given)
|
|
303
|
+
if self.proc_to_stream_transforms is not None:
|
|
304
|
+
assert ((isinstance(self.proc_to_stream_transforms, str) and
|
|
305
|
+
self.proc_to_stream_transforms.startswith("AutoTokenizer:")) or
|
|
306
|
+
isinstance(self.proc_to_stream_transforms, PreTrainedTokenizerBase) or
|
|
307
|
+
isinstance(self.proc_to_stream_transforms, list) or
|
|
308
|
+
callable(self.proc_to_stream_transforms)), \
|
|
309
|
+
("Invalid text tokenizer: expected object of type PreTrainedTokenizerBase or a "
|
|
310
|
+
"string starting with 'AutoTokenizer:' or a callable object or a dictionary "
|
|
311
|
+
"(vocabulary int->str)")
|
|
312
|
+
if (isinstance(self.proc_to_stream_transforms, str) and
|
|
313
|
+
self.proc_to_stream_transforms.startswith("AutoTokenizer:")):
|
|
314
|
+
self.proc_to_stream_transforms = (
|
|
315
|
+
AutoTokenizer.from_pretrained(self.proc_to_stream_transforms.split(":")[1]))
|
|
316
|
+
|
|
317
|
+
# Checking name and group
|
|
318
|
+
assert "~" not in name, "Invalid chars in stream name"
|
|
319
|
+
assert "~" not in group, "Invalid chars in group name"
|
|
320
|
+
|
|
321
|
+
# Initialize properties
|
|
322
|
+
self.name = name
|
|
323
|
+
self.group = group
|
|
324
|
+
self.delta = delta
|
|
325
|
+
self.pubsub = pubsub
|
|
326
|
+
self.public = public
|
|
327
|
+
|
|
328
|
+
def to_dict(self):
|
|
329
|
+
"""Serializes the `DataProps` object into a dictionary, making it suitable for transmission or storage.
|
|
330
|
+
It converts complex types like `torch.dtype` and `TensorLabels` into simple, serializable formats.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
A dictionary representation of the object's properties.
|
|
334
|
+
"""
|
|
335
|
+
return {
|
|
336
|
+
'name': self.name,
|
|
337
|
+
'group': self.group,
|
|
338
|
+
'data_type': self.data_type,
|
|
339
|
+
'data_desc': self.data_desc,
|
|
340
|
+
'tensor_shape': self.tensor_shape,
|
|
341
|
+
'tensor_dtype': str(self.tensor_dtype) if self.tensor_dtype is not None else None,
|
|
342
|
+
'tensor_labels': self.tensor_labels.to_dict() if self.tensor_labels is not None else None,
|
|
343
|
+
'delta': self.delta,
|
|
344
|
+
'pubsub': self.pubsub,
|
|
345
|
+
'public': self.public
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
@staticmethod
|
|
349
|
+
def from_dict(d_props):
|
|
350
|
+
"""A static method that deserializes a dictionary into a `DataProps` object.
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
d_props: The dictionary containing the object's properties.
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
A new `DataProps` object.
|
|
357
|
+
"""
|
|
358
|
+
d_labels = d_props['tensor_labels']
|
|
359
|
+
return DataProps(name=d_props['name'],
|
|
360
|
+
group=d_props['group'],
|
|
361
|
+
data_type=d_props['data_type'],
|
|
362
|
+
data_desc=d_props['data_desc'],
|
|
363
|
+
tensor_shape=d_props['tensor_shape'],
|
|
364
|
+
tensor_dtype=d_props['tensor_dtype'],
|
|
365
|
+
tensor_labels=d_labels['labels'] if d_labels is not None else None,
|
|
366
|
+
tensor_labeling_rule=d_labels['labeling_rule'] if d_labels is not None else "max",
|
|
367
|
+
delta=d_props['delta'],
|
|
368
|
+
pubsub=d_props['pubsub'],
|
|
369
|
+
public=d_props['public'])
|
|
370
|
+
|
|
371
|
+
def clone(self):
|
|
372
|
+
"""Creates and returns a deep copy of the `DataProps` instance.
|
|
373
|
+
It preserves the original transformation objects rather than re-evaluating them.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
A new `DataProps` object that is a clone of the original.
|
|
377
|
+
"""
|
|
378
|
+
return DataProps(name=self.name,
|
|
379
|
+
group=self.group,
|
|
380
|
+
data_type=self.data_type,
|
|
381
|
+
data_desc=self.data_desc,
|
|
382
|
+
tensor_shape=self.tensor_shape,
|
|
383
|
+
tensor_dtype=self.tensor_dtype,
|
|
384
|
+
tensor_labels=self.tensor_labels.labels if self.tensor_labels is not None else None,
|
|
385
|
+
tensor_labeling_rule=self.tensor_labels.original_labeling_rule
|
|
386
|
+
if self.tensor_labels is not None else "max",
|
|
387
|
+
stream_to_proc_transforms=self.__original_stream_to_proc_transforms,
|
|
388
|
+
proc_to_stream_transforms=self.__original_proc_to_stream_transforms,
|
|
389
|
+
delta=self.delta,
|
|
390
|
+
pubsub=self.pubsub,
|
|
391
|
+
public=self.public)
|
|
392
|
+
|
|
393
|
+
def get_name(self):
|
|
394
|
+
"""Retrieves the name of the data stream.
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
A string representing the stream's name.
|
|
398
|
+
"""
|
|
399
|
+
|
|
400
|
+
return self.name
|
|
401
|
+
|
|
402
|
+
def get_group(self):
|
|
403
|
+
"""Retrieves the name of the group to which the data stream belongs.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
A string representing the group name.
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
return self.group
|
|
410
|
+
|
|
411
|
+
def get_description(self):
|
|
412
|
+
"""Retrieves the description of the data.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
A string with the data description.
|
|
416
|
+
"""
|
|
417
|
+
return self.data_desc
|
|
418
|
+
|
|
419
|
+
def get_tensor_labels(self):
|
|
420
|
+
"""Retrieves the list of tensor labels, if they exist.
|
|
421
|
+
|
|
422
|
+
Returns:
|
|
423
|
+
A list of strings or None.
|
|
424
|
+
"""
|
|
425
|
+
return self.tensor_labels.labels if self.tensor_labels is not None else None
|
|
426
|
+
|
|
427
|
+
def set_name(self, name: str):
|
|
428
|
+
"""Sets a new name for the stream, with a check for invalid characters.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
name: The new name as a string.
|
|
432
|
+
"""
|
|
433
|
+
assert "~" not in name, "Invalid chars in stream name"
|
|
434
|
+
self.name = name
|
|
435
|
+
|
|
436
|
+
def set_group(self, group: str):
|
|
437
|
+
"""Sets a new group name for the stream, with a check for invalid characters.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
group: The new group name as a string.
|
|
441
|
+
"""
|
|
442
|
+
assert "~" not in group, "Invalid chars in group name"
|
|
443
|
+
self.group = group
|
|
444
|
+
|
|
445
|
+
def set_description(self, desc: str):
|
|
446
|
+
"""Sets a new description for the data.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
desc: The new description as a string.
|
|
450
|
+
"""
|
|
451
|
+
self.data_desc = desc
|
|
452
|
+
|
|
453
|
+
def set_public(self, public: bool):
|
|
454
|
+
"""Sets whether the stream is public or not.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
public: A boolean value.
|
|
458
|
+
"""
|
|
459
|
+
self.public = public
|
|
460
|
+
|
|
461
|
+
def set_pubsub(self, pubsub: bool):
|
|
462
|
+
"""Sets whether the stream uses Pub/Sub or direct messaging.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
pubsub: A boolean value.
|
|
466
|
+
"""
|
|
467
|
+
self.pubsub = pubsub
|
|
468
|
+
|
|
469
|
+
def is_tensor(self):
|
|
470
|
+
"""Checks if the data type is 'tensor'.
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
True if the type is 'tensor', False otherwise.
|
|
474
|
+
"""
|
|
475
|
+
return self.data_type == "tensor"
|
|
476
|
+
|
|
477
|
+
def is_img(self):
|
|
478
|
+
"""Checks if the data type is 'img'.
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
True if the type is 'img', False otherwise.
|
|
482
|
+
"""
|
|
483
|
+
return self.data_type == "img"
|
|
484
|
+
|
|
485
|
+
def is_text(self):
|
|
486
|
+
"""Checks if the data type is 'text'.
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
True if the type is 'text', False otherwise.
|
|
490
|
+
"""
|
|
491
|
+
return self.data_type == "text"
|
|
492
|
+
|
|
493
|
+
def is_tensor_long(self):
|
|
494
|
+
"""Checks if the tensor's data type is `torch.long`.
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
True if the dtype is `torch.long`, False otherwise.
|
|
498
|
+
"""
|
|
499
|
+
return self.tensor_dtype == torch.long if self.tensor_dtype is not None else False
|
|
500
|
+
|
|
501
|
+
def is_tensor_float(self):
|
|
502
|
+
"""Checks if the tensor's data type is a float type (e.g., `torch.float32`).
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
True if the dtype is a float type, False otherwise.
|
|
506
|
+
"""
|
|
507
|
+
return str(self.tensor_dtype).startswith("torch.float") if self.tensor_dtype is not None else False
|
|
508
|
+
|
|
509
|
+
def is_tensor_img(self):
|
|
510
|
+
"""Checks if the tensor's shape corresponds to a typical image format (4D, with 1 or 3 channels).
|
|
511
|
+
|
|
512
|
+
Returns:
|
|
513
|
+
True if the shape matches, False otherwise.
|
|
514
|
+
"""
|
|
515
|
+
return len(self.tensor_shape) == 4 and (self.tensor_shape[1] == 1 or self.tensor_shape[1] == 3) \
|
|
516
|
+
if self.tensor_shape is not None else False
|
|
517
|
+
|
|
518
|
+
def is_tensor_token_ids(self):
|
|
519
|
+
"""Checks if the tensor represents token IDs, which is indicated by `torch.long` data type and a 2D shape
|
|
520
|
+
suitable for sequences.
|
|
521
|
+
|
|
522
|
+
Returns:
|
|
523
|
+
True if it matches, False otherwise.
|
|
524
|
+
"""
|
|
525
|
+
return (self.tensor_dtype == torch.long and
|
|
526
|
+
len(self.tensor_shape) == 2 and (self.tensor_shape[1] >= 1 or self.tensor_shape[1] is None)) \
|
|
527
|
+
if self.tensor_shape is not None else False
|
|
528
|
+
|
|
529
|
+
def is_tensor_target_id(self):
|
|
530
|
+
"""Checks if the tensor represents a single target ID, indicated by `torch.long` data type and a 1D shape.
|
|
531
|
+
|
|
532
|
+
Returns:
|
|
533
|
+
True if it matches, False otherwise.
|
|
534
|
+
"""
|
|
535
|
+
return (self.tensor_dtype == torch.long and
|
|
536
|
+
len(self.tensor_shape) == 1) \
|
|
537
|
+
if self.tensor_shape is not None else False
|
|
538
|
+
|
|
539
|
+
def is_all(self):
|
|
540
|
+
"""Checks if the data type is 'all', which is a wildcard type.
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
True if the type is 'all', False otherwise.
|
|
544
|
+
"""
|
|
545
|
+
return self.data_type == "all"
|
|
546
|
+
|
|
547
|
+
def net_hash(self, prefix: str):
|
|
548
|
+
"""Generates a unique network hash for the stream using a provided prefix, Pub/Sub status, and name/group.
|
|
549
|
+
|
|
550
|
+
Args:
|
|
551
|
+
prefix: The prefix, typically the peer ID.
|
|
552
|
+
|
|
553
|
+
Returns:
|
|
554
|
+
A string representing the network hash.
|
|
555
|
+
"""
|
|
556
|
+
return DataProps.build_net_hash(prefix, self.pubsub, self.name_or_group())
|
|
557
|
+
|
|
558
|
+
@staticmethod
|
|
559
|
+
def peer_id_from_net_hash(net_hash):
|
|
560
|
+
"""A static method to extract the peer ID from a network hash.
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
net_hash: The network hash string.
|
|
564
|
+
|
|
565
|
+
Returns:
|
|
566
|
+
A string representing the peer ID.
|
|
567
|
+
"""
|
|
568
|
+
return net_hash.split("::")[0]
|
|
569
|
+
|
|
570
|
+
@staticmethod
|
|
571
|
+
def name_or_group_from_net_hash(net_hash):
|
|
572
|
+
"""A static method to extract the name or group from a network hash.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
net_hash: The network hash string.
|
|
576
|
+
|
|
577
|
+
Returns:
|
|
578
|
+
A string representing the name or group.
|
|
579
|
+
"""
|
|
580
|
+
return net_hash.split("::ps:")[1] if DataProps.is_pubsub_from_net_hash(net_hash) else net_hash.split("::dm:")[1]
|
|
581
|
+
|
|
582
|
+
@staticmethod
|
|
583
|
+
def is_pubsub_from_net_hash(net_hash):
|
|
584
|
+
"""A static method to check if a network hash belongs to a Pub/Sub stream.
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
net_hash: The network hash string.
|
|
588
|
+
|
|
589
|
+
Returns:
|
|
590
|
+
True if the hash is for a Pub/Sub stream, False otherwise.
|
|
591
|
+
"""
|
|
592
|
+
return "::ps:" in net_hash
|
|
593
|
+
|
|
594
|
+
def name_or_group(self):
|
|
595
|
+
"""Retrieves the group name if it's set, otherwise defaults to the stream name.
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
A string representing the name or group.
|
|
599
|
+
"""
|
|
600
|
+
group = self.get_group()
|
|
601
|
+
return group if group != 'none' else self.get_name()
|
|
602
|
+
|
|
603
|
+
@staticmethod
|
|
604
|
+
def build_net_hash(prefix: str, pubsub: bool, name_or_group: str):
|
|
605
|
+
"""A static method to construct a complete network hash from a prefix, Pub/Sub status, and name/group.
|
|
606
|
+
|
|
607
|
+
Args:
|
|
608
|
+
prefix: The peer ID prefix.
|
|
609
|
+
pubsub: The Pub/Sub status.
|
|
610
|
+
name_or_group: The name or group of the stream.
|
|
611
|
+
|
|
612
|
+
Returns:
|
|
613
|
+
The constructed network hash string.
|
|
614
|
+
"""
|
|
615
|
+
if pubsub:
|
|
616
|
+
return f"{prefix}::ps:{name_or_group}"
|
|
617
|
+
else:
|
|
618
|
+
return f"{prefix}::dm:{name_or_group}"
|
|
619
|
+
|
|
620
|
+
@staticmethod
|
|
621
|
+
def normalize_net_hash(not_normalized_net_hash: str):
|
|
622
|
+
"""A static method that cleans up or normalizes a network hash string to a canonical format, particularly
|
|
623
|
+
for direct messages.
|
|
624
|
+
|
|
625
|
+
Args:
|
|
626
|
+
not_normalized_net_hash: The network hash to normalize.
|
|
627
|
+
|
|
628
|
+
Returns:
|
|
629
|
+
The normalized network hash string.
|
|
630
|
+
"""
|
|
631
|
+
if not DataProps.is_pubsub_from_net_hash(not_normalized_net_hash):
|
|
632
|
+
if "~" in not_normalized_net_hash:
|
|
633
|
+
return not_normalized_net_hash.split("::dm:")[0] + "::dm:" + not_normalized_net_hash.split("~")[1]
|
|
634
|
+
else:
|
|
635
|
+
parts = not_normalized_net_hash.split("::dm:")
|
|
636
|
+
return parts[0] + "::dm:" + parts[1].split("-")[1]
|
|
637
|
+
else:
|
|
638
|
+
return not_normalized_net_hash
|
|
639
|
+
|
|
640
|
+
def is_pubsub(self):
|
|
641
|
+
"""Checks if the stream is set to use Pub/Sub.
|
|
642
|
+
|
|
643
|
+
Returns:
|
|
644
|
+
True if it's a Pub/Sub stream, False otherwise.
|
|
645
|
+
"""
|
|
646
|
+
return self.pubsub
|
|
647
|
+
|
|
648
|
+
def is_public(self):
|
|
649
|
+
"""Checks if the stream is set to be public.
|
|
650
|
+
|
|
651
|
+
Returns:
|
|
652
|
+
True if it's a public stream, False otherwise.
|
|
653
|
+
"""
|
|
654
|
+
return self.public
|
|
655
|
+
|
|
656
|
+
def set_tensor_labels_from_auto_tokenizer(self, model_id):
|
|
657
|
+
"""Initializes and sets the tensor labels by fetching the vocabulary from a Hugging Face `AutoTokenizer`
|
|
658
|
+
model ID.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
model_id: The ID of the tokenizer model.
|
|
662
|
+
"""
|
|
663
|
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
664
|
+
vocab_size = len(tokenizer.vocab)
|
|
665
|
+
reverse_vocab_list: list[str | None] = [None] * vocab_size
|
|
666
|
+
for i in range(vocab_size):
|
|
667
|
+
reverse_vocab_list[i] = tokenizer.convert_ids_to_tokens(i)
|
|
668
|
+
self.set_tensor_labels(reverse_vocab_list)
|
|
669
|
+
|
|
670
|
+
def set_tensor_labels(self, labels: list[str] | None, labeling_rule: str = "max"):
|
|
671
|
+
"""Sets the labels for the data.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
labels (list[str] or None): List of labels to associate with the data.
|
|
675
|
+
labeling_rule (str): The labeling rule for the labels.
|
|
676
|
+
|
|
677
|
+
Returns:
|
|
678
|
+
None
|
|
679
|
+
"""
|
|
680
|
+
self.tensor_labels = TensorLabels(self, labels=labels, labeling_rule=labeling_rule)
|
|
681
|
+
|
|
682
|
+
def adapt_tensor_to_tensor_labels(self, data: torch.Tensor) -> torch.Tensor:
|
|
683
|
+
"""Interleaves data in function of its corresponding labels and the current super-set labels.
|
|
684
|
+
|
|
685
|
+
Args:
|
|
686
|
+
data (torch.Tensor): The data tensor to interleave.
|
|
687
|
+
|
|
688
|
+
Returns:
|
|
689
|
+
torch.Tensor: The interleaved data tensor.
|
|
690
|
+
"""
|
|
691
|
+
if self.is_tensor():
|
|
692
|
+
num_labels = len(self.tensor_labels) if self.tensor_labels is not None else 0
|
|
693
|
+
if num_labels > 0 and data.shape[1] < num_labels and self.tensor_labels.indices is not None:
|
|
694
|
+
data_padded = torch.zeros((data.shape[0], num_labels), device=data.device, dtype=data.dtype)
|
|
695
|
+
data_padded[:, self.tensor_labels.indices] = data
|
|
696
|
+
return data_padded
|
|
697
|
+
else:
|
|
698
|
+
return data # Do nothing
|
|
699
|
+
else:
|
|
700
|
+
return data # Do nothing
|
|
701
|
+
|
|
702
|
+
def clear_label_adaptation(self, data: torch.Tensor):
|
|
703
|
+
"""Removes the padding and returns the original data from an adapted tensor.
|
|
704
|
+
|
|
705
|
+
Args:
|
|
706
|
+
data: The adapted tensor.
|
|
707
|
+
|
|
708
|
+
Returns:
|
|
709
|
+
The original, un-padded tensor.
|
|
710
|
+
"""
|
|
711
|
+
return data[:, self.tensor_labels.indices] if self.tensor_labels.indices is not None else data
|
|
712
|
+
|
|
713
|
+
def is_flat_tensor_with_labels(self):
|
|
714
|
+
"""Checks if the tensor is a 2D array and has labels, which is a common structure for general feature data.
|
|
715
|
+
|
|
716
|
+
Returns:
|
|
717
|
+
True if it is, False otherwise.
|
|
718
|
+
"""
|
|
719
|
+
return self.is_tensor() and len(self.tensor_shape) == 2 and self.has_tensor_labels()
|
|
720
|
+
|
|
721
|
+
def has_tensor_labels(self):
|
|
722
|
+
"""Checks if any tensor labels are associated with the stream.
|
|
723
|
+
|
|
724
|
+
Returns:
|
|
725
|
+
True if labels exist, False otherwise.
|
|
726
|
+
"""
|
|
727
|
+
return self.tensor_labels is not None and len(self.tensor_labels) > 0
|
|
728
|
+
|
|
729
|
+
def to_text(self, data: torch.Tensor | str):
|
|
730
|
+
"""Converts the tensor data into a text-based representation exploiting the given labels and the labeling rule.
|
|
731
|
+
|
|
732
|
+
Args:
|
|
733
|
+
data (torch.Tensor or str): The data tensor to convert into text (if a string, then pass-through only).
|
|
734
|
+
|
|
735
|
+
Returns:
|
|
736
|
+
str or None: The corresponding text representation of the data.
|
|
737
|
+
|
|
738
|
+
Raises:
|
|
739
|
+
ValueError: If the data type is not supported for conversion.
|
|
740
|
+
"""
|
|
741
|
+
if isinstance(data, str):
|
|
742
|
+
return data
|
|
743
|
+
elif not isinstance(data, torch.Tensor):
|
|
744
|
+
return None
|
|
745
|
+
elif len(data.shape) > 2: # Can only print 1d data (recall that 1d data has 2 dimensions, due to batch size)
|
|
746
|
+
return None
|
|
747
|
+
|
|
748
|
+
if data.shape[0] != 1:
|
|
749
|
+
return None # "Code designed for a batch of only 1 element
|
|
750
|
+
|
|
751
|
+
if self.is_tensor():
|
|
752
|
+
if not self.has_tensor_labels():
|
|
753
|
+
return None
|
|
754
|
+
|
|
755
|
+
if self.is_tensor_token_ids():
|
|
756
|
+
|
|
757
|
+
# This is the case in which we assume to have a vector of token IDs
|
|
758
|
+
text = ""
|
|
759
|
+
for i in range(0, data.shape[1]):
|
|
760
|
+
if i > 0:
|
|
761
|
+
text += " "
|
|
762
|
+
text += self.tensor_labels[data[0][i].item()]
|
|
763
|
+
return text
|
|
764
|
+
|
|
765
|
+
elif self.is_tensor_float():
|
|
766
|
+
|
|
767
|
+
# This is the generic case of a 1d tensor
|
|
768
|
+
if self.tensor_labels.labeling_rule == "max":
|
|
769
|
+
j = torch.argmax(data, dim=1)
|
|
770
|
+
return self.tensor_labels[j.item()]
|
|
771
|
+
elif self.tensor_labels.labeling_rule == "geq":
|
|
772
|
+
|
|
773
|
+
# Warning: does not work for mini-batches
|
|
774
|
+
jj = torch.where(data >= self.tensor_labels.labeling_rule_thres)[1]
|
|
775
|
+
return ", ".join(self.tensor_labels[j] for j in jj.tolist())
|
|
776
|
+
else:
|
|
777
|
+
return None
|
|
778
|
+
|
|
779
|
+
elif self.is_text():
|
|
780
|
+
if self.proc_to_stream_transforms is None:
|
|
781
|
+
return None
|
|
782
|
+
if isinstance(self.proc_to_stream_transforms, PreTrainedTokenizerBase):
|
|
783
|
+
return self.proc_to_stream_transforms.decode(data[0])
|
|
784
|
+
elif isinstance(self.proc_to_stream_transforms, dict):
|
|
785
|
+
if data.dtype != torch.long:
|
|
786
|
+
|
|
787
|
+
# This is the case of probabilities
|
|
788
|
+
j = torch.argmax(data, dim=1) # Warning: does not work for mini-batches
|
|
789
|
+
return self.proc_to_stream_transforms[j.item()]
|
|
790
|
+
else:
|
|
791
|
+
|
|
792
|
+
# This is the case in which we assume to have a vector of token IDs
|
|
793
|
+
text = ""
|
|
794
|
+
for i in range(0, data.shape[1]):
|
|
795
|
+
if i > 0:
|
|
796
|
+
text += " "
|
|
797
|
+
text += self.proc_to_stream_transforms[data[0][i].item()]
|
|
798
|
+
return text
|
|
799
|
+
else:
|
|
800
|
+
return self.proc_to_stream_transforms(data)
|
|
801
|
+
else:
|
|
802
|
+
return None
|
|
803
|
+
|
|
804
|
+
def check_and_preprocess(self, data: str | Image.Image | torch.Tensor,
|
|
805
|
+
allow_class_ids: bool = False, targets: bool = False,
|
|
806
|
+
device: torch.device = torch.device("cpu")):
|
|
807
|
+
"""Prepares incoming data for a processor by validating its type and applying necessary transformations.
|
|
808
|
+
It handles different data types, including tensors, text (strings), and images, raising `ValueError` if
|
|
809
|
+
the data type is unexpected or incompatible with the stream's properties. For text and images, it can apply a
|
|
810
|
+
pre-configured transformation (like a tokenizer or a standard image transform) to convert the data into a
|
|
811
|
+
tensor format suitable for processing. For tensors, it performs validation on shape and data type.
|
|
812
|
+
|
|
813
|
+
Args:
|
|
814
|
+
data: The data sample to check and preprocess.
|
|
815
|
+
allow_class_ids: A boolean to allow single-element long tensors, typically for class IDs.
|
|
816
|
+
targets: A boolean to indicate if the data is a target (used to select the correct transformation in a
|
|
817
|
+
dual-transform setup).
|
|
818
|
+
device: The PyTorch device (e.g., 'cpu' or 'cuda') to which the tensor should be moved.
|
|
819
|
+
|
|
820
|
+
Returns:
|
|
821
|
+
The preprocessed data, typically a tensor on the specified device.
|
|
822
|
+
"""
|
|
823
|
+
if self.is_tensor():
|
|
824
|
+
if isinstance(data, torch.Tensor):
|
|
825
|
+
|
|
826
|
+
# Skipping all checks, it is enough to know it is a tensor
|
|
827
|
+
if allow_class_ids and data.dtype == torch.long and len(data.shape) == 1:
|
|
828
|
+
return data.to(device)
|
|
829
|
+
|
|
830
|
+
# Checking dtype
|
|
831
|
+
if self.tensor_dtype != data.dtype:
|
|
832
|
+
raise ValueError(f"Expected data of type {self.tensor_dtype}, got {data.dtype} ("
|
|
833
|
+
f"shape {data.shape})")
|
|
834
|
+
|
|
835
|
+
# Checking shape
|
|
836
|
+
if len(self.tensor_shape) != len(data.shape):
|
|
837
|
+
raise ValueError(f"Expected data with shape {self.tensor_shape}, got {data.shape}")
|
|
838
|
+
for i, s in enumerate(self.tensor_shape):
|
|
839
|
+
if s is not None:
|
|
840
|
+
if s != data.shape[i]:
|
|
841
|
+
raise ValueError(f"Expected data with shape {self.tensor_shape}, got {data.shape}")
|
|
842
|
+
|
|
843
|
+
# Checking labels
|
|
844
|
+
if self.has_tensor_labels():
|
|
845
|
+
if data.ndim != 2:
|
|
846
|
+
raise ValueError("Only 2d tensors are expected for "
|
|
847
|
+
"labeled attributes (1st dimension is batch dim)")
|
|
848
|
+
if not (self.is_tensor_token_ids() or data.shape[1] == self.tensor_labels.num_labels):
|
|
849
|
+
raise ValueError(f"Expected data with {self.tensor_labels.num_labels} "
|
|
850
|
+
f"components (ignoring the 1st dimension), "
|
|
851
|
+
f"got {data[0].numel()}")
|
|
852
|
+
|
|
853
|
+
return data.to(device)
|
|
854
|
+
else:
|
|
855
|
+
raise ValueError(f"Expecting tensor data, got {type(data)}")
|
|
856
|
+
elif self.is_text():
|
|
857
|
+
if isinstance(data, str):
|
|
858
|
+
if self.stream_to_proc_transforms is not None:
|
|
859
|
+
text_to_tensor_transform = self.stream_to_proc_transforms[int(targets)]
|
|
860
|
+
if text_to_tensor_transform is not None:
|
|
861
|
+
if isinstance(text_to_tensor_transform, PreTrainedTokenizerBase):
|
|
862
|
+
return text_to_tensor_transform(data, return_tensors='pt')['input_ids'].to(device) # Tok
|
|
863
|
+
elif isinstance(text_to_tensor_transform, dict):
|
|
864
|
+
return torch.tensor(text_to_tensor_transform[data]
|
|
865
|
+
if data in text_to_tensor_transform else len(text_to_tensor_transform),
|
|
866
|
+
dtype=torch.long, device=device).view(1, -1) # Warning batch size 1
|
|
867
|
+
else:
|
|
868
|
+
return text_to_tensor_transform(data).to(device) # Custom callable function
|
|
869
|
+
else:
|
|
870
|
+
return data
|
|
871
|
+
else:
|
|
872
|
+
return data
|
|
873
|
+
else:
|
|
874
|
+
raise ValueError(f"Expecting text (string) data, got {type(data)}")
|
|
875
|
+
elif self.is_img():
|
|
876
|
+
if isinstance(data, Image.Image):
|
|
877
|
+
if self.stream_to_proc_transforms is not None:
|
|
878
|
+
img_to_tensor_transform = self.stream_to_proc_transforms[int(targets)]
|
|
879
|
+
if img_to_tensor_transform is not None:
|
|
880
|
+
return img_to_tensor_transform(data).to(device)
|
|
881
|
+
else:
|
|
882
|
+
return data
|
|
883
|
+
else:
|
|
884
|
+
return data
|
|
885
|
+
else:
|
|
886
|
+
raise ValueError(f"Expecting image (PIL.Image) data, got {type(data)}")
|
|
887
|
+
elif self.is_all():
|
|
888
|
+
return data
|
|
889
|
+
else:
|
|
890
|
+
raise ValueError(f"Unexpected data type, {self.data_type}")
|
|
891
|
+
|
|
892
|
+
def check_and_postprocess(self, data: str | Image.Image | torch.Tensor):
|
|
893
|
+
"""Takes a processor's output and validates it before converting it back into a stream-compatible format.
|
|
894
|
+
It handles `torch.Tensor` data, applying a `proc_to_stream_transform` (if one exists) to convert the tensor
|
|
895
|
+
into an appropriate format for the stream, such as a string for text or a PIL `Image` for images. It performs
|
|
896
|
+
a final check on the data's format (shape, dtype, etc.) to ensure consistency with the stream's properties.
|
|
897
|
+
|
|
898
|
+
Args:
|
|
899
|
+
data: The output from the processor, typically a tensor.
|
|
900
|
+
|
|
901
|
+
Returns:
|
|
902
|
+
The post-processed data, in a stream-compatible format (e.g., a string, image, or CPU tensor).
|
|
903
|
+
"""
|
|
904
|
+
if self.is_tensor():
|
|
905
|
+
if isinstance(data, torch.Tensor):
|
|
906
|
+
if self.proc_to_stream_transforms is not None:
|
|
907
|
+
data = self.proc_to_stream_transforms(data)
|
|
908
|
+
data = data.cpu()
|
|
909
|
+
|
|
910
|
+
# Checking dtype
|
|
911
|
+
if self.tensor_dtype != data.dtype:
|
|
912
|
+
raise ValueError(f"Expected data of type {self.tensor_dtype}, got {data.dtype}")
|
|
913
|
+
|
|
914
|
+
# Checking shape
|
|
915
|
+
if len(self.tensor_shape) != len(data.shape):
|
|
916
|
+
raise ValueError(f"Expected data with shape {self.tensor_shape}, got {data.shape}")
|
|
917
|
+
for i, s in enumerate(self.tensor_shape):
|
|
918
|
+
if s is not None:
|
|
919
|
+
if s != data.shape[i]:
|
|
920
|
+
raise ValueError(f"Expected data with shape {self.tensor_shape}, got {data.shape}")
|
|
921
|
+
|
|
922
|
+
# Checking labels
|
|
923
|
+
if self.has_tensor_labels():
|
|
924
|
+
if data.ndim != 2:
|
|
925
|
+
raise ValueError("Only 2d tensors are expected for "
|
|
926
|
+
"labeled attributes (1st dimension is batch dim)")
|
|
927
|
+
if not (self.is_tensor_token_ids() or data.shape[1] == self.tensor_labels.num_labels):
|
|
928
|
+
raise ValueError(f"Expected data with {self.tensor_labels.num_labels} "
|
|
929
|
+
f"components (ignoring the 1st dimension), "
|
|
930
|
+
f"got {data[0].numel()}")
|
|
931
|
+
|
|
932
|
+
return data
|
|
933
|
+
else:
|
|
934
|
+
raise ValueError(f"Expecting tensor data, got {type(data)}")
|
|
935
|
+
elif self.is_text():
|
|
936
|
+
if isinstance(data, str):
|
|
937
|
+
return data
|
|
938
|
+
elif isinstance(data, torch.Tensor):
|
|
939
|
+
data = data.cpu()
|
|
940
|
+
if self.proc_to_stream_transforms is not None:
|
|
941
|
+
assert data.shape[0] == 1, f"Code designed for a batch of only 1 element, got {data.shape[0]}"
|
|
942
|
+
if isinstance(self.proc_to_stream_transforms, PreTrainedTokenizerBase):
|
|
943
|
+
return self.proc_to_stream_transforms.decode(data[0]) # Tokenizer
|
|
944
|
+
elif isinstance(self.proc_to_stream_transforms, list):
|
|
945
|
+
if data.dtype != torch.long:
|
|
946
|
+
|
|
947
|
+
# This is the case of probabilities
|
|
948
|
+
j = torch.argmax(data, dim=1) # Warning: does not work for mini-batches
|
|
949
|
+
return self.proc_to_stream_transforms[j.item()]
|
|
950
|
+
else:
|
|
951
|
+
|
|
952
|
+
# This is the case in which we assume to have a vector of token IDs
|
|
953
|
+
text = ""
|
|
954
|
+
for i in range(0, data.shape[1]):
|
|
955
|
+
if i > 0:
|
|
956
|
+
text += " "
|
|
957
|
+
text += self.proc_to_stream_transforms[data[0][i].item()]
|
|
958
|
+
return text
|
|
959
|
+
else:
|
|
960
|
+
return self.proc_to_stream_transforms(data) # Custom callable function
|
|
961
|
+
else:
|
|
962
|
+
raise ValueError(f"Cannot decode torch.Tensor to text, since text_to_tensor_inv_transform is None")
|
|
963
|
+
else:
|
|
964
|
+
raise ValueError(f"Expecting text (string) or tensor data, got {type(data)}")
|
|
965
|
+
elif self.is_img():
|
|
966
|
+
if isinstance(data, Image.Image):
|
|
967
|
+
return data
|
|
968
|
+
elif isinstance(data, torch.Tensor):
|
|
969
|
+
data = data.cpu()
|
|
970
|
+
if self.proc_to_stream_transforms is not None:
|
|
971
|
+
return self.proc_to_stream_transforms(data)
|
|
972
|
+
else:
|
|
973
|
+
raise ValueError(f"Cannot convert a tensor to PIL.Image, since img_to_tensor_inv_transform is None")
|
|
974
|
+
else:
|
|
975
|
+
raise ValueError(f"Expecting image (PIL.Image) data or torch.Tensor, got {type(data)}")
|
|
976
|
+
elif self.is_all():
|
|
977
|
+
return data
|
|
978
|
+
else:
|
|
979
|
+
raise ValueError(f"Unexpected data type, {self.data_type}")
|
|
980
|
+
|
|
981
|
+
def is_compatible(self, props_to_compare: 'DataProps') -> bool:
|
|
982
|
+
"""Checks if the current DataProps instance is compatible with another DataProps instance.
|
|
983
|
+
Checks include data type, shape, and labels.
|
|
984
|
+
|
|
985
|
+
Args:
|
|
986
|
+
props_to_compare (DataProps): The DataProps instance to check compatibility with.
|
|
987
|
+
|
|
988
|
+
Returns:
|
|
989
|
+
bool: True if compatible, False otherwise.
|
|
990
|
+
"""
|
|
991
|
+
|
|
992
|
+
# Checking data type
|
|
993
|
+
if self.data_type != props_to_compare.data_type and self.data_type != "all":
|
|
994
|
+
return False
|
|
995
|
+
|
|
996
|
+
# In the case of tensors...
|
|
997
|
+
if self.is_tensor():
|
|
998
|
+
|
|
999
|
+
# Checking shape
|
|
1000
|
+
if len(self.tensor_shape) == len(props_to_compare.tensor_shape):
|
|
1001
|
+
for s, p in zip(self.tensor_shape, props_to_compare.tensor_shape):
|
|
1002
|
+
if s is not None and p is not None and s != p:
|
|
1003
|
+
return False
|
|
1004
|
+
else:
|
|
1005
|
+
return False
|
|
1006
|
+
|
|
1007
|
+
# Checking labels (if possible)
|
|
1008
|
+
if (not self.has_tensor_labels()) or (not props_to_compare.has_tensor_labels()):
|
|
1009
|
+
return True
|
|
1010
|
+
else:
|
|
1011
|
+
return self.tensor_labels == props_to_compare.tensor_labels
|
|
1012
|
+
else:
|
|
1013
|
+
return True
|
|
1014
|
+
|
|
1015
|
+
def __str__(self):
|
|
1016
|
+
"""Provides a string representation of the DataProps instance.
|
|
1017
|
+
|
|
1018
|
+
Returns:
|
|
1019
|
+
str: The string representation of the instance.
|
|
1020
|
+
"""
|
|
1021
|
+
return f"[DataProps]\n{self.to_dict()}"
|
|
1022
|
+
|
|
1023
|
+
|
|
1024
|
+
class TensorLabels:
|
|
1025
|
+
"""
|
|
1026
|
+
A class to manage labels associated with data and perform operations on them.
|
|
1027
|
+
|
|
1028
|
+
Attributes:
|
|
1029
|
+
VALID_LABELING_RULES (tuple): Tuple of valid labeling rules ('max', 'geq').
|
|
1030
|
+
"""
|
|
1031
|
+
|
|
1032
|
+
VALID_LABELING_RULES = ('max', 'geq')
|
|
1033
|
+
|
|
1034
|
+
def __init__(self, data_props: DataProps, labels: list[str] | None, labeling_rule: str = "max"):
|
|
1035
|
+
"""Initializes the TensorLabels instance.
|
|
1036
|
+
|
|
1037
|
+
Args:
|
|
1038
|
+
data_props (DataProps): The DataProps instance that owns these labels.
|
|
1039
|
+
labels (list[str] or None): List of labels.
|
|
1040
|
+
labeling_rule (str): The rule for labeling (either 'max' or 'geqX', where X is a number).
|
|
1041
|
+
|
|
1042
|
+
Returns:
|
|
1043
|
+
None
|
|
1044
|
+
|
|
1045
|
+
Raises:
|
|
1046
|
+
AssertionError: If the labels or labeling_rule are invalid.
|
|
1047
|
+
"""
|
|
1048
|
+
assert data_props.is_tensor(), "Tensor labels can only be attached to tensor data properties"
|
|
1049
|
+
num_labels = len(labels) if labels is not None else 0
|
|
1050
|
+
assert num_labels == 0 or (data_props.is_tensor() and len(data_props.tensor_shape) == 2), \
|
|
1051
|
+
"Data attribute labels can only be specified for 2d arrays (batch size + data features)"
|
|
1052
|
+
assert len(labeling_rule) >= 3 and labeling_rule[0:3] in TensorLabels.VALID_LABELING_RULES, \
|
|
1053
|
+
"Invalid labeling rule"
|
|
1054
|
+
try:
|
|
1055
|
+
original_labeling_rule = labeling_rule
|
|
1056
|
+
if len(labeling_rule) > 3:
|
|
1057
|
+
labeling_rule_thres = float(labeling_rule[3:])
|
|
1058
|
+
labeling_rule = labeling_rule[0:3]
|
|
1059
|
+
else:
|
|
1060
|
+
labeling_rule_thres = None
|
|
1061
|
+
except ValueError:
|
|
1062
|
+
assert False, "Invalid labeling rule"
|
|
1063
|
+
|
|
1064
|
+
# Basic attributes
|
|
1065
|
+
self.data_props = data_props
|
|
1066
|
+
self.labels = labels
|
|
1067
|
+
self.labeling_rule = labeling_rule
|
|
1068
|
+
self.labeling_rule_thres = labeling_rule_thres
|
|
1069
|
+
self.original_labeling_rule = original_labeling_rule
|
|
1070
|
+
|
|
1071
|
+
# These are mostly operational stuff, similar to private info (but it could be useful to expose them)
|
|
1072
|
+
self.num_labels = num_labels
|
|
1073
|
+
self.indices = None
|
|
1074
|
+
|
|
1075
|
+
def to_dict(self):
|
|
1076
|
+
"""Serializes the `TensorLabels` instance into a dictionary, which includes the list of labels and the original
|
|
1077
|
+
labeling rule.
|
|
1078
|
+
|
|
1079
|
+
Returns:
|
|
1080
|
+
A dictionary containing the labels and the original labeling rule.
|
|
1081
|
+
"""
|
|
1082
|
+
return {
|
|
1083
|
+
'labels': self.labels,
|
|
1084
|
+
'labeling_rule': self.original_labeling_rule
|
|
1085
|
+
}
|
|
1086
|
+
|
|
1087
|
+
def clear_indices(self):
|
|
1088
|
+
"""Resets the internal `indices` attribute to `None`. This effectively clears any previous label adaptation
|
|
1089
|
+
that was performed and allows the object to revert to its original, non-interleaved state.
|
|
1090
|
+
"""
|
|
1091
|
+
self.indices = None
|
|
1092
|
+
|
|
1093
|
+
def __getitem__(self, idx):
|
|
1094
|
+
"""Retrieves the label at the specified index.
|
|
1095
|
+
|
|
1096
|
+
Args:
|
|
1097
|
+
idx (int): The index of the label to retrieve.
|
|
1098
|
+
|
|
1099
|
+
Returns:
|
|
1100
|
+
str: The label at the specified index.
|
|
1101
|
+
|
|
1102
|
+
Raises:
|
|
1103
|
+
ValueError: If the index is out of bounds or labels are not defined.
|
|
1104
|
+
"""
|
|
1105
|
+
if self.labels is None:
|
|
1106
|
+
raise ValueError(f"Cannot retrieve any labels, since they are not there at all (None)")
|
|
1107
|
+
if idx < 0 or idx >= self.num_labels:
|
|
1108
|
+
raise ValueError(f"Invalid index {idx} for attribute labels of size {self.num_labels}")
|
|
1109
|
+
return self.labels[idx]
|
|
1110
|
+
|
|
1111
|
+
def __len__(self):
|
|
1112
|
+
"""Returns the number of labels.
|
|
1113
|
+
|
|
1114
|
+
Returns:
|
|
1115
|
+
int: The number of labels.
|
|
1116
|
+
"""
|
|
1117
|
+
return self.num_labels
|
|
1118
|
+
|
|
1119
|
+
def __iter__(self):
|
|
1120
|
+
"""Iterates over the labels.
|
|
1121
|
+
|
|
1122
|
+
Returns:
|
|
1123
|
+
iterator: An iterator over the labels.
|
|
1124
|
+
"""
|
|
1125
|
+
return iter(self.labels) if self.labels is not None else iter([])
|
|
1126
|
+
|
|
1127
|
+
def __str__(self):
|
|
1128
|
+
"""Provides a string representation of the DataLabels instance.
|
|
1129
|
+
|
|
1130
|
+
Returns:
|
|
1131
|
+
str: The string representation of the instance.
|
|
1132
|
+
"""
|
|
1133
|
+
return (f"[TensorLabels] "
|
|
1134
|
+
f"labels: {self.labels}, labeling_rule: {self.labeling_rule}, "
|
|
1135
|
+
f"indices_in_superset: {self.indices})")
|
|
1136
|
+
|
|
1137
|
+
def __eq__(self, other):
|
|
1138
|
+
"""Defines how two `TensorLabels` instances are compared for equality using the `==` operator. Two instances
|
|
1139
|
+
are considered equal if they have the same number of labels and the labels themselves match in order.
|
|
1140
|
+
|
|
1141
|
+
Args:
|
|
1142
|
+
other: The other object to compare with.
|
|
1143
|
+
|
|
1144
|
+
Returns:
|
|
1145
|
+
True if the instances are equal, False otherwise.
|
|
1146
|
+
"""
|
|
1147
|
+
if not isinstance(other, TensorLabels):
|
|
1148
|
+
return ValueError("Cannot compare a TensorLabels instance and something else")
|
|
1149
|
+
|
|
1150
|
+
if self.num_labels == other.num_labels:
|
|
1151
|
+
if self.num_labels > 0:
|
|
1152
|
+
for i, j in zip(self.labels, other.labels):
|
|
1153
|
+
if i != j:
|
|
1154
|
+
return False
|
|
1155
|
+
return True
|
|
1156
|
+
else:
|
|
1157
|
+
return True
|
|
1158
|
+
else:
|
|
1159
|
+
return False
|
|
1160
|
+
|
|
1161
|
+
def interleave_with(self, superset_labels: list[str]):
|
|
1162
|
+
"""Interleaves the current labels with a super-set of labels, determining how to index them.
|
|
1163
|
+
|
|
1164
|
+
Args:
|
|
1165
|
+
superset_labels (list[str]): The super-set of labels to interleave with.
|
|
1166
|
+
|
|
1167
|
+
Returns:
|
|
1168
|
+
None
|
|
1169
|
+
|
|
1170
|
+
Raises:
|
|
1171
|
+
AssertionError: If the super-set of labels is not compatible.
|
|
1172
|
+
"""
|
|
1173
|
+
assert superset_labels is not None and self.labels is not None, \
|
|
1174
|
+
f"Can only interleave non-empty sets of attribute labels"
|
|
1175
|
+
assert len(superset_labels) >= len(self), f"You must provide a super-set of attribute labels"
|
|
1176
|
+
|
|
1177
|
+
# Ensuring it is a super-set of the current labels and finding its position
|
|
1178
|
+
if self.indices is not None:
|
|
1179
|
+
labels = []
|
|
1180
|
+
indices_list = self.indices.tolist()
|
|
1181
|
+
for i in indices_list:
|
|
1182
|
+
labels.append(self.labels[i])
|
|
1183
|
+
else:
|
|
1184
|
+
labels = self.labels
|
|
1185
|
+
|
|
1186
|
+
indices = []
|
|
1187
|
+
for label in labels:
|
|
1188
|
+
assert label in superset_labels, \
|
|
1189
|
+
f"Cannot find attribute label {label} in (expected) super-set {superset_labels}"
|
|
1190
|
+
indices.append(superset_labels.index(label))
|
|
1191
|
+
|
|
1192
|
+
if len(indices) == len(superset_labels):
|
|
1193
|
+
same_labels_and_order = True
|
|
1194
|
+
for j, i in enumerate(indices):
|
|
1195
|
+
if j != i:
|
|
1196
|
+
same_labels_and_order = False
|
|
1197
|
+
break
|
|
1198
|
+
else:
|
|
1199
|
+
same_labels_and_order = False
|
|
1200
|
+
|
|
1201
|
+
if not same_labels_and_order:
|
|
1202
|
+
self.labels = superset_labels
|
|
1203
|
+
self.num_labels = len(self.labels)
|
|
1204
|
+
self.indices = torch.tensor(indices, dtype=torch.long)
|
|
1205
|
+
|
|
1206
|
+
# Altering shape
|
|
1207
|
+
self.data_props.tensor_shape = (self.data_props.tensor_shape[0], self.num_labels)
|
|
1208
|
+
else:
|
|
1209
|
+
self.indices = None
|