unaiverse 0.1.11__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unaiverse might be problematic. Click here for more details.

Files changed (50) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2090 -0
  3. unaiverse/agent_basics.py +1948 -0
  4. unaiverse/clock.py +221 -0
  5. unaiverse/dataprops.py +1236 -0
  6. unaiverse/hsm.py +1892 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +710 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1308 -0
  19. unaiverse/networking/node/node.py +2499 -0
  20. unaiverse/networking/node/profile.py +446 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +187 -0
  23. unaiverse/networking/p2p/go.mod +127 -0
  24. unaiverse/networking/p2p/go.sum +548 -0
  25. unaiverse/networking/p2p/golibp2p.py +18 -0
  26. unaiverse/networking/p2p/golibp2p.pyi +135 -0
  27. unaiverse/networking/p2p/lib.go +2662 -0
  28. unaiverse/networking/p2p/lib.go.sha256 +1 -0
  29. unaiverse/networking/p2p/lib_types.py +312 -0
  30. unaiverse/networking/p2p/message_pb2.py +50 -0
  31. unaiverse/networking/p2p/messages.py +362 -0
  32. unaiverse/networking/p2p/mylogger.py +77 -0
  33. unaiverse/networking/p2p/p2p.py +871 -0
  34. unaiverse/networking/p2p/proto-go/message.pb.go +846 -0
  35. unaiverse/networking/p2p/unailib.cpython-311-darwin.so +0 -0
  36. unaiverse/stats.py +1481 -0
  37. unaiverse/streamlib/__init__.py +15 -0
  38. unaiverse/streamlib/streamlib.py +210 -0
  39. unaiverse/streams.py +776 -0
  40. unaiverse/utils/__init__.py +16 -0
  41. unaiverse/utils/lone_wolf.json +24 -0
  42. unaiverse/utils/misc.py +310 -0
  43. unaiverse/utils/sandbox.py +293 -0
  44. unaiverse/utils/server.py +435 -0
  45. unaiverse/world.py +335 -0
  46. unaiverse-0.1.11.dist-info/METADATA +367 -0
  47. unaiverse-0.1.11.dist-info/RECORD +50 -0
  48. unaiverse-0.1.11.dist-info/WHEEL +6 -0
  49. unaiverse-0.1.11.dist-info/licenses/LICENSE +43 -0
  50. unaiverse-0.1.11.dist-info/top_level.txt +1 -0
unaiverse/streams.py ADDED
@@ -0,0 +1,776 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import os
16
+ import csv
17
+ import math
18
+ import torch
19
+ import random
20
+ import pathlib
21
+ from PIL import Image
22
+ from .clock import Clock
23
+ from .dataprops import DataProps
24
+ from datetime import datetime, timezone
25
+ from unaiverse.utils.misc import show_images_grid
26
+
27
+
28
+ class DataStream:
29
+ """
30
+ Base class for handling a generic data stream.
31
+ """
32
+
33
+ def __init__(self,
34
+ props: DataProps,
35
+ clock: Clock = Clock(current_time=datetime.now(timezone.utc).timestamp())) -> None:
36
+ """Initialize a DataStream.
37
+
38
+ Args:
39
+ props (DataProps): Properties of the stream.
40
+ clock (Clock): Clock object for time management (usually provided from outside).
41
+ """
42
+
43
+ # A stream can be turned off
44
+ self.props = props
45
+ self.clock = clock
46
+ self.data = None
47
+ self.data_timestamp = self.clock.get_time()
48
+ self.data_timestamp_when_got_by = {}
49
+ self.data_tag = -1
50
+ self.data_uuid = None
51
+ self.data_uuid_expected = None
52
+ self.data_uuid_clearable = False
53
+ self.enabled = True
54
+
55
+ @staticmethod
56
+ def create(stream: 'DataStream', name: str | None = None, group: str = 'none',
57
+ public: bool = True, pubsub: bool = True):
58
+ """Create and set the name for a given stream, also updates data stream labels.
59
+
60
+ Args:
61
+ stream (DataStream): The stream object to modify.
62
+ name (str): The name of the stream.
63
+ group (str): The name of the group to which the stream belongs.
64
+ public (bool): If the stream is going to be served in the public net or the private one.
65
+ pubsub (bool): If the stream is going to be served by broadcasting (PubSub) or not.
66
+
67
+ Returns:
68
+ Stream: The modified stream with updated group name.
69
+ """
70
+ assert name is not None or group != 'none', "Must provide either name or group name."
71
+ stream.props.set_group(group)
72
+ if name is not None:
73
+ stream.props.set_name(name)
74
+ else:
75
+ stream.props.set_name(stream.props.get_name() + "@" + group) # If name is None, then a group was provided
76
+ stream.props.set_public(public)
77
+ stream.props.set_pubsub(pubsub)
78
+ if (stream.props.is_flat_tensor_with_labels() and
79
+ len(stream.props.tensor_labels) == 1 and stream.props.tensor_labels[0] == 'unk'):
80
+ stream.props.tensor_labels[0] = group if group != 'none' else name
81
+ return stream
82
+
83
+ def enable(self):
84
+ """Enable the stream, allowing data to be retrieved.
85
+
86
+ Returns:
87
+ None
88
+ """
89
+ self.enabled = True
90
+
91
+ def disable(self):
92
+ """Disable the stream, preventing data from being retrieved.
93
+
94
+ Returns:
95
+ None
96
+ """
97
+ self.enabled = False
98
+
99
+ def get_props(self) -> DataProps:
100
+ return self.props
101
+
102
+ def net_hash(self, prefix: str):
103
+ return self.get_props().net_hash(prefix)
104
+
105
+ @staticmethod
106
+ def peer_id_from_net_hash(net_hash):
107
+ return DataProps.peer_id_from_net_hash(net_hash)
108
+
109
+ @staticmethod
110
+ def name_or_group_from_net_hash(net_hash):
111
+ return DataProps.name_or_group_from_net_hash(net_hash)
112
+
113
+ @staticmethod
114
+ def is_pubsub_from_net_hash(net_hash):
115
+ return DataProps.is_pubsub_from_net_hash(net_hash)
116
+
117
+ def is_pubsub(self):
118
+ return self.get_props().is_pubsub()
119
+
120
+ def is_public(self):
121
+ return self.get_props().is_public()
122
+
123
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor | None, int]:
124
+ """Get item for a specific clock cycle. Not implemented for base class: it will be implemented in buffered data
125
+ streams or stream that can generate data on-the-fly.
126
+
127
+ Args:
128
+ idx (int): Index of the data to retrieve.
129
+
130
+ Raises:
131
+ ValueError: Always, since this method should be overridden.
132
+ """
133
+ raise ValueError("Not implemented (expected to be only present in data streams that are buffered or "
134
+ "that can generated on the fly)")
135
+
136
+ def __str__(self) -> str:
137
+ """String representation of the object.
138
+
139
+ Returns:
140
+ str: String representation of the object.
141
+ """
142
+ return f"[DataStream] enabled: {self.enabled}\n\tprops: {self.props}"
143
+
144
+ def __len__(self):
145
+ """Get the length of the stream.
146
+
147
+ Returns:
148
+ int: Infinity (as this is a lifelong stream).
149
+ """
150
+ return math.inf
151
+
152
+ def set(self, data: torch.Tensor | Image.Image | str, data_tag: int = -1) -> bool:
153
+ """Set a new data sample into the stream, that will be provided when calling "get()".
154
+
155
+ Args:
156
+ data (torch.Tensor): Data sample to set.
157
+ data_tag (int): Custom data time tag >= 0 (Default: -1, meaning no tags).
158
+
159
+ Returns:
160
+ bool: True if data was accepted based on time constraints, else False.
161
+ """
162
+ if self.props.delta <= 0. or self.props.delta <= (self.clock.get_time() - self.data_timestamp):
163
+ self.data = self.props.adapt_tensor_to_tensor_labels(data) if data is not None else None
164
+ self.data_timestamp = self.clock.get_cycle_time()
165
+ self.data_tag = data_tag if data_tag >= 0 else self.clock.get_cycle()
166
+ return True
167
+ else:
168
+ return False
169
+
170
+ def get(self, requested_by: str | None = None) -> torch.Tensor | None:
171
+ """Get the most recent data sample from the stream (i.e., the last one that was "set").
172
+
173
+ Returns:
174
+ torch.Tensor | None: Adapted data sample if available.
175
+ """
176
+ if self.enabled:
177
+ if requested_by is not None:
178
+ if (requested_by not in self.data_timestamp_when_got_by or
179
+ self.data_timestamp_when_got_by[requested_by] != self.data_timestamp):
180
+ self.data_timestamp_when_got_by[requested_by] = self.data_timestamp
181
+ return self.data
182
+ else:
183
+ return None
184
+ else:
185
+ return self.data
186
+ else:
187
+ return None
188
+
189
+ def get_(self, requested_by: str | None = None) -> torch.Tensor:
190
+ """Wrapper of 'get'."""
191
+ return self.get(requested_by)
192
+
193
+ def get_timestamp(self) -> float:
194
+ return self.data_timestamp
195
+
196
+ def get_tag(self) -> int:
197
+ return self.data_tag
198
+
199
+ def get_uuid(self, expected: bool = False) -> str | None:
200
+ return self.data_uuid if not expected else self.data_uuid_expected
201
+
202
+ def set_uuid(self, ref_uuid: str | None, expected: bool = False):
203
+ if not expected:
204
+ self.data_uuid = ref_uuid
205
+ else:
206
+ self.data_uuid_expected = ref_uuid
207
+
208
+ # When we "set" a UUID, even if None, we unmark the "clear UUID" flag, to be sure that nobody is going to
209
+ # accidentally clear this information
210
+ self.data_uuid_clearable = False
211
+
212
+ def mark_uuid_as_clearable(self):
213
+ self.data_uuid_clearable = True
214
+
215
+ def clear_uuid_if_marked_as_clearable(self):
216
+ if self.data_uuid_clearable:
217
+ self.data_uuid = None
218
+ self.data_uuid_expected = None
219
+ self.data_uuid_clearable = False
220
+
221
+ def set_props(self, data_stream: 'DataStream'):
222
+ """Set (edit) the data properties picking up the ones of another DataStream.
223
+
224
+ Args:
225
+ data_stream (DataStream): the source DataStream from which DataProp is taken.
226
+
227
+ Returns:
228
+ None
229
+ """
230
+ self.props = data_stream.props
231
+
232
+
233
+ class BufferedDataStream(DataStream):
234
+ """
235
+ Data stream with buffer support to store historical data.
236
+ """
237
+
238
+ def __init__(self, props: DataProps, clock: Clock = Clock(current_time=datetime.now(timezone.utc).timestamp()),
239
+ is_static: bool = False):
240
+ """Initialize a BufferedDataStream.
241
+
242
+ Args:
243
+ is_static (bool): If True, the buffer stores only one item that is reused.
244
+ """
245
+ super().__init__(props=props, clock=clock)
246
+
247
+ # We store the data samples, and we cache their text representation (for speed)
248
+ self.data_buffer = []
249
+ self.text_buffer = []
250
+
251
+ self.is_static = is_static # A static stream store only one sample and always yields it
252
+
253
+ # We need to remember the fist cycle in which we started buffering and the last one we buffered
254
+ self.first_cycle = -1
255
+ self.last_cycle = -1
256
+ self.last_get_cycle = -2 # Keep it to -2 (since -1 is the starting value for cycles)
257
+ self.buffered_data_index = -1
258
+
259
+ self.restart_before_next_get = set()
260
+
261
+ def get(self, requested_by: str | None = None) -> tuple[torch.Tensor | None, float]:
262
+ """Get the current data sample based on cycle and buffer.
263
+
264
+ Returns:
265
+ torch.Tensor | None: Current buffered sample.
266
+ """
267
+ if requested_by is not None and requested_by in self.restart_before_next_get:
268
+ self.restart_before_next_get.remove(requested_by)
269
+ self.restart()
270
+
271
+ cycle = self.clock.get_cycle() - self.first_cycle # This ensures that first get clock = first sample
272
+
273
+ # These two lines might make you think "hey, call super().set(self[cycle]), it is the same!"
274
+ # however, it is not like that, since "set" will also call "adapt_to_labels", that is not needed for
275
+ # buffered streams
276
+ if (self.last_get_cycle != cycle and
277
+ (self.props.delta <= 0. or self.props.delta <= (self.clock.get_time() - self.data_timestamp))):
278
+ self.last_get_cycle = cycle
279
+ self.buffered_data_index += 1
280
+ new_data, new_tag = self[self.buffered_data_index]
281
+ self.data = new_data
282
+ self.data_timestamp = self.clock.get_cycle_time()
283
+ self.data_tag = new_tag
284
+ return super().get(requested_by)
285
+
286
+ def set(self, data: torch.Tensor, data_tag: int = -1):
287
+ """Store a new data sample into the buffer.
288
+
289
+ Args:
290
+ data (torch.Tensor): Data to store.
291
+ data_tag (int): Custom data time tag >= 0 (Default: -1, meaning no tags).
292
+
293
+ Returns:
294
+ bool: True if the data was buffered.
295
+ """
296
+ ret = super().set(data, data_tag)
297
+
298
+ if ret:
299
+ if not self.is_static or len(self.data_buffer) == 0:
300
+ self.data_buffer.append((self.props.adapt_tensor_to_tensor_labels(data), self.get_tag()))
301
+ if self.props.is_flat_tensor_with_labels():
302
+ self.text_buffer.append(self.props.to_text(data))
303
+ elif self.props.is_text():
304
+ self.text_buffer.append(data)
305
+
306
+ # Boilerplate
307
+ if self.first_cycle < 0:
308
+ self.first_cycle = self.clock.get_cycle()
309
+ self.last_cycle = self.first_cycle
310
+ else:
311
+
312
+ # Filling gaps with "None"
313
+ cycle = self.clock.get_cycle()
314
+ if cycle > self.last_cycle + 1:
315
+ for cycle in range(cycle, self.last_cycle + 1):
316
+ self.data_buffer.append((None, -1))
317
+ self.last_cycle = cycle - 1
318
+
319
+ self.last_cycle += 1
320
+ return ret
321
+
322
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor | None, int]:
323
+ """Retrieve a sample from the buffer based on the given clock cycle.
324
+
325
+ Args:
326
+ idx (int): Index (>=0) of the sample.
327
+
328
+ Returns:
329
+ torch.Tensor | None: The sample, if available.
330
+ """
331
+ if not self.is_static:
332
+ if idx >= self.__len__() or idx < 0:
333
+ return None, -1
334
+ data, data_tag = self.data_buffer[idx]
335
+ else:
336
+ data, data_tag = self.data_buffer[0]
337
+ return data, data_tag if data_tag >= 0 else self.clock.get_cycle() - self.first_cycle
338
+
339
+ def __len__(self):
340
+ """Get number of samples in the buffer.
341
+
342
+ Returns:
343
+ int: Number of buffered samples.
344
+ """
345
+ return len(self.data_buffer)
346
+
347
+ def set_first_cycle(self, cycle):
348
+ """Manually set the first cycle for the buffer.
349
+
350
+ Args:
351
+ cycle (int): Global cycle to start from.
352
+ """
353
+ self.first_cycle = cycle
354
+ self.last_cycle = cycle + len(self)
355
+
356
+ def get_first_cycle(self):
357
+ """Get the first cycle of the stream.
358
+
359
+ Returns:
360
+ int: First cycle value.
361
+ """
362
+ return self.first_cycle
363
+
364
+ def restart(self):
365
+ """Restart the buffer using the current clock cycle.
366
+ """
367
+ self.set_first_cycle(max(self.clock.get_cycle(), 0))
368
+ self.buffered_data_index = -1
369
+ self.data_timestamp_when_got_by = {}
370
+
371
+ def plan_restart_before_next_get(self, requested_by: str):
372
+ self.restart_before_next_get.add(requested_by)
373
+
374
+ def clear_buffer(self):
375
+ """Clear the data buffer
376
+ """
377
+ self.data_buffer = []
378
+ self.text_buffer = []
379
+
380
+ self.first_cycle = -1
381
+ self.last_cycle = -1
382
+ self.last_get_cycle = -2 # Keep it to -2 (since -1 is the starting value for cycles)
383
+ self.buffered_data_index = -1
384
+
385
+ self.restart_before_next_get = set()
386
+
387
+ def shuffle_buffer(self, seed: int = -1):
388
+ old_buffer = self.data_buffer
389
+ indices = list(range(len(old_buffer)))
390
+
391
+ state = random.getstate()
392
+ if seed >= 0:
393
+ random.seed(seed)
394
+ random.shuffle(indices)
395
+ if seed >= 0:
396
+ random.setstate(state)
397
+
398
+ self.data_buffer = []
399
+ k = 0
400
+ for i in indices:
401
+ self.data_buffer.append((old_buffer[i][0], old_buffer[k][1]))
402
+ k += 1
403
+
404
+ if self.text_buffer is not None and len(self.text_buffer) == len(self.data_buffer):
405
+ old_text_buffer = self.text_buffer
406
+ self.text_buffer = []
407
+ for i in indices:
408
+ self.text_buffer.append(old_text_buffer[i])
409
+
410
+ def to_text_snippet(self, length: int | None = None):
411
+ """Convert buffered text samples to a single long string.
412
+
413
+ Args:
414
+ length (int | None): Optional length of the resulting text snippet.
415
+
416
+ Returns:
417
+ str | None: Human-readable text sequence.
418
+ """
419
+ if self.text_buffer is not None and len(self.text_buffer) > 0:
420
+ if length is not None:
421
+ le = max(length // 2, 1)
422
+ text = " ".join(self.text_buffer[0:min(le, len(self.text_buffer))])
423
+ text += (" ... " + (" ".join(self.text_buffer[max(le, len(self.text_buffer) - le):]))) \
424
+ if len(self.text_buffer) > le else ""
425
+ else:
426
+ text = " ".join(self.text_buffer)
427
+ else:
428
+ text = None
429
+ return text
430
+
431
+ def get_since_timestamp(self, since_what_timestamp: float, stride: int = 1) -> (
432
+ tuple[list[int] | None, list[torch.Tensor | None] | None, int, DataProps]):
433
+ """Retrieve all samples starting from a given timestamp.
434
+
435
+ Args:
436
+ since_what_timestamp (float): Timestamp in seconds.
437
+ stride (int): Sampling stride.
438
+
439
+ Returns:
440
+ Tuple containing list of cycles, data, current cycle, and data properties.
441
+ """
442
+ since_what_cycle = self.clock.time2cycle(since_what_timestamp)
443
+ return self.get_since_cycle(since_what_cycle, stride)
444
+
445
+ def get_since_cycle(self, since_what_cycle: int, stride: int = 1) -> (
446
+ tuple[list[int] | None, list[torch.Tensor | None] | None, int, DataProps]):
447
+ """Retrieve all samples starting from a given clock cycle.
448
+
449
+ Args:
450
+ since_what_cycle (int): Cycle number.
451
+ stride (int): Stride to skip cycles.
452
+
453
+ Returns:
454
+ Tuple with cycles, data, current cycle, and properties.
455
+ """
456
+ assert stride >= 1 and isinstance(stride, int), f"Invalid stride: {stride}"
457
+
458
+ # Notice: this whole routed never calls ".get()", on purpose! it must be as it is
459
+ global_cycle = self.clock.get_cycle()
460
+ if global_cycle < 0:
461
+ return None, None, -1, self.props
462
+
463
+ # Fist check: ensure we do not go beyond the first clock and counting the resulting number of steps
464
+ since_what_cycle = max(since_what_cycle, 0)
465
+ num_steps = global_cycle - since_what_cycle + 1
466
+
467
+ # Second check: now we compute the index we should pass to get item
468
+ since_what_idx_in_getitem = since_what_cycle - self.first_cycle
469
+
470
+ ret_cycles = []
471
+ ret_data = []
472
+
473
+ for k in range(0, num_steps, stride):
474
+ _idx = since_what_idx_in_getitem + k
475
+ _data, _ = self[_idx]
476
+
477
+ if _data is not None:
478
+ ret_cycles.append(since_what_cycle + k)
479
+ ret_data.append(_data)
480
+
481
+ return ret_cycles, ret_data, global_cycle, self.props
482
+
483
+
484
+ class System(DataStream):
485
+ def __init__(self):
486
+ super().__init__(props=DataProps(name=System.__name__, data_type="text", data_desc="System stream",
487
+ pubsub=False))
488
+ self.set("ping")
489
+
490
+ def get(self, requested_by: str | None = None):
491
+ return super().get()
492
+
493
+
494
+ class Dataset(BufferedDataStream):
495
+ """
496
+ A buffered dataset that streams data from a PyTorch dataset and simulates data-streams for input/output.
497
+ """
498
+
499
+ def __init__(self, tensor_dataset: torch.utils.data.Dataset, shape: tuple, index: int = 0, batch_size: int = 1):
500
+ """Initialize a Dataset instance, which wraps around a PyTorch Dataset.
501
+
502
+ Args:
503
+ tensor_dataset (torch.utils.data.Dataset): The PyTorch Dataset to wrap.
504
+ shape (tuple): The shape of each sample from the data stream.
505
+ index (int): The index of the element returned by __getitem__ to pick up.
506
+ """
507
+ sample = tensor_dataset[0][index]
508
+ if isinstance(sample, torch.Tensor):
509
+ dtype = sample.dtype
510
+ elif isinstance(sample, int):
511
+ dtype = torch.long
512
+ elif isinstance(sample, float):
513
+ dtype = torch.float32
514
+ else:
515
+ raise ValueError("Expected tensor data or a scalar")
516
+
517
+ super().__init__(props=DataProps(name=Dataset.__name__,
518
+ data_type="tensor",
519
+ data_desc="dataset",
520
+ tensor_shape=shape,
521
+ tensor_dtype=dtype,
522
+ pubsub=True))
523
+
524
+ n = len(tensor_dataset)
525
+ b = batch_size
526
+ nb = math.ceil(float(n) / float(b))
527
+ r = n - b * (nb - 1)
528
+
529
+ for i in range(0, nb):
530
+ batch = []
531
+ if i == (nb - 1):
532
+ b = r
533
+
534
+ for j in range(0, b):
535
+ sample = tensor_dataset[i * b + j][index]
536
+ if isinstance(sample, (int, float)):
537
+ sample = torch.tensor(sample, dtype=dtype)
538
+ batch.append(sample)
539
+
540
+ self.data_buffer.append((torch.stack(batch), -1))
541
+
542
+ # It was buffered previously than every other thing
543
+ self.restart()
544
+
545
+
546
+ class ImageFileStream(BufferedDataStream):
547
+ """
548
+ A buffered dataset for image data.
549
+ """
550
+
551
+ def __init__(self, image_dir: str, list_of_image_files: str,
552
+ device: torch.device = None, circular: bool = True, show_images: bool = False):
553
+ """Initialize an ImageFileStream instance for streaming image data.
554
+
555
+ Args:
556
+ image_dir (str): The directory containing image files.
557
+ list_of_image_files (str): Path to the file with list of file names of the images.
558
+ device (torch.device): The device to store the tensors on. Default is CPU.
559
+ circular (bool): Whether to loop the dataset or not. Default is True.
560
+ """
561
+ self.image_dir = image_dir
562
+ self.device = device if device is not None else torch.device("cpu")
563
+ self.circular = circular
564
+
565
+ # Reading the image file
566
+ # (assume a file with one filename per line or a CSV format with lines such as: cat.jpg,cat,mammal,animal)
567
+ self.image_paths = []
568
+
569
+ # Calling the constructor
570
+ super().__init__(props=DataProps(name=ImageFileStream.__name__,
571
+ data_type="img",
572
+ pubsub=True))
573
+
574
+ with open(list_of_image_files, 'r') as f:
575
+ for line in f:
576
+ parts = line.strip().split(',') # Tolerates if it is a CVS and the first field is the image file name
577
+ image_name = parts[0]
578
+ self.image_paths.append(os.path.join(image_dir, image_name))
579
+
580
+ # It was buffered previously than every other thing
581
+ self.last_cycle = -1
582
+ self.first_cycle = self.last_cycle - len(self.image_paths) + 1
583
+
584
+ # Possibly print to screen the "clickable" list of images
585
+ if show_images:
586
+ show_images_grid(self.image_paths)
587
+ for i, image_path in enumerate(self.image_paths):
588
+ abs_path = os.path.abspath(image_path)
589
+ file_url = pathlib.Path(abs_path).as_uri()
590
+ basename = os.path.basename(abs_path) # 'photo.jpg'
591
+ parent = os.path.basename(os.path.dirname(abs_path)) # 'images'
592
+ label = os.path.join(parent, basename) if parent else basename
593
+ clickable_label = f"\033]8;;{file_url}\033\\[{label}]\033]8;;\033\\"
594
+ print(str(i) + " => " + clickable_label)
595
+
596
+ def __len__(self):
597
+ """Return the number of images in the dataset.
598
+
599
+ Returns:
600
+ int: Number of images in the dataset.
601
+ """
602
+ return len(self.image_paths)
603
+
604
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor | None, int]:
605
+ """Get the image and label for the specified cycle number.
606
+
607
+ Args:
608
+ idx (int): The cycle number to retrieve data for.
609
+
610
+ Returns:
611
+ tuple: A tuple of tensors (image, label) for the specified cycle.
612
+ """
613
+ if self.circular:
614
+ idx %= self.__len__()
615
+ else:
616
+ if idx >= self.__len__() or idx < 0:
617
+ return None, -1
618
+
619
+ image = Image.open(self.image_paths[idx])
620
+ return image, self.clock.get_cycle() - self.first_cycle
621
+
622
+
623
+ class LabelStream(BufferedDataStream):
624
+ """
625
+ A buffered stream for single and multi-label annotations.
626
+ """
627
+
628
+ def __init__(self, label_dir: str, label_file_csv: str,
629
+ device: torch.device = None, circular: bool = True, single_class: bool = False,
630
+ line_header: bool = False):
631
+ """Initialize an LabelStream instance for streaming labels.
632
+
633
+ Args:
634
+ label_dir (str): The directory containing image files.
635
+ label_file_csv (str): Path to the CSV file with labels for the images.
636
+ device (torch.device): The device to store the tensors on. Default is CPU.
637
+ circular (bool): Whether to loop the dataset or not. Default is True.
638
+ single_class (bool): Whether to only consider a single class for labeling. Default is False.
639
+ """
640
+ self.label_dir = label_dir
641
+ self.device = device if device is not None else torch.device("cpu")
642
+ self.circular = circular
643
+
644
+ # Reading the label file
645
+ # (assume a file with a labeled element per line or a CSV format with lines such as: cat.jpg,cat,mammal,animal)
646
+ self.labels = []
647
+
648
+ class_names = {}
649
+ with open(label_file_csv, 'r') as f:
650
+ for line in f:
651
+ parts = line.strip().split(',')
652
+ label = parts[1:]
653
+ for lab in label:
654
+ class_names[lab] = True
655
+ class_name_to_index = {}
656
+ class_names = list(class_names.keys())
657
+
658
+ # Call the constructor
659
+ super().__init__(props=DataProps(name=LabelStream.__name__,
660
+ data_type="tensor",
661
+ data_desc="label stream",
662
+ tensor_shape=(1, len(class_names)),
663
+ tensor_dtype=str(torch.float),
664
+ tensor_labels=class_names,
665
+ tensor_labeling_rule="geq0.5" if not single_class else "max",
666
+ pubsub=True))
667
+
668
+ for idx, class_name in enumerate(class_names):
669
+ class_name_to_index[class_name] = idx
670
+
671
+ with open(label_file_csv, 'r') as f:
672
+ for line in f:
673
+ parts = line.strip().split(',')
674
+ label = parts if not line_header else parts[1:]
675
+ target_vector = torch.zeros((1, len(class_names)), dtype=torch.float32)
676
+ for lab in label:
677
+ idx = class_name_to_index[lab]
678
+ target_vector[0, idx] = 1.
679
+ self.labels.append(target_vector)
680
+
681
+ # It was buffered previously than every other thing
682
+ self.last_cycle = -1
683
+ self.first_cycle = self.last_cycle - len(self.labels) + 1
684
+
685
+ def __len__(self):
686
+ """Return the number of labels in the dataset.
687
+
688
+ Returns:
689
+ int: Number of labels in the dataset.
690
+ """
691
+ return len(self.labels)
692
+
693
+ def __getitem__(self, idx: int) -> torch.Tensor | None:
694
+ """Get the image and label for the specified cycle number.
695
+
696
+ Args:
697
+ idx (int): The cycle number to retrieve data for.
698
+
699
+ Returns:
700
+ tuple: A tuple of tensors (image, label) for the specified cycle.
701
+ """
702
+ if self.circular:
703
+ idx %= self.__len__()
704
+ else:
705
+ if idx >= self.__len__() or idx < 0:
706
+ return None, -1
707
+
708
+ label = self.labels[idx].unsqueeze(0).to(self.device) # Multi-label vector for the image
709
+ return self.props.adapt_tensor_to_tensor_labels(label), self.clock.get_cycle() - self.first_cycle
710
+
711
+
712
+ class TokensStream(BufferedDataStream):
713
+ """
714
+ A buffered dataset for tokenized text, where each token is paired with its corresponding labels.
715
+ """
716
+
717
+ def __init__(self, tokens_file_csv: str, circular: bool = True, max_tokens: int = -1):
718
+ """Initialize a Tokens instance for streaming tokenized data and associated labels.
719
+
720
+ Args:
721
+ tokens_file_csv (str): Path to the CSV file containing token data.
722
+ circular (bool): Whether to loop the dataset or not. Default is True.
723
+ max_tokens (int): Whether to cut the stream to a maximum number of tokens. Default is -1 (no cut).
724
+ """
725
+ self.circular = circular
726
+
727
+ # Reading the data file (assume a token per line or a CSV format with lines such as:
728
+ # token,category_label1,category_label2,etc.)
729
+ tokens = []
730
+ with open(tokens_file_csv, 'r') as f:
731
+ for line in f:
732
+ parts = next(csv.reader([line], quotechar='"', delimiter=','))
733
+ tokens.append(parts[0])
734
+ if 0 < max_tokens <= len(tokens):
735
+ break
736
+
737
+ # Vocabulary
738
+ idx = 0
739
+ word2id = {}
740
+ sorted_stream_of_tokens = sorted(tokens)
741
+ for token in sorted_stream_of_tokens:
742
+ if token not in word2id:
743
+ word2id[token] = idx
744
+ idx += 1
745
+ id2word = [""] * len(word2id)
746
+ for _word, _id in word2id.items():
747
+ id2word[_id] = _word
748
+
749
+ # Calling the constructor
750
+ super().__init__(props=DataProps(name=TokensStream.__name__,
751
+ data_type="text",
752
+ data_desc="stream of words",
753
+ stream_to_proc_transforms=word2id,
754
+ proc_to_stream_transforms=id2word,
755
+ pubsub=True))
756
+
757
+ # Tokenized text
758
+ for i, token in enumerate(tokens):
759
+ data = token
760
+ self.data_buffer.append((data, -1))
761
+
762
+ # It was buffered previously than every other thing
763
+ self.restart()
764
+
765
+ def __getitem__(self, idx: int) -> torch.Tensor | None:
766
+ """Get the image and label for the specified cycle number.
767
+
768
+ Args:
769
+ idx (int): The index to retrieve data for.
770
+
771
+ Returns:
772
+ tuple: A tuple of tensors (image, label) for the specified cycle.
773
+ """
774
+ if self.circular:
775
+ idx %= self.__len__()
776
+ return super().__getitem__(idx)