rbbt-dm 1.2.1 → 1.2.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ab775c0224960820e5c62e294e6a183be49201da15710b66544762e1aaf97ebf
4
- data.tar.gz: 8fffb47ba226f06d1f41a8893d085bdc12c33c021cf2f0152f4cc741db36e420
3
+ metadata.gz: 2ff72107967b0f7c654697f3a7b3c0ef10f7a5264d775117f12f74084a2819b2
4
+ data.tar.gz: 6b9a58b5a2723c095332f79a37d9c1c7f4bc1431410f23a55beeed1c3b52f7ad
5
5
  SHA512:
6
- metadata.gz: 8be084156063cd93c7fe905bc4b6248dd376bbfcff8e650cbb03a4cc5c28f29dbcdaa2801895a3663067d7660e8bc2cf96682829519ebd6511a7a74cec021da0
7
- data.tar.gz: 9c8570722319caf5afe60c90778d0b8517e70064030e0081968d919a314cfe1af90f63d13a7b8ce8dd56da6644db89916a729a4005f1e20745c8d4b45d50394c
6
+ metadata.gz: a0fb4198cb0be3aa5253df0f655ee230621dd26a31956a774fffe95eac35f4c8b558a41c0e340c25c6eef463760ff6230b967f09eb09671b2078a50066067384
7
+ data.tar.gz: 0bd6c3667a8ec26ed092c54e78176671807ed0634e136e497303e34a17b2740e5e023041bc06389ac187de39b54942b9b1c5cd77abbc067c89250424654b6974
@@ -436,7 +436,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
436
436
  ensure
437
437
  @features = orig_features
438
438
  @labels = orig_labels
439
- end
439
+ end unless folds == -1
440
440
  self.train unless folds == 1
441
441
  res
442
442
  end
@@ -0,0 +1,112 @@
1
+ #{{{ LOAD MODEL
2
+
3
+ def import_module_class(module, class_name):
4
+ exec(f"from {module} import {class_name}")
5
+ return eval(class_name)
6
+
7
+ def load_model(task, checkpoint):
8
+ class_name = 'AutoModelFor' + task
9
+ return import_module_class('transformers', class_name).from_pretrained(checkpoint)
10
+
11
+ def load_tokenizer(task, checkpoint):
12
+ class_name = 'AutoTokenizer'
13
+ return import_module_class('transformers', class_name).from_pretrained(checkpoint)
14
+
15
+ def load_model_and_tokenizer(task, checkpoint):
16
+ model = load_model(task, checkpoint)
17
+ tokenizer = load_tokenizer(task, checkpoint)
18
+ return model, tokenizer
19
+
20
+ #{{{ SIMPLE EVALUATE
21
+
22
+ def forward(model, features):
23
+ return model(**features)
24
+
25
+ def logits(predictions):
26
+ logits = predictions["logits"]
27
+ return [v.detach().cpu().numpy() for v in logits]
28
+
29
+ def eval_model(model, tokenizer, texts, return_logits = True):
30
+ features = tokenizer(texts, return_tensors='pt', truncation=True).to(model.device)
31
+ predictions = forward(model, features)
32
+ if (return_logits):
33
+ return logits(predictions)
34
+ else:
35
+ return predictions
36
+
37
+ #{{{ TRAIN AND PREDICT
38
+
39
+ def load_tsv(tsv_file):
40
+ from datasets import load_dataset
41
+ return load_dataset('csv', data_files=[tsv_file], sep="\t")
42
+
43
+ def tsv_dataset(tokenizer, tsv_file):
44
+ dataset = load_tsv(tsv_file)
45
+ tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True) , batched=True)
46
+ return tokenized_dataset
47
+
48
+ def training_args(*args, **kwargs):
49
+ from transformers import TrainingArguments
50
+ training_args = TrainingArguments(*args, **kwargs)
51
+ return training_args
52
+
53
+
54
+ def train_model(model, tokenizer, training_args, tsv_file):
55
+ from transformers import Trainer
56
+
57
+ tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
58
+
59
+ trainer = Trainer(
60
+ model,
61
+ training_args,
62
+ train_dataset = tokenized_dataset["train"],
63
+ tokenizer = tokenizer
64
+ )
65
+
66
+ trainer.train()
67
+
68
+ def find_tokens_in_input(dataset, token_ids):
69
+ position_rows = []
70
+
71
+ for row in dataset:
72
+ input_ids = row["input_ids"]
73
+
74
+ if (not hasattr(token_ids, "__len__")):
75
+ token_ids = [token_ids]
76
+
77
+ positions = []
78
+ for token_id in token_ids:
79
+
80
+ item_positions = []
81
+ for i in range(len(input_ids)):
82
+ if input_ids[i] == token_id:
83
+ item_positions.append(i)
84
+
85
+ positions.append(item_positions)
86
+
87
+
88
+ position_rows.append(positions)
89
+
90
+ return position_rows
91
+
92
+
93
+
94
+ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
95
+ from transformers import Trainer
96
+
97
+ tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
98
+
99
+ trainer = Trainer(
100
+ model,
101
+ training_args,
102
+ tokenizer = tokenizer
103
+ )
104
+
105
+ result = trainer.predict(test_dataset = tokenized_dataset["train"])
106
+ if (locate_tokens != None):
107
+ token_ids = tokenizer.convert_tokens_to_ids(locate_tokens)
108
+ token_positions = find_tokens_in_input(tokenized_dataset["train"], token_ids)
109
+ return dict(result=result, token_positions=token_positions)
110
+ else:
111
+ return result
112
+
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.1
4
+ version: 1.2.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez
@@ -113,6 +113,7 @@ files:
113
113
  - lib/rbbt/vector/model/svm.rb
114
114
  - lib/rbbt/vector/model/tensorflow.rb
115
115
  - lib/rbbt/vector/model/util.rb
116
+ - python/rbbt_dm/huggingface.py
116
117
  - share/R/MA.R
117
118
  - share/R/barcode.R
118
119
  - share/R/heatmap.3.R