rbbt-dm 1.2.1 → 1.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ab775c0224960820e5c62e294e6a183be49201da15710b66544762e1aaf97ebf
4
- data.tar.gz: 8fffb47ba226f06d1f41a8893d085bdc12c33c021cf2f0152f4cc741db36e420
3
+ metadata.gz: 2ff72107967b0f7c654697f3a7b3c0ef10f7a5264d775117f12f74084a2819b2
4
+ data.tar.gz: 6b9a58b5a2723c095332f79a37d9c1c7f4bc1431410f23a55beeed1c3b52f7ad
5
5
  SHA512:
6
- metadata.gz: 8be084156063cd93c7fe905bc4b6248dd376bbfcff8e650cbb03a4cc5c28f29dbcdaa2801895a3663067d7660e8bc2cf96682829519ebd6511a7a74cec021da0
7
- data.tar.gz: 9c8570722319caf5afe60c90778d0b8517e70064030e0081968d919a314cfe1af90f63d13a7b8ce8dd56da6644db89916a729a4005f1e20745c8d4b45d50394c
6
+ metadata.gz: a0fb4198cb0be3aa5253df0f655ee230621dd26a31956a774fffe95eac35f4c8b558a41c0e340c25c6eef463760ff6230b967f09eb09671b2078a50066067384
7
+ data.tar.gz: 0bd6c3667a8ec26ed092c54e78176671807ed0634e136e497303e34a17b2740e5e023041bc06389ac187de39b54942b9b1c5cd77abbc067c89250424654b6974
@@ -436,7 +436,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
436
436
  ensure
437
437
  @features = orig_features
438
438
  @labels = orig_labels
439
- end
439
+ end unless folds == -1
440
440
  self.train unless folds == 1
441
441
  res
442
442
  end
@@ -0,0 +1,112 @@
1
+ #{{{ LOAD MODEL
2
+
3
+ def import_module_class(module, class_name):
4
+ exec(f"from {module} import {class_name}")
5
+ return eval(class_name)
6
+
7
+ def load_model(task, checkpoint):
8
+ class_name = 'AutoModelFor' + task
9
+ return import_module_class('transformers', class_name).from_pretrained(checkpoint)
10
+
11
+ def load_tokenizer(task, checkpoint):
12
+ class_name = 'AutoTokenizer'
13
+ return import_module_class('transformers', class_name).from_pretrained(checkpoint)
14
+
15
+ def load_model_and_tokenizer(task, checkpoint):
16
+ model = load_model(task, checkpoint)
17
+ tokenizer = load_tokenizer(task, checkpoint)
18
+ return model, tokenizer
19
+
20
+ #{{{ SIMPLE EVALUATE
21
+
22
+ def forward(model, features):
23
+ return model(**features)
24
+
25
+ def logits(predictions):
26
+ logits = predictions["logits"]
27
+ return [v.detach().cpu().numpy() for v in logits]
28
+
29
+ def eval_model(model, tokenizer, texts, return_logits = True):
30
+ features = tokenizer(texts, return_tensors='pt', truncation=True).to(model.device)
31
+ predictions = forward(model, features)
32
+ if (return_logits):
33
+ return logits(predictions)
34
+ else:
35
+ return predictions
36
+
37
+ #{{{ TRAIN AND PREDICT
38
+
39
+ def load_tsv(tsv_file):
40
+ from datasets import load_dataset
41
+ return load_dataset('csv', data_files=[tsv_file], sep="\t")
42
+
43
+ def tsv_dataset(tokenizer, tsv_file):
44
+ dataset = load_tsv(tsv_file)
45
+ tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True) , batched=True)
46
+ return tokenized_dataset
47
+
48
+ def training_args(*args, **kwargs):
49
+ from transformers import TrainingArguments
50
+ training_args = TrainingArguments(*args, **kwargs)
51
+ return training_args
52
+
53
+
54
+ def train_model(model, tokenizer, training_args, tsv_file):
55
+ from transformers import Trainer
56
+
57
+ tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
58
+
59
+ trainer = Trainer(
60
+ model,
61
+ training_args,
62
+ train_dataset = tokenized_dataset["train"],
63
+ tokenizer = tokenizer
64
+ )
65
+
66
+ trainer.train()
67
+
68
+ def find_tokens_in_input(dataset, token_ids):
69
+ position_rows = []
70
+
71
+ for row in dataset:
72
+ input_ids = row["input_ids"]
73
+
74
+ if (not hasattr(token_ids, "__len__")):
75
+ token_ids = [token_ids]
76
+
77
+ positions = []
78
+ for token_id in token_ids:
79
+
80
+ item_positions = []
81
+ for i in range(len(input_ids)):
82
+ if input_ids[i] == token_id:
83
+ item_positions.append(i)
84
+
85
+ positions.append(item_positions)
86
+
87
+
88
+ position_rows.append(positions)
89
+
90
+ return position_rows
91
+
92
+
93
+
94
+ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
95
+ from transformers import Trainer
96
+
97
+ tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
98
+
99
+ trainer = Trainer(
100
+ model,
101
+ training_args,
102
+ tokenizer = tokenizer
103
+ )
104
+
105
+ result = trainer.predict(test_dataset = tokenized_dataset["train"])
106
+ if (locate_tokens != None):
107
+ token_ids = tokenizer.convert_tokens_to_ids(locate_tokens)
108
+ token_positions = find_tokens_in_input(tokenized_dataset["train"], token_ids)
109
+ return dict(result=result, token_positions=token_positions)
110
+ else:
111
+ return result
112
+
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.1
4
+ version: 1.2.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez
@@ -113,6 +113,7 @@ files:
113
113
  - lib/rbbt/vector/model/svm.rb
114
114
  - lib/rbbt/vector/model/tensorflow.rb
115
115
  - lib/rbbt/vector/model/util.rb
116
+ - python/rbbt_dm/huggingface.py
116
117
  - share/R/MA.R
117
118
  - share/R/barcode.R
118
119
  - share/R/heatmap.3.R