unaiverse 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unaiverse might be problematic. Click here for more details.

Files changed (45) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2008 -0
  3. unaiverse/agent_basics.py +1844 -0
  4. unaiverse/clock.py +186 -0
  5. unaiverse/dataprops.py +1209 -0
  6. unaiverse/hsm.py +1880 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +680 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1265 -0
  19. unaiverse/networking/node/node.py +2203 -0
  20. unaiverse/networking/node/profile.py +446 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +259 -0
  23. unaiverse/networking/p2p/golibp2p.py +18 -0
  24. unaiverse/networking/p2p/golibp2p.pyi +135 -0
  25. unaiverse/networking/p2p/lib.go +2495 -0
  26. unaiverse/networking/p2p/lib_types.py +312 -0
  27. unaiverse/networking/p2p/message_pb2.py +63 -0
  28. unaiverse/networking/p2p/messages.py +265 -0
  29. unaiverse/networking/p2p/mylogger.py +77 -0
  30. unaiverse/networking/p2p/p2p.py +963 -0
  31. unaiverse/streamlib/__init__.py +15 -0
  32. unaiverse/streamlib/streamlib.py +210 -0
  33. unaiverse/streams.py +763 -0
  34. unaiverse/utils/__init__.py +16 -0
  35. unaiverse/utils/ask_lone_wolf.json +27 -0
  36. unaiverse/utils/lone_wolf.json +19 -0
  37. unaiverse/utils/misc.py +305 -0
  38. unaiverse/utils/sandbox.py +293 -0
  39. unaiverse/utils/server.py +435 -0
  40. unaiverse/world.py +175 -0
  41. unaiverse-0.1.0.dist-info/METADATA +363 -0
  42. unaiverse-0.1.0.dist-info/RECORD +45 -0
  43. unaiverse-0.1.0.dist-info/WHEEL +5 -0
  44. unaiverse-0.1.0.dist-info/licenses/LICENSE +43 -0
  45. unaiverse-0.1.0.dist-info/top_level.txt +1 -0
unaiverse/streams.py ADDED
@@ -0,0 +1,763 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import os
16
+ import csv
17
+ import math
18
+ import torch
19
+ import random
20
+ import pathlib
21
+ from PIL import Image
22
+ from .clock import Clock
23
+ from .dataprops import DataProps
24
+ from unaiverse.utils.misc import show_images_grid
25
+
26
+
27
+ class DataStream:
28
+ """
29
+ Base class for handling a generic data stream.
30
+ """
31
+
32
+ def __init__(self,
33
+ props: DataProps,
34
+ clock: Clock = Clock()) -> None:
35
+ """Initialize a DataStream.
36
+
37
+ Args:
38
+ props (DataProps): Properties of the stream.
39
+ clock (Clock): Clock object for time management (usually provided from outside).
40
+ """
41
+
42
+ # A stream can be turned off
43
+ self.props = props
44
+ self.clock = clock
45
+ self.data = None
46
+ self.data_timestamp = self.clock.get_time()
47
+ self.data_timestamp_when_got_by = {}
48
+ self.data_tag = -1
49
+ self.data_uuid = None
50
+ self.data_uuid_expected = None
51
+ self.data_uuid_clearable = False
52
+ self.enabled = True
53
+
54
+ @staticmethod
55
+ def create(stream: 'DataStream', name: str | None = None, group: str = 'none',
56
+ public: bool = True, pubsub: bool = True):
57
+ """Create and set the name for a given stream, also updates data stream labels.
58
+
59
+ Args:
60
+ stream (DataStream): The stream object to modify.
61
+ name (str): The name of the stream.
62
+ group (str): The name of the group to which the stream belongs.
63
+ public (bool): If the stream is going to be served in the public net or the private one.
64
+
65
+ Returns:
66
+ Stream: The modified stream with updated group name.
67
+ """
68
+ assert name is not None or group != 'none', "Must provide either name or group name."
69
+ stream.props.set_group(group)
70
+ if name is not None:
71
+ stream.props.set_name(name)
72
+ else:
73
+ stream.props.set_name(stream.props.get_name() + "@" + group) # If name is None, then a group was provided
74
+ stream.props.set_public(public)
75
+ stream.props.set_pubsub(pubsub)
76
+ if (stream.props.is_flat_tensor_with_labels() and
77
+ len(stream.props.tensor_labels) == 1 and stream.props.tensor_labels[0] == 'unk'):
78
+ stream.props.tensor_labels[0] = group if group != 'none' else name
79
+ return stream
80
+
81
+ def enable(self):
82
+ """Enable the stream, allowing data to be retrieved.
83
+
84
+ Returns:
85
+ None
86
+ """
87
+ self.enabled = True
88
+
89
+ def disable(self):
90
+ """Disable the stream, preventing data from being retrieved.
91
+
92
+ Returns:
93
+ None
94
+ """
95
+ self.enabled = False
96
+
97
+ def get_props(self) -> DataProps:
98
+ return self.props
99
+
100
+ def net_hash(self, prefix: str):
101
+ return self.get_props().net_hash(prefix)
102
+
103
+ @staticmethod
104
+ def peer_id_from_net_hash(net_hash):
105
+ return DataProps.peer_id_from_net_hash(net_hash)
106
+
107
+ @staticmethod
108
+ def name_or_group_from_net_hash(net_hash):
109
+ return DataProps.name_or_group_from_net_hash(net_hash)
110
+
111
+ @staticmethod
112
+ def is_pubsub_from_net_hash(net_hash):
113
+ return DataProps.is_pubsub_from_net_hash(net_hash)
114
+
115
+ def is_pubsub(self):
116
+ return self.get_props().is_pubsub()
117
+
118
+ def is_public(self):
119
+ return self.get_props().is_public()
120
+
121
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor | None, int]:
122
+ """Get item for a specific clock cycle. Not implemented for base class: it will be implemented in buffered data
123
+ streams or stream that can generate data on-the-fly.
124
+
125
+ Args:
126
+ idx (int): Index of the data to retrieve.
127
+
128
+ Raises:
129
+ ValueError: Always, since this method should be overridden.
130
+ """
131
+ raise ValueError("Not implemented (expected to be only present in data streams that are buffered or "
132
+ "that can generated on the fly)")
133
+
134
+ def __str__(self) -> str:
135
+ """String representation of the object.
136
+
137
+ Returns:
138
+ str: String representation of the object.
139
+ """
140
+ return f"[DataStream] enabled: {self.enabled}\n\tprops: {self.props}"
141
+
142
+ def __len__(self):
143
+ """Get the length of the stream.
144
+
145
+ Returns:
146
+ int: Infinity (as this is a lifelong stream).
147
+ """
148
+ return math.inf
149
+
150
+ def set(self, data: torch.Tensor | Image.Image | str, data_tag: int = -1) -> bool:
151
+ """Set a new data sample into the stream, that will be provided when calling "get()".
152
+
153
+ Args:
154
+ data (torch.Tensor): Data sample to set.
155
+ data_tag (int): Custom data time tag >= 0 (Default: -1, meaning no tags).
156
+
157
+ Returns:
158
+ bool: True if data was accepted based on time constraints, else False.
159
+ """
160
+ if self.props.delta <= 0. or self.props.delta <= (self.clock.get_time() - self.data_timestamp):
161
+ self.data = self.props.adapt_tensor_to_tensor_labels(data) if data is not None else None
162
+ self.data_timestamp = self.clock.get_cycle_time()
163
+ self.data_tag = data_tag if data_tag >= 0 else self.clock.get_cycle()
164
+ return True
165
+ else:
166
+ return False
167
+
168
+ def get(self, requested_by: str | None = None) -> torch.Tensor | None:
169
+ """Get the most recent data sample from the stream (i.e., the last one that was "set").
170
+
171
+ Returns:
172
+ torch.Tensor | None: Adapted data sample if available.
173
+ """
174
+ if self.enabled:
175
+ if requested_by is not None:
176
+ if (requested_by not in self.data_timestamp_when_got_by or
177
+ self.data_timestamp_when_got_by[requested_by] != self.data_timestamp):
178
+ self.data_timestamp_when_got_by[requested_by] = self.data_timestamp
179
+ return self.data
180
+ else:
181
+ return None
182
+ else:
183
+ return self.data
184
+ else:
185
+ return None
186
+
187
+ def get_timestamp(self) -> float:
188
+ return self.data_timestamp
189
+
190
+ def get_tag(self) -> int:
191
+ return self.data_tag
192
+
193
+ def get_uuid(self, expected: bool = False) -> str | None:
194
+ return self.data_uuid if not expected else self.data_uuid_expected
195
+
196
+ def set_uuid(self, ref_uuid: str | None, expected: bool = False):
197
+ if not expected:
198
+ self.data_uuid = ref_uuid
199
+ else:
200
+ self.data_uuid_expected = ref_uuid
201
+
202
+ # When we "set" a UUID, even if None, we unmark the "clear UUID" flag, to be sure that nobody is going to
203
+ # accidentally clear this information
204
+ self.data_uuid_clearable = False
205
+
206
+ def mark_uuid_as_clearable(self):
207
+ self.data_uuid_clearable = True
208
+
209
+ def clear_uuid_if_marked_as_clearable(self):
210
+ if self.data_uuid_clearable:
211
+ self.data_uuid = None
212
+ self.data_uuid_expected = None
213
+ self.data_uuid_clearable = False
214
+
215
+ def set_props(self, data_stream: 'DataStream'):
216
+ """Set (edit) the data properties picking up the ones of another DataStream.
217
+
218
+ Args:
219
+ data_stream (DataStream): the source DataStream from which DataProp is taken.
220
+
221
+ Returns:
222
+ None
223
+ """
224
+ self.props = data_stream.props
225
+
226
+
227
+ class BufferedDataStream(DataStream):
228
+ """
229
+ Data stream with buffer support to store historical data.
230
+ """
231
+
232
+ def __init__(self, props: DataProps, clock: Clock = Clock(), is_static: bool = False):
233
+ """Initialize a BufferedDataStream.
234
+
235
+ Args:
236
+ is_static (bool): If True, the buffer stores only one item that is reused.
237
+ """
238
+ super().__init__(props=props, clock=clock)
239
+
240
+ # We store the data samples, and we cache their text representation (for speed)
241
+ self.data_buffer = []
242
+ self.text_buffer = []
243
+
244
+ self.is_static = is_static # A static stream store only one sample and always yields it
245
+
246
+ # We need to remember the fist cycle in which we started buffering and the last one we buffered
247
+ self.first_cycle = -1
248
+ self.last_cycle = -1
249
+ self.last_get_cycle = -2 # Keep it to -2 (since -1 is the starting value for cycles)
250
+ self.buffered_data_index = -1
251
+
252
+ self.restart_before_next_get = set()
253
+
254
+ def get(self, requested_by: str | None = None) -> tuple[torch.Tensor | None, float]:
255
+ """Get the current data sample based on cycle and buffer.
256
+
257
+ Returns:
258
+ torch.Tensor | None: Current buffered sample.
259
+ """
260
+ if requested_by is not None and requested_by in self.restart_before_next_get:
261
+ self.restart_before_next_get.remove(requested_by)
262
+ self.restart()
263
+
264
+ cycle = self.clock.get_cycle() - self.first_cycle # This ensures that first get clock = first sample
265
+
266
+ # These two lines might make you think "hey, call super().set(self[cycle]), it is the same!"
267
+ # however, it is not like that, since "set" will also call "adapt_to_labels", that is not needed for
268
+ # buffered streams
269
+ if (self.last_get_cycle != cycle and
270
+ (self.props.delta <= 0. or self.props.delta <= (self.clock.get_time() - self.data_timestamp))):
271
+ self.last_get_cycle = cycle
272
+ self.buffered_data_index += 1
273
+ new_data, new_tag = self[self.buffered_data_index]
274
+ self.data = new_data
275
+ self.data_timestamp = self.clock.get_cycle_time()
276
+ self.data_tag = new_tag
277
+ return super().get(requested_by)
278
+
279
+ def set(self, data: torch.Tensor, data_tag: int = -1):
280
+ """Store a new data sample into the buffer.
281
+
282
+ Args:
283
+ data (torch.Tensor): Data to store.
284
+ data_tag (int): Custom data time tag >= 0 (Default: -1, meaning no tags).
285
+
286
+ Returns:
287
+ bool: True if the data was buffered.
288
+ """
289
+ ret = super().set(data, data_tag)
290
+
291
+ if ret:
292
+ if not self.is_static or len(self.data_buffer) == 0:
293
+ self.data_buffer.append((self.props.adapt_tensor_to_tensor_labels(data), self.get_tag()))
294
+ if self.props.is_flat_tensor_with_labels():
295
+ self.text_buffer.append(self.props.to_text(data))
296
+ elif self.props.is_text():
297
+ self.text_buffer.append(data)
298
+
299
+ # Boilerplate
300
+ if self.first_cycle < 0:
301
+ self.first_cycle = self.clock.get_cycle()
302
+ self.last_cycle = self.first_cycle
303
+ else:
304
+
305
+ # Filling gaps with "None"
306
+ cycle = self.clock.get_cycle()
307
+ if cycle > self.last_cycle + 1:
308
+ for cycle in range(cycle, self.last_cycle + 1):
309
+ self.data_buffer.append((None, -1))
310
+ self.last_cycle = cycle - 1
311
+
312
+ self.last_cycle += 1
313
+ return ret
314
+
315
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor | None, int]:
316
+ """Retrieve a sample from the buffer based on the given clock cycle.
317
+
318
+ Args:
319
+ idx (int): Index (>=0) of the sample.
320
+
321
+ Returns:
322
+ torch.Tensor | None: The sample, if available.
323
+ """
324
+ if not self.is_static:
325
+ if idx >= self.__len__() or idx < 0:
326
+ return None, -1
327
+ data, data_tag = self.data_buffer[idx]
328
+ else:
329
+ data, data_tag = self.data_buffer[0]
330
+ return data, data_tag if data_tag >= 0 else self.clock.get_cycle() - self.first_cycle
331
+
332
+ def __len__(self):
333
+ """Get number of samples in the buffer.
334
+
335
+ Returns:
336
+ int: Number of buffered samples.
337
+ """
338
+ return len(self.data_buffer)
339
+
340
+ def set_first_cycle(self, cycle):
341
+ """Manually set the first cycle for the buffer.
342
+
343
+ Args:
344
+ cycle (int): Global cycle to start from.
345
+ """
346
+ self.first_cycle = cycle
347
+ self.last_cycle = cycle + len(self)
348
+
349
+ def get_first_cycle(self):
350
+ """Get the first cycle of the stream.
351
+
352
+ Returns:
353
+ int: First cycle value.
354
+ """
355
+ return self.first_cycle
356
+
357
+ def restart(self):
358
+ """Restart the buffer using the current clock cycle.
359
+ """
360
+ self.set_first_cycle(max(self.clock.get_cycle(), 0))
361
+ self.buffered_data_index = -1
362
+ self.data_timestamp_when_got_by = {}
363
+
364
+ def plan_restart_before_next_get(self, requested_by: str):
365
+ self.restart_before_next_get.add(requested_by)
366
+
367
+ def clear_buffer(self):
368
+ """Clear the data buffer
369
+ """
370
+ self.data_buffer = []
371
+ self.text_buffer = []
372
+
373
+ self.first_cycle = -1
374
+ self.last_cycle = -1
375
+ self.last_get_cycle = -2 # Keep it to -2 (since -1 is the starting value for cycles)
376
+ self.buffered_data_index = -1
377
+
378
+ self.restart_before_next_get = set()
379
+
380
+ def shuffle_buffer(self, seed: int = -1):
381
+ old_buffer = self.data_buffer
382
+ indices = list(range(len(old_buffer)))
383
+
384
+ state = random.getstate()
385
+ if seed >= 0:
386
+ random.seed(seed)
387
+ random.shuffle(indices)
388
+ if seed >= 0:
389
+ random.setstate(state)
390
+
391
+ self.data_buffer = []
392
+ k = 0
393
+ for i in indices:
394
+ self.data_buffer.append((old_buffer[i][0], old_buffer[k][1]))
395
+ k += 1
396
+
397
+ if self.text_buffer is not None and len(self.text_buffer) == len(self.data_buffer):
398
+ old_text_buffer = self.text_buffer
399
+ self.text_buffer = []
400
+ for i in indices:
401
+ self.text_buffer.append(old_text_buffer[i])
402
+
403
+ def to_text_snippet(self, length: int | None = None):
404
+ """Convert buffered text samples to a single long string.
405
+
406
+ Args:
407
+ length (int | None): Optional length of the resulting text snippet.
408
+
409
+ Returns:
410
+ str | None: Human-readable text sequence.
411
+ """
412
+ if self.text_buffer is not None and len(self.text_buffer) > 0:
413
+ if length is not None:
414
+ le = max(length // 2, 1)
415
+ text = " ".join(self.text_buffer[0:min(le, len(self.text_buffer))])
416
+ text += (" ... " + (" ".join(self.text_buffer[max(le, len(self.text_buffer) - le):]))) \
417
+ if len(self.text_buffer) > le else ""
418
+ else:
419
+ text = " ".join(self.text_buffer)
420
+ else:
421
+ text = None
422
+ return text
423
+
424
+ def get_since_timestamp(self, since_what_timestamp: float, stride: int = 1) -> (
425
+ tuple[list[int] | None, list[torch.Tensor | None] | None, int, DataProps]):
426
+ """Retrieve all samples starting from a given timestamp.
427
+
428
+ Args:
429
+ since_what_timestamp (float): Timestamp in seconds.
430
+ stride (int): Sampling stride.
431
+
432
+ Returns:
433
+ Tuple containing list of cycles, data, current cycle, and data properties.
434
+ """
435
+ since_what_cycle = self.clock.time2cycle(since_what_timestamp)
436
+ return self.get_since_cycle(since_what_cycle, stride)
437
+
438
+ def get_since_cycle(self, since_what_cycle: int, stride: int = 1) -> (
439
+ tuple[list[int] | None, list[torch.Tensor | None] | None, int, DataProps]):
440
+ """Retrieve all samples starting from a given clock cycle.
441
+
442
+ Args:
443
+ since_what_cycle (int): Cycle number.
444
+ stride (int): Stride to skip cycles.
445
+
446
+ Returns:
447
+ Tuple with cycles, data, current cycle, and properties.
448
+ """
449
+ assert stride >= 1 and isinstance(stride, int), f"Invalid stride: {stride}"
450
+
451
+ # Notice: this whole routed never calls ".get()", on purpose! it must be as it is
452
+ global_cycle = self.clock.get_cycle()
453
+ if global_cycle < 0:
454
+ return None, None, -1, self.props
455
+
456
+ # Fist check: ensure we do not go beyond the first clock and counting the resulting number of steps
457
+ since_what_cycle = max(since_what_cycle, 0)
458
+ num_steps = global_cycle - since_what_cycle + 1
459
+
460
+ # Second check: now we compute the index we should pass to get item
461
+ since_what_idx_in_getitem = since_what_cycle - self.first_cycle
462
+
463
+ ret_cycles = []
464
+ ret_data = []
465
+
466
+ for k in range(0, num_steps, stride):
467
+ _idx = since_what_idx_in_getitem + k
468
+ _data, _ = self[_idx]
469
+
470
+ if _data is not None:
471
+ ret_cycles.append(since_what_cycle + k)
472
+ ret_data.append(_data)
473
+
474
+ return ret_cycles, ret_data, global_cycle, self.props
475
+
476
+
477
+ class System(DataStream):
478
+ def __init__(self):
479
+ super().__init__(props=DataProps(name=System.__name__, data_type="text", data_desc="System stream",
480
+ pubsub=False))
481
+ self.set("ping")
482
+
483
+ def get(self, requested_by: str | None = None):
484
+ return super().get()
485
+
486
+
487
+ class Dataset(BufferedDataStream):
488
+ """
489
+ A buffered dataset that streams data from a PyTorch dataset and simulates data-streams for input/output.
490
+ """
491
+
492
+ def __init__(self, tensor_dataset: torch.utils.data.Dataset, shape: tuple, index: int = 0, batch_size: int = 1):
493
+ """Initialize a Dataset instance, which wraps around a PyTorch Dataset.
494
+
495
+ Args:
496
+ tensor_dataset (torch.utils.data.Dataset): The PyTorch Dataset to wrap.
497
+ shape (tuple): The shape of each sample from the data stream.
498
+ index (int): The index of the element returned by __getitem__ to pick up.
499
+ """
500
+ sample = tensor_dataset[0][index]
501
+ if isinstance(sample, torch.Tensor):
502
+ dtype = sample.dtype
503
+ elif isinstance(sample, int):
504
+ dtype = torch.long
505
+ elif isinstance(sample, float):
506
+ dtype = torch.float32
507
+ else:
508
+ raise ValueError("Expected tensor data or a scalar")
509
+
510
+ super().__init__(props=DataProps(name=Dataset.__name__,
511
+ data_type="tensor",
512
+ data_desc="dataset",
513
+ tensor_shape=shape,
514
+ tensor_dtype=dtype,
515
+ pubsub=True))
516
+
517
+ n = len(tensor_dataset)
518
+ b = batch_size
519
+ nb = math.ceil(float(n) / float(b))
520
+ r = n - b * (nb - 1)
521
+
522
+ for i in range(0, nb):
523
+ batch = []
524
+ if i == (nb - 1):
525
+ b = r
526
+
527
+ for j in range(0, b):
528
+ sample = tensor_dataset[i * b + j][index]
529
+ if isinstance(sample, (int, float)):
530
+ sample = torch.tensor(sample, dtype=dtype)
531
+ batch.append(sample)
532
+
533
+ self.data_buffer.append((torch.stack(batch), -1))
534
+
535
+ # It was buffered previously than every other thing
536
+ self.restart()
537
+
538
+
539
+ class ImageFileStream(BufferedDataStream):
540
+ """
541
+ A buffered dataset for image data.
542
+ """
543
+
544
+ def __init__(self, image_dir: str, list_of_image_files: str,
545
+ device: torch.device = None, circular: bool = True, show_images: bool = False):
546
+ """Initialize an ImageFileStream instance for streaming image data.
547
+
548
+ Args:
549
+ image_dir (str): The directory containing image files.
550
+ list_of_image_files (str): Path to the file with list of file names of the images.
551
+ device (torch.device): The device to store the tensors on. Default is CPU.
552
+ circular (bool): Whether to loop the dataset or not. Default is True.
553
+ """
554
+ self.image_dir = image_dir
555
+ self.device = device if device is not None else torch.device("cpu")
556
+ self.circular = circular
557
+
558
+ # Reading the image file
559
+ # (assume a file with one filename per line or a CSV format with lines such as: cat.jpg,cat,mammal,animal)
560
+ self.image_paths = []
561
+
562
+ # Calling the constructor
563
+ super().__init__(props=DataProps(name=ImageFileStream.__name__,
564
+ data_type="img",
565
+ pubsub=True))
566
+
567
+ with open(list_of_image_files, 'r') as f:
568
+ for line in f:
569
+ parts = line.strip().split(',') # Tolerates if it is a CVS and the first field is the image file name
570
+ image_name = parts[0]
571
+ self.image_paths.append(os.path.join(image_dir, image_name))
572
+
573
+ # It was buffered previously than every other thing
574
+ self.last_cycle = -1
575
+ self.first_cycle = self.last_cycle - len(self.image_paths) + 1
576
+
577
+ # Possibly print to screen the "clickable" list of images
578
+ if show_images:
579
+ show_images_grid(self.image_paths)
580
+ for i, image_path in enumerate(self.image_paths):
581
+ abs_path = os.path.abspath(image_path)
582
+ file_url = pathlib.Path(abs_path).as_uri()
583
+ basename = os.path.basename(abs_path) # 'photo.jpg'
584
+ parent = os.path.basename(os.path.dirname(abs_path)) # 'images'
585
+ label = os.path.join(parent, basename) if parent else basename
586
+ clickable_label = f"\033]8;;{file_url}\033\\[{label}]\033]8;;\033\\"
587
+ print(str(i) + " => " + clickable_label)
588
+
589
+ def __len__(self):
590
+ """Return the number of images in the dataset.
591
+
592
+ Returns:
593
+ int: Number of images in the dataset.
594
+ """
595
+ return len(self.image_paths)
596
+
597
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor | None, int]:
598
+ """Get the image and label for the specified cycle number.
599
+
600
+ Args:
601
+ idx (int): The cycle number to retrieve data for.
602
+
603
+ Returns:
604
+ tuple: A tuple of tensors (image, label) for the specified cycle.
605
+ """
606
+ if self.circular:
607
+ idx %= self.__len__()
608
+
609
+ image = Image.open(self.image_paths[idx])
610
+ return image, self.clock.get_cycle() - self.first_cycle
611
+
612
+
613
+ class LabelStream(BufferedDataStream):
614
+ """
615
+ A buffered stream for single and multi-label annotations.
616
+ """
617
+
618
+ def __init__(self, label_dir: str, label_file_csv: str,
619
+ device: torch.device = None, circular: bool = True, single_class: bool = False,
620
+ line_header: bool = False):
621
+ """Initialize an LabelStream instance for streaming labels.
622
+
623
+ Args:
624
+ label_dir (str): The directory containing image files.
625
+ label_file_csv (str): Path to the CSV file with labels for the images.
626
+ device (torch.device): The device to store the tensors on. Default is CPU.
627
+ circular (bool): Whether to loop the dataset or not. Default is True.
628
+ single_class (bool): Whether to only consider a single class for labeling. Default is False.
629
+ """
630
+ self.label_dir = label_dir
631
+ self.device = device if device is not None else torch.device("cpu")
632
+ self.circular = circular
633
+
634
+ # Reading the label file
635
+ # (assume a file with a labeled element per line or a CSV format with lines such as: cat.jpg,cat,mammal,animal)
636
+ self.labels = []
637
+
638
+ class_names = {}
639
+ with open(label_file_csv, 'r') as f:
640
+ for line in f:
641
+ parts = line.strip().split(',')
642
+ label = parts[1:]
643
+ for lab in label:
644
+ class_names[lab] = True
645
+ class_name_to_index = {}
646
+ class_names = list(class_names.keys())
647
+
648
+ # Call the constructor
649
+ super().__init__(props=DataProps(name=LabelStream.__name__,
650
+ data_type="tensor",
651
+ data_desc="label stream",
652
+ tensor_shape=(1, len(class_names)),
653
+ tensor_dtype=str(torch.float),
654
+ tensor_labels=class_names,
655
+ tensor_labeling_rule="geq0.5" if not single_class else "max",
656
+ pubsub=True))
657
+
658
+ for idx, class_name in enumerate(class_names):
659
+ class_name_to_index[class_name] = idx
660
+
661
+ with open(label_file_csv, 'r') as f:
662
+ for line in f:
663
+ parts = line.strip().split(',')
664
+ label = parts if not line_header else parts[1:]
665
+ target_vector = torch.zeros((1, len(class_names)), dtype=torch.float32)
666
+ for lab in label:
667
+ idx = class_name_to_index[lab]
668
+ target_vector[0, idx] = 1.
669
+ self.labels.append(target_vector)
670
+
671
+ # It was buffered previously than every other thing
672
+ self.last_cycle = -1
673
+ self.first_cycle = self.last_cycle - len(self.labels) + 1
674
+
675
+ def __len__(self):
676
+ """Return the number of labels in the dataset.
677
+
678
+ Returns:
679
+ int: Number of labels in the dataset.
680
+ """
681
+ return len(self.labels)
682
+
683
+ def __getitem__(self, idx: int) -> torch.Tensor | None:
684
+ """Get the image and label for the specified cycle number.
685
+
686
+ Args:
687
+ idx (int): The cycle number to retrieve data for.
688
+
689
+ Returns:
690
+ tuple: A tuple of tensors (image, label) for the specified cycle.
691
+ """
692
+ if self.circular:
693
+ idx %= self.__len__()
694
+
695
+ label = self.labels[idx].unsqueeze(0).to(self.device) # Multi-label vector for the image
696
+ return self.props.adapt_tensor_to_tensor_labels(label), self.clock.get_cycle() - self.first_cycle
697
+
698
+
699
+ class TokensStream(BufferedDataStream):
700
+ """
701
+ A buffered dataset for tokenized text, where each token is paired with its corresponding labels.
702
+ """
703
+
704
+ def __init__(self, tokens_file_csv: str, circular: bool = True, max_tokens: int = -1):
705
+ """Initialize a Tokens instance for streaming tokenized data and associated labels.
706
+
707
+ Args:
708
+ tokens_file_csv (str): Path to the CSV file containing token data.
709
+ circular (bool): Whether to loop the dataset or not. Default is True.
710
+ max_tokens (int): Whether to cut the stream to a maximum number of tokens. Default is -1 (no cut).
711
+ """
712
+ self.circular = circular
713
+
714
+ # Reading the data file (assume a token per line or a CSV format with lines such as:
715
+ # token,category_label1,category_label2,etc.)
716
+ tokens = []
717
+ with open(tokens_file_csv, 'r') as f:
718
+ for line in f:
719
+ parts = next(csv.reader([line], quotechar='"', delimiter=','))
720
+ tokens.append(parts[0])
721
+ if 0 < max_tokens <= len(tokens):
722
+ break
723
+
724
+ # Vocabulary
725
+ idx = 0
726
+ word2id = {}
727
+ sorted_stream_of_tokens = sorted(tokens)
728
+ for token in sorted_stream_of_tokens:
729
+ if token not in word2id:
730
+ word2id[token] = idx
731
+ idx += 1
732
+ id2word = [""] * len(word2id)
733
+ for _word, _id in word2id.items():
734
+ id2word[_id] = _word
735
+
736
+ # Calling the constructor
737
+ super().__init__(props=DataProps(name=TokensStream.__name__,
738
+ data_type="text",
739
+ data_desc="stream of words",
740
+ stream_to_proc_transforms=word2id,
741
+ proc_to_stream_transforms=id2word,
742
+ pubsub=True))
743
+
744
+ # Tokenized text
745
+ for i, token in enumerate(tokens):
746
+ data = token
747
+ self.data_buffer.append((data, -1))
748
+
749
+ # It was buffered previously than every other thing
750
+ self.restart()
751
+
752
+ def __getitem__(self, idx: int) -> torch.Tensor | None:
753
+ """Get the image and label for the specified cycle number.
754
+
755
+ Args:
756
+ idx (int): The index to retrieve data for.
757
+
758
+ Returns:
759
+ tuple: A tuple of tensors (image, label) for the specified cycle.
760
+ """
761
+ if self.circular:
762
+ idx %= self.__len__()
763
+ return super().__getitem__(idx)