rxnn 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/dataset.py CHANGED
@@ -41,7 +41,6 @@ class BaseDataset(Dataset):
41
41
  self.bg_next.append(self.get_tokenized_text(i))
42
42
  self.last_idx = self.batch_size - 1
43
43
 
44
-
45
44
  def __len__(self):
46
45
  return len(self.texts if not self.is_pre_tokenized else self.inputs)
47
46
 
@@ -105,7 +104,8 @@ class BaseDataset(Dataset):
105
104
  return_attention_mask=True
106
105
  )
107
106
  if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
108
- inputs['input_ids'][0][(inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
107
+ inputs['input_ids'][0][
108
+ (inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
109
109
  if not (inputs['input_ids'][0] >= 0).all():
110
110
  inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
111
111
 
@@ -123,7 +123,8 @@ class BaseDataset(Dataset):
123
123
  split_point = int(len(self.texts) * ((1 - size) if not from_start else size))
124
124
  if not isinstance(self.texts, list):
125
125
  subset = self.texts.select(range(split_point, len(self.texts)) if not from_start else range(split_point))
126
- self.texts = self.texts.select(range(split_point) if not from_start else range(split_point, len(self.texts)))
126
+ self.texts = self.texts.select(
127
+ range(split_point) if not from_start else range(split_point, len(self.texts)))
127
128
  else:
128
129
  subset = self.texts[split_point:-1] if not from_start else self.texts[0:split_point]
129
130
  self.texts = self.texts[0:split_point] if not from_start else self.texts[split_point:-1]
@@ -155,7 +156,6 @@ class BaseDataset(Dataset):
155
156
  print(f'Processed {index + 1}/{num_texts}')
156
157
  self.is_pre_tokenized = True
157
158
 
158
-
159
159
  @classmethod
160
160
  def from_hf_hub(
161
161
  cls,
@@ -247,7 +247,8 @@ class BaseDataset(Dataset):
247
247
  tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
248
248
 
249
249
  hf_datasets = [
250
- load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
250
+ load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in
251
+ zip(dataset_ids, subsets)
251
252
  ] if subsets is not None else [
252
253
  load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
253
254
  ]
@@ -302,7 +303,8 @@ class BaseDataset(Dataset):
302
303
  tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
303
304
 
304
305
  hf_datasets = [
305
- load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
306
+ load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in
307
+ zip(dataset_ids, subsets)
306
308
  ] if subsets is not None else [
307
309
  load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
308
310
  ]
@@ -312,7 +314,8 @@ class BaseDataset(Dataset):
312
314
  hf_dataset = concatenate_datasets([ds_dict['train'] for ds_dict in hf_ds_dicts])
313
315
  hf_valid_dataset = concatenate_datasets([ds_dict['test'] for ds_dict in hf_ds_dicts])
314
316
 
315
- return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs), cls(hf_valid_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
317
+ return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs), cls(
318
+ hf_valid_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
316
319
 
317
320
 
318
321
  class JointLMDataset(BaseDataset):
@@ -427,6 +430,7 @@ class AutoregressiveLMDataset(BaseDataset):
427
430
  'targets': targets
428
431
  }
429
432
 
433
+
430
434
  class BaseInteractionDataset(Dataset):
431
435
  def __init__(
432
436
  self,
@@ -463,7 +467,6 @@ class BaseInteractionDataset(Dataset):
463
467
  self.bg_next.append(self.get_tokenized_text(i))
464
468
  self.last_idx = self.batch_size - 1
465
469
 
466
-
467
470
  def __len__(self):
468
471
  return len(self.interactions if not self.is_pre_tokenized else self.inputs)
469
472
 
@@ -528,7 +531,8 @@ class BaseInteractionDataset(Dataset):
528
531
  return_attention_mask=True
529
532
  )
530
533
  if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
531
- inputs['input_ids'][0][(inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
534
+ inputs['input_ids'][0][
535
+ (inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
532
536
  if not (inputs['input_ids'][0] >= 0).all():
533
537
  inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
534
538
 
@@ -545,12 +549,16 @@ class BaseInteractionDataset(Dataset):
545
549
  def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "BaseInteractionDataset":
546
550
  split_point = int(len(self.interactions) * ((1 - size) if not from_start else size))
547
551
  if not isinstance(self.interactions, list):
548
- subset = self.interactions.select(range(split_point, len(self.interactions)) if not from_start else range(split_point))
549
- self.interactions = self.interactions.select(range(split_point) if not from_start else range(split_point, len(self.interactions)))
552
+ subset = self.interactions.select(
553
+ range(split_point, len(self.interactions)) if not from_start else range(split_point))
554
+ self.interactions = self.interactions.select(
555
+ range(split_point) if not from_start else range(split_point, len(self.interactions)))
550
556
  else:
551
557
  subset = self.interactions[split_point:-1] if not from_start else self.interactions[0:split_point]
552
- self.interactions = self.interactions[0:split_point] if not from_start else self.interactions[split_point:-1]
553
- return self.__class__(subset, self.tokenizer, max_seq_len=self.max_seq_len, query_field=self.query_field, answer_field=self.answer_field, **kwargs)
558
+ self.interactions = self.interactions[0:split_point] if not from_start else self.interactions[
559
+ split_point:-1]
560
+ return self.__class__(subset, self.tokenizer, max_seq_len=self.max_seq_len, query_field=self.query_field,
561
+ answer_field=self.answer_field, **kwargs)
554
562
 
555
563
  def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000):
556
564
  """
@@ -577,7 +585,6 @@ class BaseInteractionDataset(Dataset):
577
585
  print(f'Processed {index + 1}/{num_texts}')
578
586
  self.is_pre_tokenized = True
579
587
 
580
-
581
588
  @classmethod
582
589
  def from_hf_hub(
583
590
  cls,
@@ -624,7 +631,8 @@ class BaseInteractionDataset(Dataset):
624
631
 
625
632
  query_field, answer_field = target_fields
626
633
 
627
- return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field, **kwargs)
634
+ return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field,
635
+ **kwargs)
628
636
 
629
637
  @classmethod
630
638
  def concat_from_hf_hub(
@@ -671,7 +679,8 @@ class BaseInteractionDataset(Dataset):
671
679
  tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
672
680
 
673
681
  hf_datasets = [
674
- load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
682
+ load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in
683
+ zip(dataset_ids, subsets)
675
684
  ] if subsets is not None else [
676
685
  load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
677
686
  ]
@@ -679,7 +688,8 @@ class BaseInteractionDataset(Dataset):
679
688
 
680
689
  query_field, answer_field = target_fields
681
690
 
682
- return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field, **kwargs)
691
+ return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field,
692
+ **kwargs)
683
693
 
684
694
  @classmethod
685
695
  def concat_from_hf_hub_with_subset(
@@ -728,7 +738,8 @@ class BaseInteractionDataset(Dataset):
728
738
  tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
729
739
 
730
740
  hf_datasets = [
731
- load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
741
+ load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in
742
+ zip(dataset_ids, subsets)
732
743
  ] if subsets is not None else [
733
744
  load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
734
745
  ]
@@ -740,22 +751,25 @@ class BaseInteractionDataset(Dataset):
740
751
 
741
752
  query_field, answer_field = target_fields
742
753
 
743
- return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field, **kwargs), cls(hf_valid_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field, **kwargs)
754
+ return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field,
755
+ **kwargs), cls(hf_valid_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field,
756
+ answer_field=answer_field, **kwargs)
757
+
744
758
 
745
759
  class DecoderSftDataset(BaseInteractionDataset):
746
760
  def __init__(
747
- self,
748
- interactions: Union[list[dict], HfDataset],
749
- tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
750
- max_seq_len: int = 1024,
751
- query_field: str = 'query',
752
- answer_field: str = 'answer',
753
- cache_tokenized: bool = False,
754
- cache_remove_text: bool = True,
755
- tokenize_in_background: bool = False,
756
- batch_size: int = 1,
757
- *args,
758
- **kwargs
761
+ self,
762
+ interactions: Union[list[dict], HfDataset],
763
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
764
+ max_seq_len: int = 1024,
765
+ query_field: str = 'query',
766
+ answer_field: str = 'answer',
767
+ cache_tokenized: bool = False,
768
+ cache_remove_text: bool = True,
769
+ tokenize_in_background: bool = False,
770
+ batch_size: int = 1,
771
+ *args,
772
+ **kwargs
759
773
  ):
760
774
  super(DecoderSftDataset, self).__init__(
761
775
  interactions,
@@ -784,21 +798,22 @@ class DecoderSftDataset(BaseInteractionDataset):
784
798
  'targets': targets
785
799
  }
786
800
 
801
+
787
802
  class EncoderSftDataset(BaseInteractionDataset):
788
803
  def __init__(
789
- self,
790
- interactions: Union[list[dict], HfDataset],
791
- tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
792
- max_seq_len: int = 1024,
793
- query_field: str = 'query',
794
- answer_field: str = 'answer',
795
- cache_tokenized: bool = False,
796
- cache_remove_text: bool = True,
797
- tokenize_in_background: bool = False,
798
- batch_size: int = 1,
799
- mask_prob: float = 0.15,
800
- *args,
801
- **kwargs
804
+ self,
805
+ interactions: Union[list[dict], HfDataset],
806
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
807
+ max_seq_len: int = 1024,
808
+ query_field: str = 'query',
809
+ answer_field: str = 'answer',
810
+ cache_tokenized: bool = False,
811
+ cache_remove_text: bool = True,
812
+ tokenize_in_background: bool = False,
813
+ batch_size: int = 1,
814
+ mask_prob: float = 0.15,
815
+ *args,
816
+ **kwargs
802
817
  ):
803
818
  super(EncoderSftDataset, self).__init__(
804
819
  interactions,
@@ -839,8 +854,10 @@ class EncoderSftDataset(BaseInteractionDataset):
839
854
  'labels': labels
840
855
  }
841
856
 
857
+
842
858
  MrlDataItem: TypeAlias = dict[str, Union[dict[str, torch.Tensor], list[dict[str, dict[str, torch.Tensor]]]]]
843
859
 
860
+
844
861
  class MrlCurriculumDataset(Dataset):
845
862
  def __init__(
846
863
  self,
@@ -931,12 +948,16 @@ class MrlCurriculumDataset(Dataset):
931
948
  def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "MRlCurriculumDataset":
932
949
  split_point = int(len(self.episodes) * ((1 - size) if not from_start else size))
933
950
  if not isinstance(self.episodes, list):
934
- subset = self.episodes.select(range(split_point, len(self.episodes)) if not from_start else range(split_point))
935
- self.episodes = self.episodes.select(range(split_point) if not from_start else range(split_point, len(self.episodes)))
951
+ subset = self.episodes.select(
952
+ range(split_point, len(self.episodes)) if not from_start else range(split_point))
953
+ self.episodes = self.episodes.select(
954
+ range(split_point) if not from_start else range(split_point, len(self.episodes)))
936
955
  else:
937
956
  subset = self.episodes[split_point:-1] if not from_start else self.episodes[0:split_point]
938
957
  self.episodes = self.episodes[0:split_point] if not from_start else self.episodes[split_point:-1]
939
- return self.__class__(subset, tokenizer=self.tokenizer, query_field=self.query_field, answer_field=self.answer_field, interactions_field=self.interactions_field, **kwargs)
958
+ return self.__class__(subset, tokenizer=self.tokenizer, query_field=self.query_field,
959
+ answer_field=self.answer_field, interactions_field=self.interactions_field,
960
+ max_seq_len=self.max_seq_len, **kwargs)
940
961
 
941
962
  def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000, keep_order: bool = False):
942
963
  """
@@ -977,6 +998,7 @@ class MrlCurriculumDataset(Dataset):
977
998
  answer_field: str = 'answer',
978
999
  interactions_field: str = 'interactions',
979
1000
  load_kwargs: dict = None,
1001
+ max_seq_len: int = 1024,
980
1002
  **kwargs
981
1003
  ):
982
1004
  """
@@ -993,19 +1015,23 @@ class MrlCurriculumDataset(Dataset):
993
1015
  answer_field (str): Answer field (default: "answer")
994
1016
  interactions_field (str): Interactions field (default: "interactions")
995
1017
  load_kwargs (dict): Additional args for HuggingFace API load_dataset function
1018
+ max_seq_len (int): Maximum sequence length (default: 1024)
996
1019
  **kwargs: Additional args for RxNN Dataset class
997
1020
  """
998
1021
  if load_kwargs is None:
999
- load_kwargs = {}
1022
+ load_kwargs = {}
1000
1023
 
1001
1024
  hf_dataset = load_dataset(dataset_id, mrl_subset, split=split, **load_kwargs)
1002
1025
 
1003
- return cls(hf_dataset, tokenizer, query_field=query_field, answer_field=answer_field, interactions_field=interactions_field, **kwargs)
1026
+ return cls(hf_dataset, tokenizer, query_field=query_field, answer_field=answer_field,
1027
+ interactions_field=interactions_field, max_seq_len=max_seq_len, **kwargs)
1004
1028
 
1005
1029
  @staticmethod
1006
1030
  def collate_mrl_batch(batch: list[MrlDataItem]) -> MrlDataItem:
1007
1031
  """Collate function for MRL curriculum dataset with nested interactions"""
1008
- def collate_interaction_batch(interaction_batch: Union[list[dict[str, dict[str, torch.Tensor]]], tuple[Any]]) -> dict[str, dict[str, torch.Tensor]]:
1032
+
1033
+ def collate_interaction_batch(interaction_batch: Union[list[dict[str, dict[str, torch.Tensor]]], tuple[Any]]) -> \
1034
+ dict[str, dict[str, torch.Tensor]]:
1009
1035
  """Helper to collate a batch of interactions"""
1010
1036
  return {
1011
1037
  'query': {
@@ -1022,23 +1048,26 @@ class MrlCurriculumDataset(Dataset):
1022
1048
  transposed_interactions = list(zip(*batch_interactions))
1023
1049
 
1024
1050
  return {
1025
- **collate_interaction_batch(batch), # Collate initial query and answer
1051
+ **collate_interaction_batch(batch), # Collate initial query and answer
1026
1052
  'interactions': [
1027
1053
  collate_interaction_batch(step_batch) for step_batch in transposed_interactions
1028
1054
  ]
1029
1055
  }
1030
1056
 
1057
+
1031
1058
  class MrlDatasetItem(TypedDict):
1032
1059
  steps: int
1033
1060
  is_long_range: bool
1034
1061
  dataset: MrlCurriculumDataset
1035
1062
  eval_dataset: Optional[MrlCurriculumDataset]
1036
1063
 
1064
+
1037
1065
  class MrlDatasetLoadItem(TypedDict):
1038
1066
  subset_name: str
1039
1067
  steps: int
1040
1068
  is_long_range: bool
1041
1069
 
1070
+
1042
1071
  class MrlDatasets:
1043
1072
  def __init__(self, datasets: list[MrlDatasetItem]):
1044
1073
  self.datasets = datasets
@@ -1061,7 +1090,8 @@ class MrlDatasets:
1061
1090
  @property
1062
1091
  def is_pre_tokenized(self) -> bool:
1063
1092
  train_tokenized = all(item['dataset'].is_pre_tokenized for item in self.datasets)
1064
- eval_tokenized = all(item['eval_dataset'].is_pre_tokenized for item in self.datasets if item['eval_dataset'] is not None)
1093
+ eval_tokenized = all(
1094
+ item['eval_dataset'].is_pre_tokenized for item in self.datasets if item['eval_dataset'] is not None)
1065
1095
  return train_tokenized and eval_tokenized
1066
1096
 
1067
1097
  def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000, keep_order: bool = False):
@@ -1147,4 +1177,4 @@ class MrlDatasets:
1147
1177
 
1148
1178
  mrl_datasets = [dataset_item(item) for item in mrl_curriculum_steps]
1149
1179
 
1150
- return cls(mrl_datasets)
1180
+ return cls(mrl_datasets)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.14
3
+ Version: 0.2.16
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -14,7 +14,7 @@ rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
15
  rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
17
- rxnn/training/dataset.py,sha256=m1opjNA7XHl6Ys-NtERM00c0BLN2xuu84lsfXp-3GQA,50478
17
+ rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
19
  rxnn/training/mrl.py,sha256=CezloyaXOKrc_F_eDt99EZ1fmKAMCCCMh5Ry6vF82Ro,39607
20
20
  rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.14.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.14.dist-info/METADATA,sha256=dutamudjxMj9IzykuCONpMyqnU4emEEwvseD4nmKkfs,25960
37
- rxnn-0.2.14.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.14.dist-info/RECORD,,
35
+ rxnn-0.2.16.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.16.dist-info/METADATA,sha256=rSE80lLQZTqjOO9CJfLtyIQMCmgLep1suzXj2_DmPNI,25960
37
+ rxnn-0.2.16.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.16.dist-info/RECORD,,
File without changes
File without changes