unaiverse 0.1.6__cp310-cp310-macosx_10_9_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unaiverse might be problematic. Click here for more details.

Files changed (50) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2008 -0
  3. unaiverse/agent_basics.py +1846 -0
  4. unaiverse/clock.py +191 -0
  5. unaiverse/dataprops.py +1209 -0
  6. unaiverse/hsm.py +1880 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +680 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1261 -0
  19. unaiverse/networking/node/node.py +2223 -0
  20. unaiverse/networking/node/profile.py +446 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +198 -0
  23. unaiverse/networking/p2p/go.mod +127 -0
  24. unaiverse/networking/p2p/go.sum +548 -0
  25. unaiverse/networking/p2p/golibp2p.py +18 -0
  26. unaiverse/networking/p2p/golibp2p.pyi +135 -0
  27. unaiverse/networking/p2p/lib.go +2714 -0
  28. unaiverse/networking/p2p/lib.go.sha256 +1 -0
  29. unaiverse/networking/p2p/lib_types.py +312 -0
  30. unaiverse/networking/p2p/message_pb2.py +63 -0
  31. unaiverse/networking/p2p/messages.py +265 -0
  32. unaiverse/networking/p2p/mylogger.py +77 -0
  33. unaiverse/networking/p2p/p2p.py +929 -0
  34. unaiverse/networking/p2p/proto-go/message.pb.go +616 -0
  35. unaiverse/networking/p2p/unailib.cpython-310-darwin.so +0 -0
  36. unaiverse/streamlib/__init__.py +15 -0
  37. unaiverse/streamlib/streamlib.py +210 -0
  38. unaiverse/streams.py +770 -0
  39. unaiverse/utils/__init__.py +16 -0
  40. unaiverse/utils/ask_lone_wolf.json +27 -0
  41. unaiverse/utils/lone_wolf.json +19 -0
  42. unaiverse/utils/misc.py +305 -0
  43. unaiverse/utils/sandbox.py +293 -0
  44. unaiverse/utils/server.py +435 -0
  45. unaiverse/world.py +175 -0
  46. unaiverse-0.1.6.dist-info/METADATA +365 -0
  47. unaiverse-0.1.6.dist-info/RECORD +50 -0
  48. unaiverse-0.1.6.dist-info/WHEEL +6 -0
  49. unaiverse-0.1.6.dist-info/licenses/LICENSE +43 -0
  50. unaiverse-0.1.6.dist-info/top_level.txt +1 -0
unaiverse/dataprops.py ADDED
@@ -0,0 +1,1209 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import torch
16
+ from PIL import Image
17
+ from typing import Callable, Any
18
+ from transformers import AutoTokenizer, PreTrainedTokenizerBase
19
+
20
+
21
+ class Data4Proc:
22
+ def __init__(self, *args, private_only: bool = False, public_only: bool = False, **kwargs):
23
+ """Initializes a `Data4Proc` object, which is a container for one or two `DataProps` instances. It creates a
24
+ `DataProps` object for a private stream and, optionally, a public stream.
25
+
26
+ Args:
27
+ *args: Variable length argument list to be passed to the `DataProps` constructor.
28
+ private_only: If True, only a private stream's `DataProps` is created.
29
+ public_only: If True, only a public stream's `DataProps` is created.
30
+ **kwargs: Arbitrary keyword arguments passed to the `DataProps` constructor.
31
+
32
+ Raises:
33
+ ValueError: If both `private_only` and `public_only` are set to True, or if the 'public' argument is passed
34
+ directly.
35
+ """
36
+ self.props = []
37
+ if public_only and private_only:
38
+ raise ValueError("Cannot set both private_only and public_only to True (it does not make any sense)")
39
+ if 'public' in kwargs:
40
+ raise ValueError("Invalid argument was provided to Data4Proc: 'public' (it is an argument of DataProps)")
41
+ kwargs['public'] = False
42
+ self.props.append(DataProps(*args, **kwargs))
43
+ if not private_only:
44
+ kwargs['public'] = True
45
+ self.props.append(DataProps(*args, **kwargs))
46
+
47
+ def to_list_of_dicts(self):
48
+ """Converts the contained `DataProps` objects into a list of dictionaries. Each dictionary represents the
49
+ properties of a single data stream.
50
+
51
+ Returns:
52
+ A list of dictionaries, where each dictionary holds the properties of a `DataProps` object.
53
+ """
54
+ return [props.to_dict() for props in self.props]
55
+
56
+ def to_dict(self):
57
+ """Raises a `RuntimeError` because this method is intended for a single `DataProps` object, not for the
58
+ container class `Data4Proc` which can hold multiple properties.
59
+
60
+ Raises:
61
+ RuntimeError: Always, as this method is not supported.
62
+ """
63
+ raise RuntimeError("This method can only be called on a DataProps object and not on Data4Proc")
64
+
65
+ def from_dict(self):
66
+ """Raises a `RuntimeError` because this method is intended for a single `DataProps` object, not for the
67
+ container class `Data4Proc` which can hold multiple properties.
68
+
69
+ Raises:
70
+ RuntimeError: Always, as this method is not supported.
71
+ """
72
+ raise RuntimeError("This method can only be called on a DataProps object and not on Data4Proc")
73
+
74
+ def clone(self):
75
+ """Creates and returns a deep copy of the `Data4Proc` object.
76
+
77
+ Returns:
78
+ A new `Data4Proc` object that is a clone of the original.
79
+ """
80
+ ret = Data4Proc()
81
+ ret.props = []
82
+ for p in self.props:
83
+ ret.props.append(p.clone())
84
+ return ret
85
+
86
+ def is_public(self):
87
+ """Raises a `RuntimeError` because this method is intended for a single `DataProps` object, not for the
88
+ container class `Data4Proc` which can hold multiple properties.
89
+
90
+ Raises:
91
+ RuntimeError: Always, as this method is not supported.
92
+ """
93
+ raise RuntimeError("This method can only be called on a DataProps object and not on Data4Proc")
94
+
95
+ def __str__(self):
96
+ """Provides a formatted string representation of the `Data4Proc` object. It lists the number of `DataProps`
97
+ objects it contains and includes the string representation of each of them.
98
+
99
+ Returns:
100
+ A string detailing the contents of the `Data4Proc` object.
101
+ """
102
+ s = f"[Data4Proc] Number of DataProps: {len(self.props)}"
103
+ for p in self.props:
104
+ z = str(p).replace("\n", "\n\t")
105
+ s += "\t" + z
106
+ return s
107
+
108
+ def __getattr__(self, method_or_attribute_name):
109
+ """Handles dynamic attribute and method access for the `Data4Proc` class. If the requested method name starts
110
+ with 'set_', it creates a new function that applies the corresponding setter method to all contained
111
+ `DataProps` objects. For any other attribute, it returns the attribute from the first `DataProps` object
112
+ in the list.
113
+
114
+ Args:
115
+ method_or_attribute_name: The name of the method or attribute being accessed.
116
+
117
+ Returns:
118
+ Either a callable function or the requested attribute from the first `DataProps` object.
119
+ """
120
+ if method_or_attribute_name.startswith('set_'):
121
+ def apply_set_method_to_all_props(*args, **kwargs):
122
+ for prop in self.props:
123
+ getattr(prop, method_or_attribute_name)(*args, **kwargs)
124
+
125
+ return apply_set_method_to_all_props
126
+ else:
127
+ return getattr(self.props[0], method_or_attribute_name)
128
+
129
+
130
+ class DataProps:
131
+ """
132
+ A class for handling the properties and transformations of data, including labels.
133
+ It supports different data types: 'tensor', 'tensor_token_id', 'img', and 'text'.
134
+
135
+ Attributes:
136
+ VALID_DATA_TYPES (tuple): Tuple of valid data types ('tensor', 'tensor_token_id', 'img', 'text').
137
+ """
138
+ VALID_DATA_TYPES = ('tensor', 'img', 'text', 'all')
139
+
140
+ def __init__(self,
141
+ name: str = "unk",
142
+ group: str = "none",
143
+ data_type: str = "text", # Do not set tensor as default
144
+ data_desc: str = "unk",
145
+ tensor_shape: tuple[int | None, ...] | None = None,
146
+ tensor_labels: list[str] | str | None = None,
147
+ tensor_dtype: torch.dtype | str | None = None,
148
+ tensor_labeling_rule: str = "max",
149
+ stream_to_proc_transforms: Callable[..., Any] | PreTrainedTokenizerBase | str | dict | tuple[
150
+ dict | Callable[..., Any] | PreTrainedTokenizerBase | str | None,
151
+ dict | Callable[..., Any] | PreTrainedTokenizerBase | str | None] | None = None,
152
+ proc_to_stream_transforms: Callable[..., Any] | PreTrainedTokenizerBase | str | list | None = None,
153
+ delta: float = -1,
154
+ pubsub: bool = False,
155
+ public: bool = False):
156
+ """Initializes a DataProps instance.
157
+
158
+ Args:
159
+ name (str): Name of the data (default is "unk").
160
+ group (str): Name of the group to which this DataProps belong (default: "none").
161
+ data_type (str): The type of data ('tensor', 'img', or 'text').
162
+ data_desc (str): Description of the data (default is "unk").
163
+ tensor_shape (tuple[int | None] or None): Shape of the tensor data (e.g., (3, 224, 224) or (3, None, None)
164
+ for tensors that are variable size images). It is None for non-tensor data.
165
+ tensor_labels (list[str] or "AutoTokenizer:<tokenizer_model_id>" or None):
166
+ List of labels for the components (features) of tensor data. It can be a string representing the ID of a
167
+ tokenizer which is valid in AutoTokenizer (with prefix "AutoTokenizer:").
168
+ tensor_dtype (torch.dtype or str or None): The string representing the Pytorch dtype of the tensor data.
169
+ tensor_labeling_rule (str): The labeling rule for tensor data ('max' or 'geqX' where X is a number).
170
+ stream_to_proc_transforms (callable or PreTrainedTokenizerBase or str or dict or a list* or None):
171
+ A callable stream format to tensor format conversion fcn (any callable thing, torchvision transforms, a
172
+ pretrained tokenizer, or the model ID from which it can be downloaded (it must have prefix
173
+ "AutoTokenizer:"), or a vocabulary str->int). It is None for already-tensorial data.
174
+ *If you need to distinguish the transform applied to the inputs and to targets, you can pass a list of
175
+ two elements like the just described ones - one for input, one for targets, respectively.
176
+ proc_to_stream_transforms (callable or PreTrainedTokenizerBase or str or list or None): A callable tensor
177
+ to stream format function (any callable thing, torchvision transforms, a Pretrained tokenizer (HF),
178
+ or the model ID from which it can be downloaded, (it must have prefix "AutoTokenizer:"), or a
179
+ vocabulary int->str). It is None for non text data.
180
+ delta (float): Time interval between consecutive data samples (<= 0 for real-time data).
181
+ pubsub (bool): If the stream is supposed to be sent to/received from a Pub/Sub topic.
182
+ public (bool): If the stream is supposed to be accessed through the public net or through the private one.
183
+
184
+ Returns:
185
+ None
186
+ """
187
+
188
+ # Checking data type
189
+ assert data_type in DataProps.VALID_DATA_TYPES, "Invalid data type"
190
+ assert isinstance(data_desc, str), "Invalid data description"
191
+
192
+ # Checking transformations
193
+ assert (stream_to_proc_transforms is None or
194
+ isinstance(stream_to_proc_transforms, str) or
195
+ isinstance(stream_to_proc_transforms, PreTrainedTokenizerBase) or
196
+ callable(stream_to_proc_transforms) or
197
+ isinstance(stream_to_proc_transforms, dict) or
198
+ isinstance(stream_to_proc_transforms, tuple) or
199
+ isinstance(stream_to_proc_transforms, list)), \
200
+ "Invalid stream to processor transforms"
201
+
202
+ if stream_to_proc_transforms is not None:
203
+ if not isinstance(stream_to_proc_transforms, list) and not isinstance(stream_to_proc_transforms, tuple):
204
+ self.stream_to_proc_transforms = [stream_to_proc_transforms, stream_to_proc_transforms]
205
+ else:
206
+ assert len(stream_to_proc_transforms) == 2, \
207
+ "Expected a list with two sets of transforms (input, target)"
208
+ self.stream_to_proc_transforms = stream_to_proc_transforms
209
+ self.__original_stream_to_proc_transforms = stream_to_proc_transforms
210
+ else:
211
+ self.stream_to_proc_transforms = None
212
+ self.__original_stream_to_proc_transforms = None
213
+
214
+ assert (proc_to_stream_transforms is None or
215
+ isinstance(proc_to_stream_transforms, str) or
216
+ isinstance(proc_to_stream_transforms, PreTrainedTokenizerBase) or
217
+ callable(proc_to_stream_transforms) or
218
+ isinstance(proc_to_stream_transforms, list)), \
219
+ "Invalid stream to processor transforms"
220
+
221
+ self.proc_to_stream_transforms = proc_to_stream_transforms
222
+ self.__original_proc_to_stream_transforms = proc_to_stream_transforms
223
+
224
+ # Setting data type and description
225
+ self.data_type = data_type
226
+ self.data_desc = data_desc
227
+
228
+ # Setting empty attributes
229
+ self.tensor_shape = None
230
+ self.tensor_dtype = None
231
+ self.tensor_labels = None
232
+
233
+ # Checking data in function of its type
234
+ if self.is_tensor():
235
+
236
+ # Checking shape
237
+ assert (tensor_shape is not None and
238
+ isinstance(tensor_shape, (tuple, list))), f"Invalid shape for DataProps: {tensor_shape}"
239
+ assert all(x is None or isinstance(x, int) for x in tensor_shape), \
240
+ f"Invalid shape for DataProps: {tensor_shape}"
241
+
242
+ # Setting shape
243
+ self.tensor_shape = tuple(tensor_shape) # Forcing (important)
244
+
245
+ # Checking dtype
246
+ assert (tensor_dtype is not None and
247
+ (isinstance(tensor_dtype, torch.dtype) or isinstance(tensor_dtype, str)
248
+ and tensor_dtype.startswith("torch."))), \
249
+ f"Invalid tensor type: {tensor_dtype}"
250
+
251
+ # Setting dtype
252
+ self.tensor_dtype = tensor_dtype if isinstance(tensor_dtype, torch.dtype) else eval(tensor_dtype)
253
+
254
+ # Checking labels
255
+ assert tensor_labels is None or (isinstance(tensor_labels, list) or
256
+ (isinstance(tensor_shape, str) and
257
+ tensor_labels.startswith("AutoTokenizer:"))), \
258
+ f"Invalid tensor labels: {tensor_labels}"
259
+
260
+ # Setting labels
261
+ if tensor_labels is not None:
262
+ if not (isinstance(tensor_labels, str) and tensor_labels.startswith("AutoTokenizer:")):
263
+ self.tensor_labels = TensorLabels(self, labels=tensor_labels, labeling_rule=tensor_labeling_rule)
264
+ else:
265
+ self.set_tensor_labels_from_auto_tokenizer(tensor_labels.split[:][1])
266
+
267
+ elif self.is_img():
268
+
269
+ # Ensuring other type-related tools are not set
270
+ assert tensor_shape is None and tensor_labels is None and tensor_dtype is None, \
271
+ f"Tensor-related arguments must be None when using a DataProps of type {data_type}"
272
+ assert (self.stream_to_proc_transforms is None or (not isinstance(self.stream_to_proc_transforms, str)
273
+ and not isinstance(self.stream_to_proc_transforms,
274
+ PreTrainedTokenizerBase))), \
275
+ "Non-image-related transforms were selected"
276
+ assert (self.proc_to_stream_transforms is None or (not isinstance(self.proc_to_stream_transforms, str)
277
+ and not isinstance(self.proc_to_stream_transforms,
278
+ PreTrainedTokenizerBase)
279
+ and not isinstance(self.proc_to_stream_transforms,
280
+ list))), \
281
+ "Non-image-related transforms were selected"
282
+
283
+ elif self.is_text():
284
+
285
+ # Ensuring other type-related tools are not set
286
+ assert tensor_shape is None and tensor_labels is None and tensor_dtype is None, \
287
+ f"Tensor/image-related arguments must be None when using a DataProps of type {data_type}"
288
+
289
+ # Setting text to tensor transform (tokenizer in encode mode) (if given)
290
+ if self.stream_to_proc_transforms is not None:
291
+ for j, _tttt in enumerate(self.stream_to_proc_transforms):
292
+ assert ((isinstance(_tttt, str) and _tttt.startswith("AutoTokenizer:")) or
293
+ isinstance(_tttt, PreTrainedTokenizerBase) or
294
+ isinstance(_tttt, dict) or
295
+ callable(_tttt)), \
296
+ ("Invalid text tokenizer: expected object of type PreTrainedTokenizerBase or a "
297
+ "string starting with 'AutoTokenizer:' or a callable object or a dictionary "
298
+ "(vocabulary str->int)")
299
+ if isinstance(_tttt, str) and _tttt.startswith("AutoTokenizer:"):
300
+ self.stream_to_proc_transforms[j] = AutoTokenizer.from_pretrained(_tttt.split(":")[1])
301
+
302
+ # Setting tensor to text transform (tokenizer in decode mode OR a given vocabulary int->str) (if given)
303
+ if self.proc_to_stream_transforms is not None:
304
+ assert ((isinstance(self.proc_to_stream_transforms, str) and
305
+ self.proc_to_stream_transforms.startswith("AutoTokenizer:")) or
306
+ isinstance(self.proc_to_stream_transforms, PreTrainedTokenizerBase) or
307
+ isinstance(self.proc_to_stream_transforms, list) or
308
+ callable(self.proc_to_stream_transforms)), \
309
+ ("Invalid text tokenizer: expected object of type PreTrainedTokenizerBase or a "
310
+ "string starting with 'AutoTokenizer:' or a callable object or a dictionary "
311
+ "(vocabulary int->str)")
312
+ if (isinstance(self.proc_to_stream_transforms, str) and
313
+ self.proc_to_stream_transforms.startswith("AutoTokenizer:")):
314
+ self.proc_to_stream_transforms = (
315
+ AutoTokenizer.from_pretrained(self.proc_to_stream_transforms.split(":")[1]))
316
+
317
+ # Checking name and group
318
+ assert "~" not in name, "Invalid chars in stream name"
319
+ assert "~" not in group, "Invalid chars in group name"
320
+
321
+ # Initialize properties
322
+ self.name = name
323
+ self.group = group
324
+ self.delta = delta
325
+ self.pubsub = pubsub
326
+ self.public = public
327
+
328
+ def to_dict(self):
329
+ """Serializes the `DataProps` object into a dictionary, making it suitable for transmission or storage.
330
+ It converts complex types like `torch.dtype` and `TensorLabels` into simple, serializable formats.
331
+
332
+ Returns:
333
+ A dictionary representation of the object's properties.
334
+ """
335
+ return {
336
+ 'name': self.name,
337
+ 'group': self.group,
338
+ 'data_type': self.data_type,
339
+ 'data_desc': self.data_desc,
340
+ 'tensor_shape': self.tensor_shape,
341
+ 'tensor_dtype': str(self.tensor_dtype) if self.tensor_dtype is not None else None,
342
+ 'tensor_labels': self.tensor_labels.to_dict() if self.tensor_labels is not None else None,
343
+ 'delta': self.delta,
344
+ 'pubsub': self.pubsub,
345
+ 'public': self.public
346
+ }
347
+
348
+ @staticmethod
349
+ def from_dict(d_props):
350
+ """A static method that deserializes a dictionary into a `DataProps` object.
351
+
352
+ Args:
353
+ d_props: The dictionary containing the object's properties.
354
+
355
+ Returns:
356
+ A new `DataProps` object.
357
+ """
358
+ d_labels = d_props['tensor_labels']
359
+ return DataProps(name=d_props['name'],
360
+ group=d_props['group'],
361
+ data_type=d_props['data_type'],
362
+ data_desc=d_props['data_desc'],
363
+ tensor_shape=d_props['tensor_shape'],
364
+ tensor_dtype=d_props['tensor_dtype'],
365
+ tensor_labels=d_labels['labels'] if d_labels is not None else None,
366
+ tensor_labeling_rule=d_labels['labeling_rule'] if d_labels is not None else "max",
367
+ delta=d_props['delta'],
368
+ pubsub=d_props['pubsub'],
369
+ public=d_props['public'])
370
+
371
+ def clone(self):
372
+ """Creates and returns a deep copy of the `DataProps` instance.
373
+ It preserves the original transformation objects rather than re-evaluating them.
374
+
375
+ Returns:
376
+ A new `DataProps` object that is a clone of the original.
377
+ """
378
+ return DataProps(name=self.name,
379
+ group=self.group,
380
+ data_type=self.data_type,
381
+ data_desc=self.data_desc,
382
+ tensor_shape=self.tensor_shape,
383
+ tensor_dtype=self.tensor_dtype,
384
+ tensor_labels=self.tensor_labels.labels if self.tensor_labels is not None else None,
385
+ tensor_labeling_rule=self.tensor_labels.original_labeling_rule
386
+ if self.tensor_labels is not None else "max",
387
+ stream_to_proc_transforms=self.__original_stream_to_proc_transforms,
388
+ proc_to_stream_transforms=self.__original_proc_to_stream_transforms,
389
+ delta=self.delta,
390
+ pubsub=self.pubsub,
391
+ public=self.public)
392
+
393
+ def get_name(self):
394
+ """Retrieves the name of the data stream.
395
+
396
+ Returns:
397
+ A string representing the stream's name.
398
+ """
399
+
400
+ return self.name
401
+
402
+ def get_group(self):
403
+ """Retrieves the name of the group to which the data stream belongs.
404
+
405
+ Returns:
406
+ A string representing the group name.
407
+ """
408
+
409
+ return self.group
410
+
411
+ def get_description(self):
412
+ """Retrieves the description of the data.
413
+
414
+ Returns:
415
+ A string with the data description.
416
+ """
417
+ return self.data_desc
418
+
419
+ def get_tensor_labels(self):
420
+ """Retrieves the list of tensor labels, if they exist.
421
+
422
+ Returns:
423
+ A list of strings or None.
424
+ """
425
+ return self.tensor_labels.labels if self.tensor_labels is not None else None
426
+
427
+ def set_name(self, name: str):
428
+ """Sets a new name for the stream, with a check for invalid characters.
429
+
430
+ Args:
431
+ name: The new name as a string.
432
+ """
433
+ assert "~" not in name, "Invalid chars in stream name"
434
+ self.name = name
435
+
436
+ def set_group(self, group: str):
437
+ """Sets a new group name for the stream, with a check for invalid characters.
438
+
439
+ Args:
440
+ group: The new group name as a string.
441
+ """
442
+ assert "~" not in group, "Invalid chars in group name"
443
+ self.group = group
444
+
445
+ def set_description(self, desc: str):
446
+ """Sets a new description for the data.
447
+
448
+ Args:
449
+ desc: The new description as a string.
450
+ """
451
+ self.data_desc = desc
452
+
453
+ def set_public(self, public: bool):
454
+ """Sets whether the stream is public or not.
455
+
456
+ Args:
457
+ public: A boolean value.
458
+ """
459
+ self.public = public
460
+
461
+ def set_pubsub(self, pubsub: bool):
462
+ """Sets whether the stream uses Pub/Sub or direct messaging.
463
+
464
+ Args:
465
+ pubsub: A boolean value.
466
+ """
467
+ self.pubsub = pubsub
468
+
469
+ def is_tensor(self):
470
+ """Checks if the data type is 'tensor'.
471
+
472
+ Returns:
473
+ True if the type is 'tensor', False otherwise.
474
+ """
475
+ return self.data_type == "tensor"
476
+
477
+ def is_img(self):
478
+ """Checks if the data type is 'img'.
479
+
480
+ Returns:
481
+ True if the type is 'img', False otherwise.
482
+ """
483
+ return self.data_type == "img"
484
+
485
+ def is_text(self):
486
+ """Checks if the data type is 'text'.
487
+
488
+ Returns:
489
+ True if the type is 'text', False otherwise.
490
+ """
491
+ return self.data_type == "text"
492
+
493
+ def is_tensor_long(self):
494
+ """Checks if the tensor's data type is `torch.long`.
495
+
496
+ Returns:
497
+ True if the dtype is `torch.long`, False otherwise.
498
+ """
499
+ return self.tensor_dtype == torch.long if self.tensor_dtype is not None else False
500
+
501
+ def is_tensor_float(self):
502
+ """Checks if the tensor's data type is a float type (e.g., `torch.float32`).
503
+
504
+ Returns:
505
+ True if the dtype is a float type, False otherwise.
506
+ """
507
+ return str(self.tensor_dtype).startswith("torch.float") if self.tensor_dtype is not None else False
508
+
509
+ def is_tensor_img(self):
510
+ """Checks if the tensor's shape corresponds to a typical image format (4D, with 1 or 3 channels).
511
+
512
+ Returns:
513
+ True if the shape matches, False otherwise.
514
+ """
515
+ return len(self.tensor_shape) == 4 and (self.tensor_shape[1] == 1 or self.tensor_shape[1] == 3) \
516
+ if self.tensor_shape is not None else False
517
+
518
+ def is_tensor_token_ids(self):
519
+ """Checks if the tensor represents token IDs, which is indicated by `torch.long` data type and a 2D shape
520
+ suitable for sequences.
521
+
522
+ Returns:
523
+ True if it matches, False otherwise.
524
+ """
525
+ return (self.tensor_dtype == torch.long and
526
+ len(self.tensor_shape) == 2 and (self.tensor_shape[1] >= 1 or self.tensor_shape[1] is None)) \
527
+ if self.tensor_shape is not None else False
528
+
529
+ def is_tensor_target_id(self):
530
+ """Checks if the tensor represents a single target ID, indicated by `torch.long` data type and a 1D shape.
531
+
532
+ Returns:
533
+ True if it matches, False otherwise.
534
+ """
535
+ return (self.tensor_dtype == torch.long and
536
+ len(self.tensor_shape) == 1) \
537
+ if self.tensor_shape is not None else False
538
+
539
+ def is_all(self):
540
+ """Checks if the data type is 'all', which is a wildcard type.
541
+
542
+ Returns:
543
+ True if the type is 'all', False otherwise.
544
+ """
545
+ return self.data_type == "all"
546
+
547
+ def net_hash(self, prefix: str):
548
+ """Generates a unique network hash for the stream using a provided prefix, Pub/Sub status, and name/group.
549
+
550
+ Args:
551
+ prefix: The prefix, typically the peer ID.
552
+
553
+ Returns:
554
+ A string representing the network hash.
555
+ """
556
+ return DataProps.build_net_hash(prefix, self.pubsub, self.name_or_group())
557
+
558
+ @staticmethod
559
+ def peer_id_from_net_hash(net_hash):
560
+ """A static method to extract the peer ID from a network hash.
561
+
562
+ Args:
563
+ net_hash: The network hash string.
564
+
565
+ Returns:
566
+ A string representing the peer ID.
567
+ """
568
+ return net_hash.split("::")[0]
569
+
570
+ @staticmethod
571
+ def name_or_group_from_net_hash(net_hash):
572
+ """A static method to extract the name or group from a network hash.
573
+
574
+ Args:
575
+ net_hash: The network hash string.
576
+
577
+ Returns:
578
+ A string representing the name or group.
579
+ """
580
+ return net_hash.split("::ps:")[1] if DataProps.is_pubsub_from_net_hash(net_hash) else net_hash.split("::dm:")[1]
581
+
582
+ @staticmethod
583
+ def is_pubsub_from_net_hash(net_hash):
584
+ """A static method to check if a network hash belongs to a Pub/Sub stream.
585
+
586
+ Args:
587
+ net_hash: The network hash string.
588
+
589
+ Returns:
590
+ True if the hash is for a Pub/Sub stream, False otherwise.
591
+ """
592
+ return "::ps:" in net_hash
593
+
594
+ def name_or_group(self):
595
+ """Retrieves the group name if it's set, otherwise defaults to the stream name.
596
+
597
+ Returns:
598
+ A string representing the name or group.
599
+ """
600
+ group = self.get_group()
601
+ return group if group != 'none' else self.get_name()
602
+
603
+ @staticmethod
604
+ def build_net_hash(prefix: str, pubsub: bool, name_or_group: str):
605
+ """A static method to construct a complete network hash from a prefix, Pub/Sub status, and name/group.
606
+
607
+ Args:
608
+ prefix: The peer ID prefix.
609
+ pubsub: The Pub/Sub status.
610
+ name_or_group: The name or group of the stream.
611
+
612
+ Returns:
613
+ The constructed network hash string.
614
+ """
615
+ if pubsub:
616
+ return f"{prefix}::ps:{name_or_group}"
617
+ else:
618
+ return f"{prefix}::dm:{name_or_group}"
619
+
620
+ @staticmethod
621
+ def normalize_net_hash(not_normalized_net_hash: str):
622
+ """A static method that cleans up or normalizes a network hash string to a canonical format, particularly
623
+ for direct messages.
624
+
625
+ Args:
626
+ not_normalized_net_hash: The network hash to normalize.
627
+
628
+ Returns:
629
+ The normalized network hash string.
630
+ """
631
+ if not DataProps.is_pubsub_from_net_hash(not_normalized_net_hash):
632
+ if "~" in not_normalized_net_hash:
633
+ return not_normalized_net_hash.split("::dm:")[0] + "::dm:" + not_normalized_net_hash.split("~")[1]
634
+ else:
635
+ parts = not_normalized_net_hash.split("::dm:")
636
+ return parts[0] + "::dm:" + parts[1].split("-")[1]
637
+ else:
638
+ return not_normalized_net_hash
639
+
640
+ def is_pubsub(self):
641
+ """Checks if the stream is set to use Pub/Sub.
642
+
643
+ Returns:
644
+ True if it's a Pub/Sub stream, False otherwise.
645
+ """
646
+ return self.pubsub
647
+
648
+ def is_public(self):
649
+ """Checks if the stream is set to be public.
650
+
651
+ Returns:
652
+ True if it's a public stream, False otherwise.
653
+ """
654
+ return self.public
655
+
656
+ def set_tensor_labels_from_auto_tokenizer(self, model_id):
657
+ """Initializes and sets the tensor labels by fetching the vocabulary from a Hugging Face `AutoTokenizer`
658
+ model ID.
659
+
660
+ Args:
661
+ model_id: The ID of the tokenizer model.
662
+ """
663
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
664
+ vocab_size = len(tokenizer.vocab)
665
+ reverse_vocab_list: list[str | None] = [None] * vocab_size
666
+ for i in range(vocab_size):
667
+ reverse_vocab_list[i] = tokenizer.convert_ids_to_tokens(i)
668
+ self.set_tensor_labels(reverse_vocab_list)
669
+
670
+ def set_tensor_labels(self, labels: list[str] | None, labeling_rule: str = "max"):
671
+ """Sets the labels for the data.
672
+
673
+ Args:
674
+ labels (list[str] or None): List of labels to associate with the data.
675
+ labeling_rule (str): The labeling rule for the labels.
676
+
677
+ Returns:
678
+ None
679
+ """
680
+ self.tensor_labels = TensorLabels(self, labels=labels, labeling_rule=labeling_rule)
681
+
682
+ def adapt_tensor_to_tensor_labels(self, data: torch.Tensor) -> torch.Tensor:
683
+ """Interleaves data in function of its corresponding labels and the current super-set labels.
684
+
685
+ Args:
686
+ data (torch.Tensor): The data tensor to interleave.
687
+
688
+ Returns:
689
+ torch.Tensor: The interleaved data tensor.
690
+ """
691
+ if self.is_tensor():
692
+ num_labels = len(self.tensor_labels) if self.tensor_labels is not None else 0
693
+ if num_labels > 0 and data.shape[1] < num_labels and self.tensor_labels.indices is not None:
694
+ data_padded = torch.zeros((data.shape[0], num_labels), device=data.device, dtype=data.dtype)
695
+ data_padded[:, self.tensor_labels.indices] = data
696
+ return data_padded
697
+ else:
698
+ return data # Do nothing
699
+ else:
700
+ return data # Do nothing
701
+
702
+ def clear_label_adaptation(self, data: torch.Tensor):
703
+ """Removes the padding and returns the original data from an adapted tensor.
704
+
705
+ Args:
706
+ data: The adapted tensor.
707
+
708
+ Returns:
709
+ The original, un-padded tensor.
710
+ """
711
+ return data[:, self.tensor_labels.indices] if self.tensor_labels.indices is not None else data
712
+
713
+ def is_flat_tensor_with_labels(self):
714
+ """Checks if the tensor is a 2D array and has labels, which is a common structure for general feature data.
715
+
716
+ Returns:
717
+ True if it is, False otherwise.
718
+ """
719
+ return self.is_tensor() and len(self.tensor_shape) == 2 and self.has_tensor_labels()
720
+
721
+ def has_tensor_labels(self):
722
+ """Checks if any tensor labels are associated with the stream.
723
+
724
+ Returns:
725
+ True if labels exist, False otherwise.
726
+ """
727
+ return self.tensor_labels is not None and len(self.tensor_labels) > 0
728
+
729
+ def to_text(self, data: torch.Tensor | str):
730
+ """Converts the tensor data into a text-based representation exploiting the given labels and the labeling rule.
731
+
732
+ Args:
733
+ data (torch.Tensor or str): The data tensor to convert into text (if a string, then pass-through only).
734
+
735
+ Returns:
736
+ str or None: The corresponding text representation of the data.
737
+
738
+ Raises:
739
+ ValueError: If the data type is not supported for conversion.
740
+ """
741
+ if isinstance(data, str):
742
+ return data
743
+ elif not isinstance(data, torch.Tensor):
744
+ return None
745
+ elif len(data.shape) > 2: # Can only print 1d data (recall that 1d data has 2 dimensions, due to batch size)
746
+ return None
747
+
748
+ if data.shape[0] != 1:
749
+ return None # "Code designed for a batch of only 1 element
750
+
751
+ if self.is_tensor():
752
+ if not self.has_tensor_labels():
753
+ return None
754
+
755
+ if self.is_tensor_token_ids():
756
+
757
+ # This is the case in which we assume to have a vector of token IDs
758
+ text = ""
759
+ for i in range(0, data.shape[1]):
760
+ if i > 0:
761
+ text += " "
762
+ text += self.tensor_labels[data[0][i].item()]
763
+ return text
764
+
765
+ elif self.is_tensor_float():
766
+
767
+ # This is the generic case of a 1d tensor
768
+ if self.tensor_labels.labeling_rule == "max":
769
+ j = torch.argmax(data, dim=1)
770
+ return self.tensor_labels[j.item()]
771
+ elif self.tensor_labels.labeling_rule == "geq":
772
+
773
+ # Warning: does not work for mini-batches
774
+ jj = torch.where(data >= self.tensor_labels.labeling_rule_thres)[1]
775
+ return ", ".join(self.tensor_labels[j] for j in jj.tolist())
776
+ else:
777
+ return None
778
+
779
+ elif self.is_text():
780
+ if self.proc_to_stream_transforms is None:
781
+ return None
782
+ if isinstance(self.proc_to_stream_transforms, PreTrainedTokenizerBase):
783
+ return self.proc_to_stream_transforms.decode(data[0])
784
+ elif isinstance(self.proc_to_stream_transforms, dict):
785
+ if data.dtype != torch.long:
786
+
787
+ # This is the case of probabilities
788
+ j = torch.argmax(data, dim=1) # Warning: does not work for mini-batches
789
+ return self.proc_to_stream_transforms[j.item()]
790
+ else:
791
+
792
+ # This is the case in which we assume to have a vector of token IDs
793
+ text = ""
794
+ for i in range(0, data.shape[1]):
795
+ if i > 0:
796
+ text += " "
797
+ text += self.proc_to_stream_transforms[data[0][i].item()]
798
+ return text
799
+ else:
800
+ return self.proc_to_stream_transforms(data)
801
+ else:
802
+ return None
803
+
804
+ def check_and_preprocess(self, data: str | Image.Image | torch.Tensor,
805
+ allow_class_ids: bool = False, targets: bool = False,
806
+ device: torch.device = torch.device("cpu")):
807
+ """Prepares incoming data for a processor by validating its type and applying necessary transformations.
808
+ It handles different data types, including tensors, text (strings), and images, raising `ValueError` if
809
+ the data type is unexpected or incompatible with the stream's properties. For text and images, it can apply a
810
+ pre-configured transformation (like a tokenizer or a standard image transform) to convert the data into a
811
+ tensor format suitable for processing. For tensors, it performs validation on shape and data type.
812
+
813
+ Args:
814
+ data: The data sample to check and preprocess.
815
+ allow_class_ids: A boolean to allow single-element long tensors, typically for class IDs.
816
+ targets: A boolean to indicate if the data is a target (used to select the correct transformation in a
817
+ dual-transform setup).
818
+ device: The PyTorch device (e.g., 'cpu' or 'cuda') to which the tensor should be moved.
819
+
820
+ Returns:
821
+ The preprocessed data, typically a tensor on the specified device.
822
+ """
823
+ if self.is_tensor():
824
+ if isinstance(data, torch.Tensor):
825
+
826
+ # Skipping all checks, it is enough to know it is a tensor
827
+ if allow_class_ids and data.dtype == torch.long and len(data.shape) == 1:
828
+ return data.to(device)
829
+
830
+ # Checking dtype
831
+ if self.tensor_dtype != data.dtype:
832
+ raise ValueError(f"Expected data of type {self.tensor_dtype}, got {data.dtype} ("
833
+ f"shape {data.shape})")
834
+
835
+ # Checking shape
836
+ if len(self.tensor_shape) != len(data.shape):
837
+ raise ValueError(f"Expected data with shape {self.tensor_shape}, got {data.shape}")
838
+ for i, s in enumerate(self.tensor_shape):
839
+ if s is not None:
840
+ if s != data.shape[i]:
841
+ raise ValueError(f"Expected data with shape {self.tensor_shape}, got {data.shape}")
842
+
843
+ # Checking labels
844
+ if self.has_tensor_labels():
845
+ if data.ndim != 2:
846
+ raise ValueError("Only 2d tensors are expected for "
847
+ "labeled attributes (1st dimension is batch dim)")
848
+ if not (self.is_tensor_token_ids() or data.shape[1] == self.tensor_labels.num_labels):
849
+ raise ValueError(f"Expected data with {self.tensor_labels.num_labels} "
850
+ f"components (ignoring the 1st dimension), "
851
+ f"got {data[0].numel()}")
852
+
853
+ return data.to(device)
854
+ else:
855
+ raise ValueError(f"Expecting tensor data, got {type(data)}")
856
+ elif self.is_text():
857
+ if isinstance(data, str):
858
+ if self.stream_to_proc_transforms is not None:
859
+ text_to_tensor_transform = self.stream_to_proc_transforms[int(targets)]
860
+ if text_to_tensor_transform is not None:
861
+ if isinstance(text_to_tensor_transform, PreTrainedTokenizerBase):
862
+ return text_to_tensor_transform(data, return_tensors='pt')['input_ids'].to(device) # Tok
863
+ elif isinstance(text_to_tensor_transform, dict):
864
+ return torch.tensor(text_to_tensor_transform[data]
865
+ if data in text_to_tensor_transform else len(text_to_tensor_transform),
866
+ dtype=torch.long, device=device).view(1, -1) # Warning batch size 1
867
+ else:
868
+ return text_to_tensor_transform(data).to(device) # Custom callable function
869
+ else:
870
+ return data
871
+ else:
872
+ return data
873
+ else:
874
+ raise ValueError(f"Expecting text (string) data, got {type(data)}")
875
+ elif self.is_img():
876
+ if isinstance(data, Image.Image):
877
+ if self.stream_to_proc_transforms is not None:
878
+ img_to_tensor_transform = self.stream_to_proc_transforms[int(targets)]
879
+ if img_to_tensor_transform is not None:
880
+ return img_to_tensor_transform(data).to(device)
881
+ else:
882
+ return data
883
+ else:
884
+ return data
885
+ else:
886
+ raise ValueError(f"Expecting image (PIL.Image) data, got {type(data)}")
887
+ elif self.is_all():
888
+ return data
889
+ else:
890
+ raise ValueError(f"Unexpected data type, {self.data_type}")
891
+
892
+ def check_and_postprocess(self, data: str | Image.Image | torch.Tensor):
893
+ """Takes a processor's output and validates it before converting it back into a stream-compatible format.
894
+ It handles `torch.Tensor` data, applying a `proc_to_stream_transform` (if one exists) to convert the tensor
895
+ into an appropriate format for the stream, such as a string for text or a PIL `Image` for images. It performs
896
+ a final check on the data's format (shape, dtype, etc.) to ensure consistency with the stream's properties.
897
+
898
+ Args:
899
+ data: The output from the processor, typically a tensor.
900
+
901
+ Returns:
902
+ The post-processed data, in a stream-compatible format (e.g., a string, image, or CPU tensor).
903
+ """
904
+ if self.is_tensor():
905
+ if isinstance(data, torch.Tensor):
906
+ if self.proc_to_stream_transforms is not None:
907
+ data = self.proc_to_stream_transforms(data)
908
+ data = data.cpu()
909
+
910
+ # Checking dtype
911
+ if self.tensor_dtype != data.dtype:
912
+ raise ValueError(f"Expected data of type {self.tensor_dtype}, got {data.dtype}")
913
+
914
+ # Checking shape
915
+ if len(self.tensor_shape) != len(data.shape):
916
+ raise ValueError(f"Expected data with shape {self.tensor_shape}, got {data.shape}")
917
+ for i, s in enumerate(self.tensor_shape):
918
+ if s is not None:
919
+ if s != data.shape[i]:
920
+ raise ValueError(f"Expected data with shape {self.tensor_shape}, got {data.shape}")
921
+
922
+ # Checking labels
923
+ if self.has_tensor_labels():
924
+ if data.ndim != 2:
925
+ raise ValueError("Only 2d tensors are expected for "
926
+ "labeled attributes (1st dimension is batch dim)")
927
+ if not (self.is_tensor_token_ids() or data.shape[1] == self.tensor_labels.num_labels):
928
+ raise ValueError(f"Expected data with {self.tensor_labels.num_labels} "
929
+ f"components (ignoring the 1st dimension), "
930
+ f"got {data[0].numel()}")
931
+
932
+ return data
933
+ else:
934
+ raise ValueError(f"Expecting tensor data, got {type(data)}")
935
+ elif self.is_text():
936
+ if isinstance(data, str):
937
+ return data
938
+ elif isinstance(data, torch.Tensor):
939
+ data = data.cpu()
940
+ if self.proc_to_stream_transforms is not None:
941
+ assert data.shape[0] == 1, f"Code designed for a batch of only 1 element, got {data.shape[0]}"
942
+ if isinstance(self.proc_to_stream_transforms, PreTrainedTokenizerBase):
943
+ return self.proc_to_stream_transforms.decode(data[0]) # Tokenizer
944
+ elif isinstance(self.proc_to_stream_transforms, list):
945
+ if data.dtype != torch.long:
946
+
947
+ # This is the case of probabilities
948
+ j = torch.argmax(data, dim=1) # Warning: does not work for mini-batches
949
+ return self.proc_to_stream_transforms[j.item()]
950
+ else:
951
+
952
+ # This is the case in which we assume to have a vector of token IDs
953
+ text = ""
954
+ for i in range(0, data.shape[1]):
955
+ if i > 0:
956
+ text += " "
957
+ text += self.proc_to_stream_transforms[data[0][i].item()]
958
+ return text
959
+ else:
960
+ return self.proc_to_stream_transforms(data) # Custom callable function
961
+ else:
962
+ raise ValueError(f"Cannot decode torch.Tensor to text, since text_to_tensor_inv_transform is None")
963
+ else:
964
+ raise ValueError(f"Expecting text (string) or tensor data, got {type(data)}")
965
+ elif self.is_img():
966
+ if isinstance(data, Image.Image):
967
+ return data
968
+ elif isinstance(data, torch.Tensor):
969
+ data = data.cpu()
970
+ if self.proc_to_stream_transforms is not None:
971
+ return self.proc_to_stream_transforms(data)
972
+ else:
973
+ raise ValueError(f"Cannot convert a tensor to PIL.Image, since img_to_tensor_inv_transform is None")
974
+ else:
975
+ raise ValueError(f"Expecting image (PIL.Image) data or torch.Tensor, got {type(data)}")
976
+ elif self.is_all():
977
+ return data
978
+ else:
979
+ raise ValueError(f"Unexpected data type, {self.data_type}")
980
+
981
+ def is_compatible(self, props_to_compare: 'DataProps') -> bool:
982
+ """Checks if the current DataProps instance is compatible with another DataProps instance.
983
+ Checks include data type, shape, and labels.
984
+
985
+ Args:
986
+ props_to_compare (DataProps): The DataProps instance to check compatibility with.
987
+
988
+ Returns:
989
+ bool: True if compatible, False otherwise.
990
+ """
991
+
992
+ # Checking data type
993
+ if self.data_type != props_to_compare.data_type and self.data_type != "all":
994
+ return False
995
+
996
+ # In the case of tensors...
997
+ if self.is_tensor():
998
+
999
+ # Checking shape
1000
+ if len(self.tensor_shape) == len(props_to_compare.tensor_shape):
1001
+ for s, p in zip(self.tensor_shape, props_to_compare.tensor_shape):
1002
+ if s is not None and p is not None and s != p:
1003
+ return False
1004
+ else:
1005
+ return False
1006
+
1007
+ # Checking labels (if possible)
1008
+ if (not self.has_tensor_labels()) or (not props_to_compare.has_tensor_labels()):
1009
+ return True
1010
+ else:
1011
+ return self.tensor_labels == props_to_compare.tensor_labels
1012
+ else:
1013
+ return True
1014
+
1015
+ def __str__(self):
1016
+ """Provides a string representation of the DataProps instance.
1017
+
1018
+ Returns:
1019
+ str: The string representation of the instance.
1020
+ """
1021
+ return f"[DataProps]\n{self.to_dict()}"
1022
+
1023
+
1024
+ class TensorLabels:
1025
+ """
1026
+ A class to manage labels associated with data and perform operations on them.
1027
+
1028
+ Attributes:
1029
+ VALID_LABELING_RULES (tuple): Tuple of valid labeling rules ('max', 'geq').
1030
+ """
1031
+
1032
+ VALID_LABELING_RULES = ('max', 'geq')
1033
+
1034
+ def __init__(self, data_props: DataProps, labels: list[str] | None, labeling_rule: str = "max"):
1035
+ """Initializes the TensorLabels instance.
1036
+
1037
+ Args:
1038
+ data_props (DataProps): The DataProps instance that owns these labels.
1039
+ labels (list[str] or None): List of labels.
1040
+ labeling_rule (str): The rule for labeling (either 'max' or 'geqX', where X is a number).
1041
+
1042
+ Returns:
1043
+ None
1044
+
1045
+ Raises:
1046
+ AssertionError: If the labels or labeling_rule are invalid.
1047
+ """
1048
+ assert data_props.is_tensor(), "Tensor labels can only be attached to tensor data properties"
1049
+ num_labels = len(labels) if labels is not None else 0
1050
+ assert num_labels == 0 or (data_props.is_tensor() and len(data_props.tensor_shape) == 2), \
1051
+ "Data attribute labels can only be specified for 2d arrays (batch size + data features)"
1052
+ assert len(labeling_rule) >= 3 and labeling_rule[0:3] in TensorLabels.VALID_LABELING_RULES, \
1053
+ "Invalid labeling rule"
1054
+ try:
1055
+ original_labeling_rule = labeling_rule
1056
+ if len(labeling_rule) > 3:
1057
+ labeling_rule_thres = float(labeling_rule[3:])
1058
+ labeling_rule = labeling_rule[0:3]
1059
+ else:
1060
+ labeling_rule_thres = None
1061
+ except ValueError:
1062
+ assert False, "Invalid labeling rule"
1063
+
1064
+ # Basic attributes
1065
+ self.data_props = data_props
1066
+ self.labels = labels
1067
+ self.labeling_rule = labeling_rule
1068
+ self.labeling_rule_thres = labeling_rule_thres
1069
+ self.original_labeling_rule = original_labeling_rule
1070
+
1071
+ # These are mostly operational stuff, similar to private info (but it could be useful to expose them)
1072
+ self.num_labels = num_labels
1073
+ self.indices = None
1074
+
1075
+ def to_dict(self):
1076
+ """Serializes the `TensorLabels` instance into a dictionary, which includes the list of labels and the original
1077
+ labeling rule.
1078
+
1079
+ Returns:
1080
+ A dictionary containing the labels and the original labeling rule.
1081
+ """
1082
+ return {
1083
+ 'labels': self.labels,
1084
+ 'labeling_rule': self.original_labeling_rule
1085
+ }
1086
+
1087
+ def clear_indices(self):
1088
+ """Resets the internal `indices` attribute to `None`. This effectively clears any previous label adaptation
1089
+ that was performed and allows the object to revert to its original, non-interleaved state.
1090
+ """
1091
+ self.indices = None
1092
+
1093
+ def __getitem__(self, idx):
1094
+ """Retrieves the label at the specified index.
1095
+
1096
+ Args:
1097
+ idx (int): The index of the label to retrieve.
1098
+
1099
+ Returns:
1100
+ str: The label at the specified index.
1101
+
1102
+ Raises:
1103
+ ValueError: If the index is out of bounds or labels are not defined.
1104
+ """
1105
+ if self.labels is None:
1106
+ raise ValueError(f"Cannot retrieve any labels, since they are not there at all (None)")
1107
+ if idx < 0 or idx >= self.num_labels:
1108
+ raise ValueError(f"Invalid index {idx} for attribute labels of size {self.num_labels}")
1109
+ return self.labels[idx]
1110
+
1111
+ def __len__(self):
1112
+ """Returns the number of labels.
1113
+
1114
+ Returns:
1115
+ int: The number of labels.
1116
+ """
1117
+ return self.num_labels
1118
+
1119
+ def __iter__(self):
1120
+ """Iterates over the labels.
1121
+
1122
+ Returns:
1123
+ iterator: An iterator over the labels.
1124
+ """
1125
+ return iter(self.labels) if self.labels is not None else iter([])
1126
+
1127
+ def __str__(self):
1128
+ """Provides a string representation of the DataLabels instance.
1129
+
1130
+ Returns:
1131
+ str: The string representation of the instance.
1132
+ """
1133
+ return (f"[TensorLabels] "
1134
+ f"labels: {self.labels}, labeling_rule: {self.labeling_rule}, "
1135
+ f"indices_in_superset: {self.indices})")
1136
+
1137
+ def __eq__(self, other):
1138
+ """Defines how two `TensorLabels` instances are compared for equality using the `==` operator. Two instances
1139
+ are considered equal if they have the same number of labels and the labels themselves match in order.
1140
+
1141
+ Args:
1142
+ other: The other object to compare with.
1143
+
1144
+ Returns:
1145
+ True if the instances are equal, False otherwise.
1146
+ """
1147
+ if not isinstance(other, TensorLabels):
1148
+ return ValueError("Cannot compare a TensorLabels instance and something else")
1149
+
1150
+ if self.num_labels == other.num_labels:
1151
+ if self.num_labels > 0:
1152
+ for i, j in zip(self.labels, other.labels):
1153
+ if i != j:
1154
+ return False
1155
+ return True
1156
+ else:
1157
+ return True
1158
+ else:
1159
+ return False
1160
+
1161
+ def interleave_with(self, superset_labels: list[str]):
1162
+ """Interleaves the current labels with a super-set of labels, determining how to index them.
1163
+
1164
+ Args:
1165
+ superset_labels (list[str]): The super-set of labels to interleave with.
1166
+
1167
+ Returns:
1168
+ None
1169
+
1170
+ Raises:
1171
+ AssertionError: If the super-set of labels is not compatible.
1172
+ """
1173
+ assert superset_labels is not None and self.labels is not None, \
1174
+ f"Can only interleave non-empty sets of attribute labels"
1175
+ assert len(superset_labels) >= len(self), f"You must provide a super-set of attribute labels"
1176
+
1177
+ # Ensuring it is a super-set of the current labels and finding its position
1178
+ if self.indices is not None:
1179
+ labels = []
1180
+ indices_list = self.indices.tolist()
1181
+ for i in indices_list:
1182
+ labels.append(self.labels[i])
1183
+ else:
1184
+ labels = self.labels
1185
+
1186
+ indices = []
1187
+ for label in labels:
1188
+ assert label in superset_labels, \
1189
+ f"Cannot find attribute label {label} in (expected) super-set {superset_labels}"
1190
+ indices.append(superset_labels.index(label))
1191
+
1192
+ if len(indices) == len(superset_labels):
1193
+ same_labels_and_order = True
1194
+ for j, i in enumerate(indices):
1195
+ if j != i:
1196
+ same_labels_and_order = False
1197
+ break
1198
+ else:
1199
+ same_labels_and_order = False
1200
+
1201
+ if not same_labels_and_order:
1202
+ self.labels = superset_labels
1203
+ self.num_labels = len(self.labels)
1204
+ self.indices = torch.tensor(indices, dtype=torch.long)
1205
+
1206
+ # Altering shape
1207
+ self.data_props.tensor_shape = (self.data_props.tensor_shape[0], self.num_labels)
1208
+ else:
1209
+ self.indices = None