unaiverse 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2226 -0
  3. unaiverse/agent_basics.py +2389 -0
  4. unaiverse/clock.py +234 -0
  5. unaiverse/dataprops.py +1282 -0
  6. unaiverse/hsm.py +2471 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +748 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1332 -0
  19. unaiverse/networking/node/node.py +2752 -0
  20. unaiverse/networking/node/profile.py +446 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +188 -0
  23. unaiverse/networking/p2p/go.mod +127 -0
  24. unaiverse/networking/p2p/go.sum +548 -0
  25. unaiverse/networking/p2p/golibp2p.py +18 -0
  26. unaiverse/networking/p2p/golibp2p.pyi +136 -0
  27. unaiverse/networking/p2p/lib.go +2765 -0
  28. unaiverse/networking/p2p/lib_types.py +311 -0
  29. unaiverse/networking/p2p/message_pb2.py +50 -0
  30. unaiverse/networking/p2p/messages.py +360 -0
  31. unaiverse/networking/p2p/mylogger.py +78 -0
  32. unaiverse/networking/p2p/p2p.py +900 -0
  33. unaiverse/networking/p2p/proto-go/message.pb.go +846 -0
  34. unaiverse/stats.py +1506 -0
  35. unaiverse/streamlib/__init__.py +15 -0
  36. unaiverse/streamlib/streamlib.py +210 -0
  37. unaiverse/streams.py +804 -0
  38. unaiverse/utils/__init__.py +16 -0
  39. unaiverse/utils/lone_wolf.json +28 -0
  40. unaiverse/utils/misc.py +441 -0
  41. unaiverse/utils/sandbox.py +292 -0
  42. unaiverse/world.py +384 -0
  43. unaiverse-0.1.12.dist-info/METADATA +366 -0
  44. unaiverse-0.1.12.dist-info/RECORD +47 -0
  45. unaiverse-0.1.12.dist-info/WHEEL +5 -0
  46. unaiverse-0.1.12.dist-info/licenses/LICENSE +177 -0
  47. unaiverse-0.1.12.dist-info/top_level.txt +1 -0
unaiverse/streams.py ADDED
@@ -0,0 +1,804 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import os
16
+ import csv
17
+ import math
18
+ import torch
19
+ import random
20
+ import pathlib
21
+ from PIL import Image
22
+ from .clock import Clock
23
+ from .dataprops import DataProps
24
+ from datetime import datetime, timezone
25
+ from unaiverse.utils.misc import show_images_grid
26
+
27
+
28
+ class DataStream:
29
+ """
30
+ Base class for handling a generic data stream.
31
+ """
32
+
33
+ def __init__(self,
34
+ props: DataProps,
35
+ clock: Clock = Clock(current_time=datetime.now(timezone.utc).timestamp())) -> None:
36
+ """Initialize a DataStream.
37
+
38
+ Args:
39
+ props (DataProps): Properties of the stream.
40
+ clock (Clock): Clock object for time management (usually provided from outside).
41
+ """
42
+
43
+ # A stream can be turned off
44
+ self.props = props
45
+ self.clock = clock
46
+ self.data = None
47
+ self.data_timestamp = self.clock.get_time()
48
+ self.data_timestamp_when_got_by = {}
49
+ self.data_tag = -1
50
+ self.data_uuid = None
51
+ self.data_uuid_expected = None
52
+ self.data_uuid_clearable = False
53
+ self.enabled = True
54
+
55
+ @staticmethod
56
+ def create(stream: 'DataStream', name: str | None = None, group: str = 'none',
57
+ public: bool = True, pubsub: bool = True):
58
+ """Create and set the name for a given stream, also updates data stream labels.
59
+
60
+ Args:
61
+ stream (DataStream): The stream object to modify.
62
+ name (str): The name of the stream.
63
+ group (str): The name of the group to which the stream belongs.
64
+ public (bool): If the stream is going to be served in the public net or the private one.
65
+ pubsub (bool): If the stream is going to be served by broadcasting (PubSub) or not.
66
+
67
+ Returns:
68
+ Stream: The modified stream with updated group name.
69
+ """
70
+ assert name is not None or group != 'none', "Must provide either name or group name."
71
+ stream.props.set_group(group)
72
+ if name is not None:
73
+ stream.props.set_name(name)
74
+ else:
75
+ stream.props.set_name(stream.props.get_name() + "@" + group) # If name is None, then a group was provided
76
+ stream.props.set_public(public)
77
+ stream.props.set_pubsub(pubsub)
78
+ if (stream.props.is_flat_tensor_with_labels() and
79
+ len(stream.props.tensor_labels) == 1 and stream.props.tensor_labels[0] == 'unk'):
80
+ stream.props.tensor_labels[0] = group if group != 'none' else name
81
+ return stream
82
+
83
+ def enable(self):
84
+ """Enable the stream, allowing data to be retrieved.
85
+
86
+ Returns:
87
+ None
88
+ """
89
+ self.enabled = True
90
+
91
+ def disable(self):
92
+ """Disable the stream, preventing data from being retrieved.
93
+
94
+ Returns:
95
+ None
96
+ """
97
+ self.enabled = False
98
+
99
+ def get_props(self) -> DataProps:
100
+ return self.props
101
+
102
+ def net_hash(self, prefix: str):
103
+ return self.get_props().net_hash(prefix)
104
+
105
+ @staticmethod
106
+ def peer_id_from_net_hash(net_hash):
107
+ return DataProps.peer_id_from_net_hash(net_hash)
108
+
109
+ @staticmethod
110
+ def name_or_group_from_net_hash(net_hash):
111
+ return DataProps.name_or_group_from_net_hash(net_hash)
112
+
113
+ @staticmethod
114
+ def is_pubsub_from_net_hash(net_hash):
115
+ return DataProps.is_pubsub_from_net_hash(net_hash)
116
+
117
+ def is_pubsub(self):
118
+ return self.get_props().is_pubsub()
119
+
120
+ def is_public(self):
121
+ return self.get_props().is_public()
122
+
123
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor | None, int]:
124
+ """Get item for a specific clock cycle. Not implemented for base class: it will be implemented in buffered data
125
+ streams or stream that can generate data on-the-fly.
126
+
127
+ Args:
128
+ idx (int): Index of the data to retrieve.
129
+
130
+ Raises:
131
+ ValueError: Always, since this method should be overridden.
132
+ """
133
+ raise ValueError("Not implemented (expected to be only present in data streams that are buffered or "
134
+ "that can generated on the fly)")
135
+
136
+ def __str__(self) -> str:
137
+ """String representation of the object.
138
+
139
+ Returns:
140
+ str: String representation of the object.
141
+ """
142
+ return f"[DataStream] enabled: {self.enabled}\n\tprops: {self.props}"
143
+
144
+ def __len__(self):
145
+ """Get the length of the stream.
146
+
147
+ Returns:
148
+ int: Infinity (as this is a lifelong stream).
149
+ """
150
+ return math.inf
151
+
152
+ def set(self, data: torch.Tensor | Image.Image | str, data_tag: int = -1, keep_existing_tag: bool = False) -> bool:
153
+ """Set a new data sample into the stream, that will be provided when calling "get()".
154
+
155
+ Args:
156
+ data (torch.Tensor): Data sample to set.
157
+ data_tag (int): Custom data time tag >= 0 (Default: -1, meaning no tags).
158
+ keep_existing_tag: Keep the data tag that is already in the stream, if the provided data_tag arg is -1.
159
+
160
+ Returns:
161
+ bool: True if data was accepted based on time constraints, else False.
162
+ """
163
+ print(f">>>>> [stream set, clock={self.clock.get_cycle()}] Trying to set sample to stream. Sample={data}")
164
+ if (self.enabled and
165
+ (self.props.delta <= 0. or self.props.delta <= (self.clock.get_time() - self.data_timestamp))):
166
+ self.data = self.props.adapt_tensor_to_tensor_labels(data) if data is not None else None
167
+ print(f">>>>> [stream set, clock={self.clock.get_cycle()}] Setting sample to stream. Sample={self.data}")
168
+ self.data_timestamp = self.clock.get_cycle_time()
169
+ self.data_tag = data_tag if data_tag >= 0 else (
170
+ self.clock.get_cycle()) if self.data_tag < 0 or not keep_existing_tag else self.data_tag
171
+ return True
172
+ else:
173
+ print(f">>>>> [stream set, clock={self.clock.get_cycle()}] Cannot set. Sample={data}")
174
+ return False
175
+
176
+ def get(self, requested_by: str | None = None) -> torch.Tensor | None:
177
+ """Get the most recent data sample from the stream (i.e., the last one that was "set").
178
+
179
+ Returns:
180
+ torch.Tensor | None: Adapted data sample if available.
181
+ """
182
+ if requested_by is not None:
183
+ if (requested_by not in self.data_timestamp_when_got_by or
184
+ self.data_timestamp_when_got_by[requested_by] != self.data_timestamp):
185
+ self.data_timestamp_when_got_by[requested_by] = self.data_timestamp
186
+ return self.data
187
+ else:
188
+ return None
189
+ else:
190
+ return self.data
191
+
192
+ def get_(self, requested_by: str | None = None) -> torch.Tensor:
193
+ """Wrapper of 'get'."""
194
+ return self.get(requested_by)
195
+
196
+ def get_timestamp(self) -> float:
197
+ return self.data_timestamp
198
+
199
+ def get_tag(self) -> int:
200
+ return self.data_tag
201
+
202
+ def set_tag(self, data_tag: int):
203
+ self.data_tag = data_tag
204
+
205
+ def get_uuid(self, expected: bool = False) -> str | None:
206
+ return self.data_uuid if not expected else self.data_uuid_expected
207
+
208
+ def set_uuid(self, ref_uuid: str | None, expected: bool = False):
209
+ if not expected:
210
+ self.data_uuid = ref_uuid
211
+ else:
212
+ self.data_uuid_expected = ref_uuid
213
+
214
+ def mark_uuid_as_clearable(self):
215
+ self.data_uuid_clearable = True
216
+
217
+ def clear_uuid_if_marked_as_clearable(self):
218
+ if self.data_uuid_clearable:
219
+ self.data_uuid = None
220
+ self.data_uuid_expected = None
221
+ self.data_uuid_clearable = False
222
+ return True
223
+ return False
224
+
225
+ def clear_uuid(self):
226
+ self.data_uuid = None
227
+ self.data_uuid_expected = None
228
+ return True
229
+
230
+ def set_props(self, data_stream: 'DataStream'):
231
+ """Set (edit) the data properties picking up the ones of another DataStream.
232
+
233
+ Args:
234
+ data_stream (DataStream): the source DataStream from which DataProp is taken.
235
+
236
+ Returns:
237
+ None
238
+ """
239
+ self.props = data_stream.props
240
+
241
+
242
+ class BufferedDataStream(DataStream):
243
+ """
244
+ Data stream with buffer support to store historical data.
245
+ """
246
+
247
+ def __init__(self, props: DataProps, clock: Clock = Clock(current_time=datetime.now(timezone.utc).timestamp()),
248
+ is_static: bool = False, is_queue: bool = False):
249
+ """Initialize a BufferedDataStream.
250
+
251
+ Args:
252
+ is_static (bool): If True, the buffer stores only one item that is reused.
253
+ is_queue (bool): If True, the buffer acts as a data queue.
254
+ """
255
+ super().__init__(props=props, clock=clock)
256
+
257
+ # We store the data samples, and we cache their text representation (for speed)
258
+ self.data_buffer = []
259
+ self.text_buffer = []
260
+
261
+ self.is_static = is_static # A static stream store only one sample and always yields it
262
+ self.is_queue = is_queue
263
+
264
+ # We need to remember the fist cycle in which we started buffering and the last one we buffered
265
+ self.first_cycle = -1
266
+ self.last_cycle = -1
267
+ self.last_get_cycle = -2 # Keep it to -2 (since -1 is the starting value for cycles)
268
+ self.buffered_data_index = -1
269
+
270
+ self.restart_before_next_get = set()
271
+
272
+ def get(self, requested_by: str | None = None) -> tuple[torch.Tensor | None, float]:
273
+ """Get the current data sample based on cycle and buffer.
274
+
275
+ Returns:
276
+ torch.Tensor | None: Current buffered sample.
277
+ """
278
+ if requested_by is not None and requested_by in self.restart_before_next_get:
279
+ self.restart_before_next_get.remove(requested_by)
280
+ self.restart()
281
+
282
+ cycle = self.clock.get_cycle() - self.first_cycle # This ensures that first get clock = first sample
283
+
284
+ # These two lines might make you think "hey, call super().set(self[cycle]), it is the same!"
285
+ # however, it is not like that, since "set" will also call "adapt_to_labels", that is not needed for
286
+ # buffered streams
287
+ if (self.last_get_cycle != cycle and
288
+ (self.props.delta <= 0. or self.props.delta <= (self.clock.get_time() - self.data_timestamp))):
289
+ self.last_get_cycle = cycle
290
+
291
+ if not self.is_queue:
292
+ self.buffered_data_index += 1
293
+ new_data, new_tag = self[self.buffered_data_index]
294
+ else:
295
+ new_data, new_tag = self[0]
296
+ if new_data is not None:
297
+ self.data_buffer.pop(0)
298
+ self.text_buffer.pop(0)
299
+ self.data = new_data
300
+ self.data_timestamp = self.clock.get_cycle_time()
301
+ self.data_tag = new_tag
302
+ return super().get(requested_by)
303
+
304
+ def set(self, data: torch.Tensor, data_tag: int = -1, keep_existing_tag: bool = False):
305
+ """Store a new data sample into the buffer.
306
+
307
+ Args:
308
+ data (torch.Tensor): Data to store.
309
+ data_tag (int): Custom data time tag >= 0 (Default: -1, meaning no tags).
310
+ keep_existing_tag: Keep the data tag that is already in the stream, if the provided data_tag arg is -1.
311
+
312
+ Returns:
313
+ bool: True if the data was buffered.
314
+ """
315
+ if self.is_queue and data is None:
316
+ return True
317
+
318
+ ret = super().set(data, data_tag, keep_existing_tag)
319
+
320
+ if ret:
321
+ if not self.is_static or len(self.data_buffer) == 0:
322
+ self.data_buffer.append((self.props.adapt_tensor_to_tensor_labels(data), self.get_tag()))
323
+ if self.props.is_flat_tensor_with_labels():
324
+ self.text_buffer.append(self.props.to_text(data))
325
+ elif self.props.is_text():
326
+ self.text_buffer.append(data)
327
+ else:
328
+ self.text_buffer.append("")
329
+
330
+ # Boilerplate
331
+ if self.first_cycle < 0:
332
+ self.first_cycle = self.clock.get_cycle()
333
+ self.last_cycle = self.first_cycle
334
+ else:
335
+
336
+ if not self.is_queue:
337
+
338
+ # Filling gaps with "None"
339
+ cycle = self.clock.get_cycle()
340
+ if cycle > self.last_cycle + 1:
341
+ for cycle in range(cycle, self.last_cycle + 1):
342
+ self.data_buffer.append((None, -1))
343
+ self.last_cycle = cycle - 1
344
+
345
+ self.last_cycle += 1
346
+ else:
347
+ self.last_cycle = self.clock.get_cycle()
348
+ return ret
349
+
350
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor | None, int]:
351
+ """Retrieve a sample from the buffer based on the given clock cycle.
352
+
353
+ Args:
354
+ idx (int): Index (>=0) of the sample.
355
+
356
+ Returns:
357
+ torch.Tensor | None: The sample, if available.
358
+ """
359
+ if not self.is_static:
360
+ if idx >= self.__len__() or idx < 0:
361
+ return None, -1
362
+ data, data_tag = self.data_buffer[idx]
363
+ else:
364
+ data, data_tag = self.data_buffer[0]
365
+ return data, data_tag if data_tag >= 0 else self.clock.get_cycle() - self.first_cycle
366
+
367
+ def __len__(self):
368
+ """Get number of samples in the buffer.
369
+
370
+ Returns:
371
+ int: Number of buffered samples.
372
+ """
373
+ return len(self.data_buffer)
374
+
375
+ def set_first_cycle(self, cycle):
376
+ """Manually set the first cycle for the buffer.
377
+
378
+ Args:
379
+ cycle (int): Global cycle to start from.
380
+ """
381
+ self.first_cycle = cycle
382
+ self.last_cycle = cycle + len(self)
383
+
384
+ def get_first_cycle(self):
385
+ """Get the first cycle of the stream.
386
+
387
+ Returns:
388
+ int: First cycle value.
389
+ """
390
+ return self.first_cycle
391
+
392
+ def restart(self):
393
+ """Restart the buffer using the current clock cycle.
394
+ """
395
+ self.set_first_cycle(max(self.clock.get_cycle(), 0))
396
+ self.buffered_data_index = -1
397
+ self.data_timestamp_when_got_by = {}
398
+
399
+ def plan_restart_before_next_get(self, requested_by: str):
400
+ self.restart_before_next_get.add(requested_by)
401
+
402
+ def clear_buffer(self):
403
+ """Clear the data buffer
404
+ """
405
+ self.data_buffer = []
406
+ self.text_buffer = []
407
+
408
+ self.first_cycle = -1
409
+ self.last_cycle = -1
410
+ self.last_get_cycle = -2 # Keep it to -2 (since -1 is the starting value for cycles)
411
+ self.buffered_data_index = -1
412
+
413
+ self.restart_before_next_get = set()
414
+
415
+ def shuffle_buffer(self, seed: int = -1):
416
+ old_buffer = self.data_buffer
417
+ indices = list(range(len(old_buffer)))
418
+
419
+ state = random.getstate()
420
+ if seed >= 0:
421
+ random.seed(seed)
422
+ random.shuffle(indices)
423
+ if seed >= 0:
424
+ random.setstate(state)
425
+
426
+ self.data_buffer = []
427
+ k = 0
428
+ for i in indices:
429
+ self.data_buffer.append((old_buffer[i][0], old_buffer[k][1]))
430
+ k += 1
431
+
432
+ if self.text_buffer is not None and len(self.text_buffer) == len(self.data_buffer):
433
+ old_text_buffer = self.text_buffer
434
+ self.text_buffer = []
435
+ for i in indices:
436
+ self.text_buffer.append(old_text_buffer[i])
437
+
438
+ def to_text_snippet(self, length: int | None = None):
439
+ """Convert buffered text samples to a single long string.
440
+
441
+ Args:
442
+ length (int | None): Optional length of the resulting text snippet.
443
+
444
+ Returns:
445
+ str | None: Human-readable text sequence.
446
+ """
447
+ if self.text_buffer is not None and len(self.text_buffer) > 0:
448
+ if length is not None:
449
+ le = max(length // 2, 1)
450
+ text = " ".join(self.text_buffer[0:min(le, len(self.text_buffer))])
451
+ text += (" ... " + (" ".join(self.text_buffer[max(le, len(self.text_buffer) - le):]))) \
452
+ if len(self.text_buffer) > le else ""
453
+ else:
454
+ text = " ".join(self.text_buffer)
455
+ else:
456
+ text = None
457
+ return text
458
+
459
+ def get_since_timestamp(self, since_what_timestamp: float, stride: int = 1) -> (
460
+ tuple[list[int] | None, list[torch.Tensor | None] | None, int, DataProps]):
461
+ """Retrieve all samples starting from a given timestamp.
462
+
463
+ Args:
464
+ since_what_timestamp (float): Timestamp in seconds.
465
+ stride (int): Sampling stride.
466
+
467
+ Returns:
468
+ Tuple containing list of cycles, data, current cycle, and data properties.
469
+ """
470
+ since_what_cycle = self.clock.time2cycle(since_what_timestamp)
471
+ return self.get_since_cycle(since_what_cycle, stride)
472
+
473
+ def get_since_cycle(self, since_what_cycle: int, stride: int = 1) -> (
474
+ tuple[list[int] | None, list[torch.Tensor | None] | None, int, DataProps]):
475
+ """Retrieve all samples starting from a given clock cycle.
476
+
477
+ Args:
478
+ since_what_cycle (int): Cycle number.
479
+ stride (int): Stride to skip cycles.
480
+
481
+ Returns:
482
+ Tuple with cycles, data, current cycle, and properties.
483
+ """
484
+ assert stride >= 1 and isinstance(stride, int), f"Invalid stride: {stride}"
485
+
486
+ # Notice: this whole routed never calls ".get()", on purpose! it must be as it is
487
+ global_cycle = self.clock.get_cycle()
488
+ if global_cycle < 0:
489
+ return None, None, -1, self.props
490
+
491
+ # Fist check: ensure we do not go beyond the first clock and counting the resulting number of steps
492
+ since_what_cycle = max(since_what_cycle, 0)
493
+ num_steps = global_cycle - since_what_cycle + 1
494
+
495
+ # Second check: now we compute the index we should pass to get item
496
+ since_what_idx_in_getitem = since_what_cycle - self.first_cycle
497
+
498
+ ret_cycles = []
499
+ ret_data = []
500
+
501
+ for k in range(0, num_steps, stride):
502
+ _idx = since_what_idx_in_getitem + k
503
+ _data, _ = self[_idx]
504
+
505
+ if _data is not None:
506
+ ret_cycles.append(since_what_cycle + k)
507
+ ret_data.append(_data)
508
+
509
+ return ret_cycles, ret_data, global_cycle, self.props
510
+
511
+
512
+ class System(DataStream):
513
+ def __init__(self):
514
+ super().__init__(props=DataProps(name=System.__name__, data_type="text", data_desc="System stream",
515
+ pubsub=False))
516
+ self.set("ping")
517
+
518
+ def get(self, requested_by: str | None = None):
519
+ return super().get()
520
+
521
+
522
+ class Dataset(BufferedDataStream):
523
+ """
524
+ A buffered dataset that streams data from a PyTorch dataset and simulates data-streams for input/output.
525
+ """
526
+
527
+ def __init__(self, tensor_dataset: torch.utils.data.Dataset, shape: tuple, index: int = 0, batch_size: int = 1):
528
+ """Initialize a Dataset instance, which wraps around a PyTorch Dataset.
529
+
530
+ Args:
531
+ tensor_dataset (torch.utils.data.Dataset): The PyTorch Dataset to wrap.
532
+ shape (tuple): The shape of each sample from the data stream.
533
+ index (int): The index of the element returned by __getitem__ to pick up.
534
+ """
535
+ sample = tensor_dataset[0][index]
536
+ if isinstance(sample, torch.Tensor):
537
+ dtype = sample.dtype
538
+ elif isinstance(sample, int):
539
+ dtype = torch.long
540
+ elif isinstance(sample, float):
541
+ dtype = torch.float32
542
+ else:
543
+ raise ValueError("Expected tensor data or a scalar")
544
+
545
+ super().__init__(props=DataProps(name=Dataset.__name__,
546
+ data_type="tensor",
547
+ data_desc="dataset",
548
+ tensor_shape=shape,
549
+ tensor_dtype=dtype,
550
+ pubsub=True))
551
+
552
+ n = len(tensor_dataset)
553
+ b = batch_size
554
+ nb = math.ceil(float(n) / float(b))
555
+ r = n - b * (nb - 1)
556
+
557
+ for i in range(0, nb):
558
+ batch = []
559
+ if i == (nb - 1):
560
+ b = r
561
+
562
+ for j in range(0, b):
563
+ sample = tensor_dataset[i * b + j][index]
564
+ if isinstance(sample, (int, float)):
565
+ sample = torch.tensor(sample, dtype=dtype)
566
+ batch.append(sample)
567
+
568
+ self.data_buffer.append((torch.stack(batch), -1))
569
+
570
+ # It was buffered previously than every other thing
571
+ self.restart()
572
+
573
+
574
+ class ImageFileStream(BufferedDataStream):
575
+ """
576
+ A buffered dataset for image data.
577
+ """
578
+
579
+ def __init__(self, image_dir: str, list_of_image_files: str,
580
+ device: torch.device = None, circular: bool = True, show_images: bool = False):
581
+ """Initialize an ImageFileStream instance for streaming image data.
582
+
583
+ Args:
584
+ image_dir (str): The directory containing image files.
585
+ list_of_image_files (str): Path to the file with list of file names of the images.
586
+ device (torch.device): The device to store the tensors on. Default is CPU.
587
+ circular (bool): Whether to loop the dataset or not. Default is True.
588
+ """
589
+ self.image_dir = image_dir
590
+ self.device = device if device is not None else torch.device("cpu")
591
+ self.circular = circular
592
+
593
+ # Reading the image file
594
+ # (assume a file with one filename per line or a CSV format with lines such as: cat.jpg,cat,mammal,animal)
595
+ self.image_paths = []
596
+
597
+ # Calling the constructor
598
+ super().__init__(props=DataProps(name=ImageFileStream.__name__,
599
+ data_type="img",
600
+ pubsub=True))
601
+
602
+ with open(list_of_image_files, 'r') as f:
603
+ for line in f:
604
+ parts = line.strip().split(',') # Tolerates if it is a CVS and the first field is the image file name
605
+ image_name = parts[0]
606
+ self.image_paths.append(os.path.join(image_dir, image_name))
607
+
608
+ # It was buffered previously than every other thing
609
+ self.last_cycle = -1
610
+ self.first_cycle = self.last_cycle - len(self.image_paths) + 1
611
+
612
+ # Possibly print to screen the "clickable" list of images
613
+ if show_images:
614
+ show_images_grid(self.image_paths)
615
+ for i, image_path in enumerate(self.image_paths):
616
+ abs_path = os.path.abspath(image_path)
617
+ file_url = pathlib.Path(abs_path).as_uri()
618
+ basename = os.path.basename(abs_path) # 'photo.jpg'
619
+ parent = os.path.basename(os.path.dirname(abs_path)) # 'images'
620
+ label = os.path.join(parent, basename) if parent else basename
621
+ clickable_label = f"\033]8;;{file_url}\033\\[{label}]\033]8;;\033\\"
622
+ print(str(i) + " => " + clickable_label)
623
+
624
+ def __len__(self):
625
+ """Return the number of images in the dataset.
626
+
627
+ Returns:
628
+ int: Number of images in the dataset.
629
+ """
630
+ return len(self.image_paths)
631
+
632
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor | None, int]:
633
+ """Get the image and label for the specified cycle number.
634
+
635
+ Args:
636
+ idx (int): The cycle number to retrieve data for.
637
+
638
+ Returns:
639
+ tuple: A tuple of tensors (image, label) for the specified cycle.
640
+ """
641
+ if self.circular:
642
+ idx %= self.__len__()
643
+ else:
644
+ if idx >= self.__len__() or idx < 0:
645
+ return None, -1
646
+
647
+ image = Image.open(self.image_paths[idx])
648
+ return image, self.clock.get_cycle() - self.first_cycle
649
+
650
+
651
+ class LabelStream(BufferedDataStream):
652
+ """
653
+ A buffered stream for single and multi-label annotations.
654
+ """
655
+
656
+ def __init__(self, label_dir: str, label_file_csv: str,
657
+ device: torch.device = None, circular: bool = True, single_class: bool = False,
658
+ line_header: bool = False):
659
+ """Initialize an LabelStream instance for streaming labels.
660
+
661
+ Args:
662
+ label_dir (str): The directory containing image files.
663
+ label_file_csv (str): Path to the CSV file with labels for the images.
664
+ device (torch.device): The device to store the tensors on. Default is CPU.
665
+ circular (bool): Whether to loop the dataset or not. Default is True.
666
+ single_class (bool): Whether to only consider a single class for labeling. Default is False.
667
+ """
668
+ self.label_dir = label_dir
669
+ self.device = device if device is not None else torch.device("cpu")
670
+ self.circular = circular
671
+
672
+ # Reading the label file
673
+ # (assume a file with a labeled element per line or a CSV format with lines such as: cat.jpg,cat,mammal,animal)
674
+ self.labels = []
675
+
676
+ class_names = {}
677
+ with open(label_file_csv, 'r') as f:
678
+ for line in f:
679
+ parts = line.strip().split(',')
680
+ label = parts[1:]
681
+ for lab in label:
682
+ class_names[lab] = True
683
+ class_name_to_index = {}
684
+ class_names = list(class_names.keys())
685
+
686
+ # Call the constructor
687
+ super().__init__(props=DataProps(name=LabelStream.__name__,
688
+ data_type="tensor",
689
+ data_desc="label stream",
690
+ tensor_shape=(1, len(class_names)),
691
+ tensor_dtype=str(torch.float),
692
+ tensor_labels=class_names,
693
+ tensor_labeling_rule="geq0.5" if not single_class else "max",
694
+ pubsub=True))
695
+
696
+ for idx, class_name in enumerate(class_names):
697
+ class_name_to_index[class_name] = idx
698
+
699
+ with open(label_file_csv, 'r') as f:
700
+ for line in f:
701
+ parts = line.strip().split(',')
702
+ label = parts if not line_header else parts[1:]
703
+ target_vector = torch.zeros((1, len(class_names)), dtype=torch.float32)
704
+ for lab in label:
705
+ idx = class_name_to_index[lab]
706
+ target_vector[0, idx] = 1.
707
+ self.labels.append(target_vector)
708
+
709
+ # It was buffered previously than every other thing
710
+ self.last_cycle = -1
711
+ self.first_cycle = self.last_cycle - len(self.labels) + 1
712
+
713
+ def __len__(self):
714
+ """Return the number of labels in the dataset.
715
+
716
+ Returns:
717
+ int: Number of labels in the dataset.
718
+ """
719
+ return len(self.labels)
720
+
721
+ def __getitem__(self, idx: int) -> torch.Tensor | None:
722
+ """Get the image and label for the specified cycle number.
723
+
724
+ Args:
725
+ idx (int): The cycle number to retrieve data for.
726
+
727
+ Returns:
728
+ tuple: A tuple of tensors (image, label) for the specified cycle.
729
+ """
730
+ if self.circular:
731
+ idx %= self.__len__()
732
+ else:
733
+ if idx >= self.__len__() or idx < 0:
734
+ return None, -1
735
+
736
+ label = self.labels[idx].to(self.device) # Multi-label vector for the label
737
+ return self.props.adapt_tensor_to_tensor_labels(label), self.clock.get_cycle() - self.first_cycle
738
+
739
+
740
+ class TokensStream(BufferedDataStream):
741
+ """
742
+ A buffered dataset for tokenized text, where each token is paired with its corresponding labels.
743
+ """
744
+
745
+ def __init__(self, tokens_file_csv: str, circular: bool = True, max_tokens: int = -1):
746
+ """Initialize a Tokens instance for streaming tokenized data and associated labels.
747
+
748
+ Args:
749
+ tokens_file_csv (str): Path to the CSV file containing token data.
750
+ circular (bool): Whether to loop the dataset or not. Default is True.
751
+ max_tokens (int): Whether to cut the stream to a maximum number of tokens. Default is -1 (no cut).
752
+ """
753
+ self.circular = circular
754
+
755
+ # Reading the data file (assume a token per line or a CSV format with lines such as:
756
+ # token,category_label1,category_label2,etc.)
757
+ tokens = []
758
+ with open(tokens_file_csv, 'r') as f:
759
+ for line in f:
760
+ parts = next(csv.reader([line], quotechar='"', delimiter=','))
761
+ tokens.append(parts[0])
762
+ if 0 < max_tokens <= len(tokens):
763
+ break
764
+
765
+ # Vocabulary
766
+ idx = 0
767
+ word2id = {}
768
+ sorted_stream_of_tokens = sorted(tokens)
769
+ for token in sorted_stream_of_tokens:
770
+ if token not in word2id:
771
+ word2id[token] = idx
772
+ idx += 1
773
+ id2word = [""] * len(word2id)
774
+ for _word, _id in word2id.items():
775
+ id2word[_id] = _word
776
+
777
+ # Calling the constructor
778
+ super().__init__(props=DataProps(name=TokensStream.__name__,
779
+ data_type="text",
780
+ data_desc="stream of words",
781
+ stream_to_proc_transforms=word2id,
782
+ proc_to_stream_transforms=id2word,
783
+ pubsub=True))
784
+
785
+ # Tokenized text
786
+ for i, token in enumerate(tokens):
787
+ data = token
788
+ self.data_buffer.append((data, -1))
789
+
790
+ # It was buffered previously than every other thing
791
+ self.restart()
792
+
793
+ def __getitem__(self, idx: int) -> torch.Tensor | None:
794
+ """Get the image and label for the specified cycle number.
795
+
796
+ Args:
797
+ idx (int): The index to retrieve data for.
798
+
799
+ Returns:
800
+ tuple: A tuple of tensors (image, label) for the specified cycle.
801
+ """
802
+ if self.circular:
803
+ idx %= self.__len__()
804
+ return super().__getitem__(idx)