unaiverse 0.1.6__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unaiverse might be problematic. Click here for more details.

Files changed (50) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2008 -0
  3. unaiverse/agent_basics.py +1846 -0
  4. unaiverse/clock.py +191 -0
  5. unaiverse/dataprops.py +1209 -0
  6. unaiverse/hsm.py +1880 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +680 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1261 -0
  19. unaiverse/networking/node/node.py +2223 -0
  20. unaiverse/networking/node/profile.py +446 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +198 -0
  23. unaiverse/networking/p2p/go.mod +127 -0
  24. unaiverse/networking/p2p/go.sum +548 -0
  25. unaiverse/networking/p2p/golibp2p.py +18 -0
  26. unaiverse/networking/p2p/golibp2p.pyi +135 -0
  27. unaiverse/networking/p2p/lib.go +2714 -0
  28. unaiverse/networking/p2p/lib.go.sha256 +1 -0
  29. unaiverse/networking/p2p/lib_types.py +312 -0
  30. unaiverse/networking/p2p/message_pb2.py +63 -0
  31. unaiverse/networking/p2p/messages.py +265 -0
  32. unaiverse/networking/p2p/mylogger.py +77 -0
  33. unaiverse/networking/p2p/p2p.py +929 -0
  34. unaiverse/networking/p2p/proto-go/message.pb.go +616 -0
  35. unaiverse/networking/p2p/unailib.cpython-312-darwin.so +0 -0
  36. unaiverse/streamlib/__init__.py +15 -0
  37. unaiverse/streamlib/streamlib.py +210 -0
  38. unaiverse/streams.py +770 -0
  39. unaiverse/utils/__init__.py +16 -0
  40. unaiverse/utils/ask_lone_wolf.json +27 -0
  41. unaiverse/utils/lone_wolf.json +19 -0
  42. unaiverse/utils/misc.py +305 -0
  43. unaiverse/utils/sandbox.py +293 -0
  44. unaiverse/utils/server.py +435 -0
  45. unaiverse/world.py +175 -0
  46. unaiverse-0.1.6.dist-info/METADATA +365 -0
  47. unaiverse-0.1.6.dist-info/RECORD +50 -0
  48. unaiverse-0.1.6.dist-info/WHEEL +6 -0
  49. unaiverse-0.1.6.dist-info/licenses/LICENSE +43 -0
  50. unaiverse-0.1.6.dist-info/top_level.txt +1 -0
unaiverse/streams.py ADDED
@@ -0,0 +1,770 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import os
16
+ import csv
17
+ import math
18
+ import torch
19
+ import random
20
+ import pathlib
21
+ from PIL import Image
22
+ from .clock import Clock
23
+ from .dataprops import DataProps
24
+ from unaiverse.utils.misc import show_images_grid
25
+
26
+
27
+ class DataStream:
28
+ """
29
+ Base class for handling a generic data stream.
30
+ """
31
+
32
+ def __init__(self,
33
+ props: DataProps,
34
+ clock: Clock = Clock()) -> None:
35
+ """Initialize a DataStream.
36
+
37
+ Args:
38
+ props (DataProps): Properties of the stream.
39
+ clock (Clock): Clock object for time management (usually provided from outside).
40
+ """
41
+
42
+ # A stream can be turned off
43
+ self.props = props
44
+ self.clock = clock
45
+ self.data = None
46
+ self.data_timestamp = self.clock.get_time()
47
+ self.data_timestamp_when_got_by = {}
48
+ self.data_tag = -1
49
+ self.data_uuid = None
50
+ self.data_uuid_expected = None
51
+ self.data_uuid_clearable = False
52
+ self.enabled = True
53
+
54
+ @staticmethod
55
+ def create(stream: 'DataStream', name: str | None = None, group: str = 'none',
56
+ public: bool = True, pubsub: bool = True):
57
+ """Create and set the name for a given stream, also updates data stream labels.
58
+
59
+ Args:
60
+ stream (DataStream): The stream object to modify.
61
+ name (str): The name of the stream.
62
+ group (str): The name of the group to which the stream belongs.
63
+ public (bool): If the stream is going to be served in the public net or the private one.
64
+ pubsub (bool): If the stream is going to be served by broadcasting (PubSub) or not.
65
+
66
+ Returns:
67
+ Stream: The modified stream with updated group name.
68
+ """
69
+ assert name is not None or group != 'none', "Must provide either name or group name."
70
+ stream.props.set_group(group)
71
+ if name is not None:
72
+ stream.props.set_name(name)
73
+ else:
74
+ stream.props.set_name(stream.props.get_name() + "@" + group) # If name is None, then a group was provided
75
+ stream.props.set_public(public)
76
+ stream.props.set_pubsub(pubsub)
77
+ if (stream.props.is_flat_tensor_with_labels() and
78
+ len(stream.props.tensor_labels) == 1 and stream.props.tensor_labels[0] == 'unk'):
79
+ stream.props.tensor_labels[0] = group if group != 'none' else name
80
+ return stream
81
+
82
+ def enable(self):
83
+ """Enable the stream, allowing data to be retrieved.
84
+
85
+ Returns:
86
+ None
87
+ """
88
+ self.enabled = True
89
+
90
+ def disable(self):
91
+ """Disable the stream, preventing data from being retrieved.
92
+
93
+ Returns:
94
+ None
95
+ """
96
+ self.enabled = False
97
+
98
+ def get_props(self) -> DataProps:
99
+ return self.props
100
+
101
+ def net_hash(self, prefix: str):
102
+ return self.get_props().net_hash(prefix)
103
+
104
+ @staticmethod
105
+ def peer_id_from_net_hash(net_hash):
106
+ return DataProps.peer_id_from_net_hash(net_hash)
107
+
108
+ @staticmethod
109
+ def name_or_group_from_net_hash(net_hash):
110
+ return DataProps.name_or_group_from_net_hash(net_hash)
111
+
112
+ @staticmethod
113
+ def is_pubsub_from_net_hash(net_hash):
114
+ return DataProps.is_pubsub_from_net_hash(net_hash)
115
+
116
+ def is_pubsub(self):
117
+ return self.get_props().is_pubsub()
118
+
119
+ def is_public(self):
120
+ return self.get_props().is_public()
121
+
122
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor | None, int]:
123
+ """Get item for a specific clock cycle. Not implemented for base class: it will be implemented in buffered data
124
+ streams or stream that can generate data on-the-fly.
125
+
126
+ Args:
127
+ idx (int): Index of the data to retrieve.
128
+
129
+ Raises:
130
+ ValueError: Always, since this method should be overridden.
131
+ """
132
+ raise ValueError("Not implemented (expected to be only present in data streams that are buffered or "
133
+ "that can generated on the fly)")
134
+
135
+ def __str__(self) -> str:
136
+ """String representation of the object.
137
+
138
+ Returns:
139
+ str: String representation of the object.
140
+ """
141
+ return f"[DataStream] enabled: {self.enabled}\n\tprops: {self.props}"
142
+
143
+ def __len__(self):
144
+ """Get the length of the stream.
145
+
146
+ Returns:
147
+ int: Infinity (as this is a lifelong stream).
148
+ """
149
+ return math.inf
150
+
151
+ def set(self, data: torch.Tensor | Image.Image | str, data_tag: int = -1) -> bool:
152
+ """Set a new data sample into the stream, that will be provided when calling "get()".
153
+
154
+ Args:
155
+ data (torch.Tensor): Data sample to set.
156
+ data_tag (int): Custom data time tag >= 0 (Default: -1, meaning no tags).
157
+
158
+ Returns:
159
+ bool: True if data was accepted based on time constraints, else False.
160
+ """
161
+ if self.props.delta <= 0. or self.props.delta <= (self.clock.get_time() - self.data_timestamp):
162
+ self.data = self.props.adapt_tensor_to_tensor_labels(data) if data is not None else None
163
+ self.data_timestamp = self.clock.get_cycle_time()
164
+ self.data_tag = data_tag if data_tag >= 0 else self.clock.get_cycle()
165
+ return True
166
+ else:
167
+ return False
168
+
169
+ def get(self, requested_by: str | None = None) -> torch.Tensor | None:
170
+ """Get the most recent data sample from the stream (i.e., the last one that was "set").
171
+
172
+ Returns:
173
+ torch.Tensor | None: Adapted data sample if available.
174
+ """
175
+ if self.enabled:
176
+ if requested_by is not None:
177
+ if (requested_by not in self.data_timestamp_when_got_by or
178
+ self.data_timestamp_when_got_by[requested_by] != self.data_timestamp):
179
+ self.data_timestamp_when_got_by[requested_by] = self.data_timestamp
180
+ return self.data
181
+ else:
182
+ return None
183
+ else:
184
+ return self.data
185
+ else:
186
+ return None
187
+
188
+ def get_timestamp(self) -> float:
189
+ return self.data_timestamp
190
+
191
+ def get_tag(self) -> int:
192
+ return self.data_tag
193
+
194
+ def get_uuid(self, expected: bool = False) -> str | None:
195
+ return self.data_uuid if not expected else self.data_uuid_expected
196
+
197
+ def set_uuid(self, ref_uuid: str | None, expected: bool = False):
198
+ if not expected:
199
+ self.data_uuid = ref_uuid
200
+ else:
201
+ self.data_uuid_expected = ref_uuid
202
+
203
+ # When we "set" a UUID, even if None, we unmark the "clear UUID" flag, to be sure that nobody is going to
204
+ # accidentally clear this information
205
+ self.data_uuid_clearable = False
206
+
207
+ def mark_uuid_as_clearable(self):
208
+ self.data_uuid_clearable = True
209
+
210
+ def clear_uuid_if_marked_as_clearable(self):
211
+ if self.data_uuid_clearable:
212
+ self.data_uuid = None
213
+ self.data_uuid_expected = None
214
+ self.data_uuid_clearable = False
215
+
216
+ def set_props(self, data_stream: 'DataStream'):
217
+ """Set (edit) the data properties picking up the ones of another DataStream.
218
+
219
+ Args:
220
+ data_stream (DataStream): the source DataStream from which DataProp is taken.
221
+
222
+ Returns:
223
+ None
224
+ """
225
+ self.props = data_stream.props
226
+
227
+
228
+ class BufferedDataStream(DataStream):
229
+ """
230
+ Data stream with buffer support to store historical data.
231
+ """
232
+
233
+ def __init__(self, props: DataProps, clock: Clock = Clock(), is_static: bool = False):
234
+ """Initialize a BufferedDataStream.
235
+
236
+ Args:
237
+ is_static (bool): If True, the buffer stores only one item that is reused.
238
+ """
239
+ super().__init__(props=props, clock=clock)
240
+
241
+ # We store the data samples, and we cache their text representation (for speed)
242
+ self.data_buffer = []
243
+ self.text_buffer = []
244
+
245
+ self.is_static = is_static # A static stream store only one sample and always yields it
246
+
247
+ # We need to remember the fist cycle in which we started buffering and the last one we buffered
248
+ self.first_cycle = -1
249
+ self.last_cycle = -1
250
+ self.last_get_cycle = -2 # Keep it to -2 (since -1 is the starting value for cycles)
251
+ self.buffered_data_index = -1
252
+
253
+ self.restart_before_next_get = set()
254
+
255
+ def get(self, requested_by: str | None = None) -> tuple[torch.Tensor | None, float]:
256
+ """Get the current data sample based on cycle and buffer.
257
+
258
+ Returns:
259
+ torch.Tensor | None: Current buffered sample.
260
+ """
261
+ if requested_by is not None and requested_by in self.restart_before_next_get:
262
+ self.restart_before_next_get.remove(requested_by)
263
+ self.restart()
264
+
265
+ cycle = self.clock.get_cycle() - self.first_cycle # This ensures that first get clock = first sample
266
+
267
+ # These two lines might make you think "hey, call super().set(self[cycle]), it is the same!"
268
+ # however, it is not like that, since "set" will also call "adapt_to_labels", that is not needed for
269
+ # buffered streams
270
+ if (self.last_get_cycle != cycle and
271
+ (self.props.delta <= 0. or self.props.delta <= (self.clock.get_time() - self.data_timestamp))):
272
+ self.last_get_cycle = cycle
273
+ self.buffered_data_index += 1
274
+ new_data, new_tag = self[self.buffered_data_index]
275
+ self.data = new_data
276
+ self.data_timestamp = self.clock.get_cycle_time()
277
+ self.data_tag = new_tag
278
+ return super().get(requested_by)
279
+
280
+ def set(self, data: torch.Tensor, data_tag: int = -1):
281
+ """Store a new data sample into the buffer.
282
+
283
+ Args:
284
+ data (torch.Tensor): Data to store.
285
+ data_tag (int): Custom data time tag >= 0 (Default: -1, meaning no tags).
286
+
287
+ Returns:
288
+ bool: True if the data was buffered.
289
+ """
290
+ ret = super().set(data, data_tag)
291
+
292
+ if ret:
293
+ if not self.is_static or len(self.data_buffer) == 0:
294
+ self.data_buffer.append((self.props.adapt_tensor_to_tensor_labels(data), self.get_tag()))
295
+ if self.props.is_flat_tensor_with_labels():
296
+ self.text_buffer.append(self.props.to_text(data))
297
+ elif self.props.is_text():
298
+ self.text_buffer.append(data)
299
+
300
+ # Boilerplate
301
+ if self.first_cycle < 0:
302
+ self.first_cycle = self.clock.get_cycle()
303
+ self.last_cycle = self.first_cycle
304
+ else:
305
+
306
+ # Filling gaps with "None"
307
+ cycle = self.clock.get_cycle()
308
+ if cycle > self.last_cycle + 1:
309
+ for cycle in range(cycle, self.last_cycle + 1):
310
+ self.data_buffer.append((None, -1))
311
+ self.last_cycle = cycle - 1
312
+
313
+ self.last_cycle += 1
314
+ return ret
315
+
316
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor | None, int]:
317
+ """Retrieve a sample from the buffer based on the given clock cycle.
318
+
319
+ Args:
320
+ idx (int): Index (>=0) of the sample.
321
+
322
+ Returns:
323
+ torch.Tensor | None: The sample, if available.
324
+ """
325
+ if not self.is_static:
326
+ if idx >= self.__len__() or idx < 0:
327
+ return None, -1
328
+ data, data_tag = self.data_buffer[idx]
329
+ else:
330
+ data, data_tag = self.data_buffer[0]
331
+ return data, data_tag if data_tag >= 0 else self.clock.get_cycle() - self.first_cycle
332
+
333
+ def __len__(self):
334
+ """Get number of samples in the buffer.
335
+
336
+ Returns:
337
+ int: Number of buffered samples.
338
+ """
339
+ return len(self.data_buffer)
340
+
341
+ def set_first_cycle(self, cycle):
342
+ """Manually set the first cycle for the buffer.
343
+
344
+ Args:
345
+ cycle (int): Global cycle to start from.
346
+ """
347
+ self.first_cycle = cycle
348
+ self.last_cycle = cycle + len(self)
349
+
350
+ def get_first_cycle(self):
351
+ """Get the first cycle of the stream.
352
+
353
+ Returns:
354
+ int: First cycle value.
355
+ """
356
+ return self.first_cycle
357
+
358
+ def restart(self):
359
+ """Restart the buffer using the current clock cycle.
360
+ """
361
+ self.set_first_cycle(max(self.clock.get_cycle(), 0))
362
+ self.buffered_data_index = -1
363
+ self.data_timestamp_when_got_by = {}
364
+
365
+ def plan_restart_before_next_get(self, requested_by: str):
366
+ self.restart_before_next_get.add(requested_by)
367
+
368
+ def clear_buffer(self):
369
+ """Clear the data buffer
370
+ """
371
+ self.data_buffer = []
372
+ self.text_buffer = []
373
+
374
+ self.first_cycle = -1
375
+ self.last_cycle = -1
376
+ self.last_get_cycle = -2 # Keep it to -2 (since -1 is the starting value for cycles)
377
+ self.buffered_data_index = -1
378
+
379
+ self.restart_before_next_get = set()
380
+
381
+ def shuffle_buffer(self, seed: int = -1):
382
+ old_buffer = self.data_buffer
383
+ indices = list(range(len(old_buffer)))
384
+
385
+ state = random.getstate()
386
+ if seed >= 0:
387
+ random.seed(seed)
388
+ random.shuffle(indices)
389
+ if seed >= 0:
390
+ random.setstate(state)
391
+
392
+ self.data_buffer = []
393
+ k = 0
394
+ for i in indices:
395
+ self.data_buffer.append((old_buffer[i][0], old_buffer[k][1]))
396
+ k += 1
397
+
398
+ if self.text_buffer is not None and len(self.text_buffer) == len(self.data_buffer):
399
+ old_text_buffer = self.text_buffer
400
+ self.text_buffer = []
401
+ for i in indices:
402
+ self.text_buffer.append(old_text_buffer[i])
403
+
404
+ def to_text_snippet(self, length: int | None = None):
405
+ """Convert buffered text samples to a single long string.
406
+
407
+ Args:
408
+ length (int | None): Optional length of the resulting text snippet.
409
+
410
+ Returns:
411
+ str | None: Human-readable text sequence.
412
+ """
413
+ if self.text_buffer is not None and len(self.text_buffer) > 0:
414
+ if length is not None:
415
+ le = max(length // 2, 1)
416
+ text = " ".join(self.text_buffer[0:min(le, len(self.text_buffer))])
417
+ text += (" ... " + (" ".join(self.text_buffer[max(le, len(self.text_buffer) - le):]))) \
418
+ if len(self.text_buffer) > le else ""
419
+ else:
420
+ text = " ".join(self.text_buffer)
421
+ else:
422
+ text = None
423
+ return text
424
+
425
+ def get_since_timestamp(self, since_what_timestamp: float, stride: int = 1) -> (
426
+ tuple[list[int] | None, list[torch.Tensor | None] | None, int, DataProps]):
427
+ """Retrieve all samples starting from a given timestamp.
428
+
429
+ Args:
430
+ since_what_timestamp (float): Timestamp in seconds.
431
+ stride (int): Sampling stride.
432
+
433
+ Returns:
434
+ Tuple containing list of cycles, data, current cycle, and data properties.
435
+ """
436
+ since_what_cycle = self.clock.time2cycle(since_what_timestamp)
437
+ return self.get_since_cycle(since_what_cycle, stride)
438
+
439
+ def get_since_cycle(self, since_what_cycle: int, stride: int = 1) -> (
440
+ tuple[list[int] | None, list[torch.Tensor | None] | None, int, DataProps]):
441
+ """Retrieve all samples starting from a given clock cycle.
442
+
443
+ Args:
444
+ since_what_cycle (int): Cycle number.
445
+ stride (int): Stride to skip cycles.
446
+
447
+ Returns:
448
+ Tuple with cycles, data, current cycle, and properties.
449
+ """
450
+ assert stride >= 1 and isinstance(stride, int), f"Invalid stride: {stride}"
451
+
452
+ # Notice: this whole routed never calls ".get()", on purpose! it must be as it is
453
+ global_cycle = self.clock.get_cycle()
454
+ if global_cycle < 0:
455
+ return None, None, -1, self.props
456
+
457
+ # Fist check: ensure we do not go beyond the first clock and counting the resulting number of steps
458
+ since_what_cycle = max(since_what_cycle, 0)
459
+ num_steps = global_cycle - since_what_cycle + 1
460
+
461
+ # Second check: now we compute the index we should pass to get item
462
+ since_what_idx_in_getitem = since_what_cycle - self.first_cycle
463
+
464
+ ret_cycles = []
465
+ ret_data = []
466
+
467
+ for k in range(0, num_steps, stride):
468
+ _idx = since_what_idx_in_getitem + k
469
+ _data, _ = self[_idx]
470
+
471
+ if _data is not None:
472
+ ret_cycles.append(since_what_cycle + k)
473
+ ret_data.append(_data)
474
+
475
+ return ret_cycles, ret_data, global_cycle, self.props
476
+
477
+
478
+ class System(DataStream):
479
+ def __init__(self):
480
+ super().__init__(props=DataProps(name=System.__name__, data_type="text", data_desc="System stream",
481
+ pubsub=False))
482
+ self.set("ping")
483
+
484
+ def get(self, requested_by: str | None = None):
485
+ return super().get()
486
+
487
+
488
+ class Dataset(BufferedDataStream):
489
+ """
490
+ A buffered dataset that streams data from a PyTorch dataset and simulates data-streams for input/output.
491
+ """
492
+
493
+ def __init__(self, tensor_dataset: torch.utils.data.Dataset, shape: tuple, index: int = 0, batch_size: int = 1):
494
+ """Initialize a Dataset instance, which wraps around a PyTorch Dataset.
495
+
496
+ Args:
497
+ tensor_dataset (torch.utils.data.Dataset): The PyTorch Dataset to wrap.
498
+ shape (tuple): The shape of each sample from the data stream.
499
+ index (int): The index of the element returned by __getitem__ to pick up.
500
+ """
501
+ sample = tensor_dataset[0][index]
502
+ if isinstance(sample, torch.Tensor):
503
+ dtype = sample.dtype
504
+ elif isinstance(sample, int):
505
+ dtype = torch.long
506
+ elif isinstance(sample, float):
507
+ dtype = torch.float32
508
+ else:
509
+ raise ValueError("Expected tensor data or a scalar")
510
+
511
+ super().__init__(props=DataProps(name=Dataset.__name__,
512
+ data_type="tensor",
513
+ data_desc="dataset",
514
+ tensor_shape=shape,
515
+ tensor_dtype=dtype,
516
+ pubsub=True))
517
+
518
+ n = len(tensor_dataset)
519
+ b = batch_size
520
+ nb = math.ceil(float(n) / float(b))
521
+ r = n - b * (nb - 1)
522
+
523
+ for i in range(0, nb):
524
+ batch = []
525
+ if i == (nb - 1):
526
+ b = r
527
+
528
+ for j in range(0, b):
529
+ sample = tensor_dataset[i * b + j][index]
530
+ if isinstance(sample, (int, float)):
531
+ sample = torch.tensor(sample, dtype=dtype)
532
+ batch.append(sample)
533
+
534
+ self.data_buffer.append((torch.stack(batch), -1))
535
+
536
+ # It was buffered previously than every other thing
537
+ self.restart()
538
+
539
+
540
+ class ImageFileStream(BufferedDataStream):
541
+ """
542
+ A buffered dataset for image data.
543
+ """
544
+
545
+ def __init__(self, image_dir: str, list_of_image_files: str,
546
+ device: torch.device = None, circular: bool = True, show_images: bool = False):
547
+ """Initialize an ImageFileStream instance for streaming image data.
548
+
549
+ Args:
550
+ image_dir (str): The directory containing image files.
551
+ list_of_image_files (str): Path to the file with list of file names of the images.
552
+ device (torch.device): The device to store the tensors on. Default is CPU.
553
+ circular (bool): Whether to loop the dataset or not. Default is True.
554
+ """
555
+ self.image_dir = image_dir
556
+ self.device = device if device is not None else torch.device("cpu")
557
+ self.circular = circular
558
+
559
+ # Reading the image file
560
+ # (assume a file with one filename per line or a CSV format with lines such as: cat.jpg,cat,mammal,animal)
561
+ self.image_paths = []
562
+
563
+ # Calling the constructor
564
+ super().__init__(props=DataProps(name=ImageFileStream.__name__,
565
+ data_type="img",
566
+ pubsub=True))
567
+
568
+ with open(list_of_image_files, 'r') as f:
569
+ for line in f:
570
+ parts = line.strip().split(',') # Tolerates if it is a CVS and the first field is the image file name
571
+ image_name = parts[0]
572
+ self.image_paths.append(os.path.join(image_dir, image_name))
573
+
574
+ # It was buffered previously than every other thing
575
+ self.last_cycle = -1
576
+ self.first_cycle = self.last_cycle - len(self.image_paths) + 1
577
+
578
+ # Possibly print to screen the "clickable" list of images
579
+ if show_images:
580
+ show_images_grid(self.image_paths)
581
+ for i, image_path in enumerate(self.image_paths):
582
+ abs_path = os.path.abspath(image_path)
583
+ file_url = pathlib.Path(abs_path).as_uri()
584
+ basename = os.path.basename(abs_path) # 'photo.jpg'
585
+ parent = os.path.basename(os.path.dirname(abs_path)) # 'images'
586
+ label = os.path.join(parent, basename) if parent else basename
587
+ clickable_label = f"\033]8;;{file_url}\033\\[{label}]\033]8;;\033\\"
588
+ print(str(i) + " => " + clickable_label)
589
+
590
+ def __len__(self):
591
+ """Return the number of images in the dataset.
592
+
593
+ Returns:
594
+ int: Number of images in the dataset.
595
+ """
596
+ return len(self.image_paths)
597
+
598
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor | None, int]:
599
+ """Get the image and label for the specified cycle number.
600
+
601
+ Args:
602
+ idx (int): The cycle number to retrieve data for.
603
+
604
+ Returns:
605
+ tuple: A tuple of tensors (image, label) for the specified cycle.
606
+ """
607
+ if self.circular:
608
+ idx %= self.__len__()
609
+ else:
610
+ if idx >= self.__len__() or idx < 0:
611
+ return None, -1
612
+
613
+ image = Image.open(self.image_paths[idx])
614
+ return image, self.clock.get_cycle() - self.first_cycle
615
+
616
+
617
+ class LabelStream(BufferedDataStream):
618
+ """
619
+ A buffered stream for single and multi-label annotations.
620
+ """
621
+
622
+ def __init__(self, label_dir: str, label_file_csv: str,
623
+ device: torch.device = None, circular: bool = True, single_class: bool = False,
624
+ line_header: bool = False):
625
+ """Initialize an LabelStream instance for streaming labels.
626
+
627
+ Args:
628
+ label_dir (str): The directory containing image files.
629
+ label_file_csv (str): Path to the CSV file with labels for the images.
630
+ device (torch.device): The device to store the tensors on. Default is CPU.
631
+ circular (bool): Whether to loop the dataset or not. Default is True.
632
+ single_class (bool): Whether to only consider a single class for labeling. Default is False.
633
+ """
634
+ self.label_dir = label_dir
635
+ self.device = device if device is not None else torch.device("cpu")
636
+ self.circular = circular
637
+
638
+ # Reading the label file
639
+ # (assume a file with a labeled element per line or a CSV format with lines such as: cat.jpg,cat,mammal,animal)
640
+ self.labels = []
641
+
642
+ class_names = {}
643
+ with open(label_file_csv, 'r') as f:
644
+ for line in f:
645
+ parts = line.strip().split(',')
646
+ label = parts[1:]
647
+ for lab in label:
648
+ class_names[lab] = True
649
+ class_name_to_index = {}
650
+ class_names = list(class_names.keys())
651
+
652
+ # Call the constructor
653
+ super().__init__(props=DataProps(name=LabelStream.__name__,
654
+ data_type="tensor",
655
+ data_desc="label stream",
656
+ tensor_shape=(1, len(class_names)),
657
+ tensor_dtype=str(torch.float),
658
+ tensor_labels=class_names,
659
+ tensor_labeling_rule="geq0.5" if not single_class else "max",
660
+ pubsub=True))
661
+
662
+ for idx, class_name in enumerate(class_names):
663
+ class_name_to_index[class_name] = idx
664
+
665
+ with open(label_file_csv, 'r') as f:
666
+ for line in f:
667
+ parts = line.strip().split(',')
668
+ label = parts if not line_header else parts[1:]
669
+ target_vector = torch.zeros((1, len(class_names)), dtype=torch.float32)
670
+ for lab in label:
671
+ idx = class_name_to_index[lab]
672
+ target_vector[0, idx] = 1.
673
+ self.labels.append(target_vector)
674
+
675
+ # It was buffered previously than every other thing
676
+ self.last_cycle = -1
677
+ self.first_cycle = self.last_cycle - len(self.labels) + 1
678
+
679
+ def __len__(self):
680
+ """Return the number of labels in the dataset.
681
+
682
+ Returns:
683
+ int: Number of labels in the dataset.
684
+ """
685
+ return len(self.labels)
686
+
687
+ def __getitem__(self, idx: int) -> torch.Tensor | None:
688
+ """Get the image and label for the specified cycle number.
689
+
690
+ Args:
691
+ idx (int): The cycle number to retrieve data for.
692
+
693
+ Returns:
694
+ tuple: A tuple of tensors (image, label) for the specified cycle.
695
+ """
696
+ if self.circular:
697
+ idx %= self.__len__()
698
+ else:
699
+ if idx >= self.__len__() or idx < 0:
700
+ return None, -1
701
+
702
+ label = self.labels[idx].unsqueeze(0).to(self.device) # Multi-label vector for the image
703
+ return self.props.adapt_tensor_to_tensor_labels(label), self.clock.get_cycle() - self.first_cycle
704
+
705
+
706
+ class TokensStream(BufferedDataStream):
707
+ """
708
+ A buffered dataset for tokenized text, where each token is paired with its corresponding labels.
709
+ """
710
+
711
+ def __init__(self, tokens_file_csv: str, circular: bool = True, max_tokens: int = -1):
712
+ """Initialize a Tokens instance for streaming tokenized data and associated labels.
713
+
714
+ Args:
715
+ tokens_file_csv (str): Path to the CSV file containing token data.
716
+ circular (bool): Whether to loop the dataset or not. Default is True.
717
+ max_tokens (int): Whether to cut the stream to a maximum number of tokens. Default is -1 (no cut).
718
+ """
719
+ self.circular = circular
720
+
721
+ # Reading the data file (assume a token per line or a CSV format with lines such as:
722
+ # token,category_label1,category_label2,etc.)
723
+ tokens = []
724
+ with open(tokens_file_csv, 'r') as f:
725
+ for line in f:
726
+ parts = next(csv.reader([line], quotechar='"', delimiter=','))
727
+ tokens.append(parts[0])
728
+ if 0 < max_tokens <= len(tokens):
729
+ break
730
+
731
+ # Vocabulary
732
+ idx = 0
733
+ word2id = {}
734
+ sorted_stream_of_tokens = sorted(tokens)
735
+ for token in sorted_stream_of_tokens:
736
+ if token not in word2id:
737
+ word2id[token] = idx
738
+ idx += 1
739
+ id2word = [""] * len(word2id)
740
+ for _word, _id in word2id.items():
741
+ id2word[_id] = _word
742
+
743
+ # Calling the constructor
744
+ super().__init__(props=DataProps(name=TokensStream.__name__,
745
+ data_type="text",
746
+ data_desc="stream of words",
747
+ stream_to_proc_transforms=word2id,
748
+ proc_to_stream_transforms=id2word,
749
+ pubsub=True))
750
+
751
+ # Tokenized text
752
+ for i, token in enumerate(tokens):
753
+ data = token
754
+ self.data_buffer.append((data, -1))
755
+
756
+ # It was buffered previously than every other thing
757
+ self.restart()
758
+
759
+ def __getitem__(self, idx: int) -> torch.Tensor | None:
760
+ """Get the image and label for the specified cycle number.
761
+
762
+ Args:
763
+ idx (int): The index to retrieve data for.
764
+
765
+ Returns:
766
+ tuple: A tuple of tensors (image, label) for the specified cycle.
767
+ """
768
+ if self.circular:
769
+ idx %= self.__len__()
770
+ return super().__getitem__(idx)