scout-ai 0.2.0 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.vimproject +91 -10
- data/Rakefile +1 -0
- data/VERSION +1 -1
- data/bin/scout-ai +2 -0
- data/lib/scout/llm/agent/chat.rb +24 -0
- data/lib/scout/llm/agent.rb +13 -13
- data/lib/scout/llm/ask.rb +26 -16
- data/lib/scout/llm/backends/bedrock.rb +129 -0
- data/lib/scout/llm/backends/huggingface.rb +6 -21
- data/lib/scout/llm/backends/ollama.rb +69 -36
- data/lib/scout/llm/backends/openai.rb +85 -35
- data/lib/scout/llm/backends/openwebui.rb +1 -1
- data/lib/scout/llm/backends/relay.rb +3 -2
- data/lib/scout/llm/backends/responses.rb +272 -0
- data/lib/scout/llm/chat.rb +547 -0
- data/lib/scout/llm/parse.rb +70 -13
- data/lib/scout/llm/tools.rb +126 -5
- data/lib/scout/llm/utils.rb +17 -10
- data/lib/scout/model/base.rb +19 -0
- data/lib/scout/model/python/base.rb +25 -0
- data/lib/scout/model/python/huggingface/causal/next_token.rb +23 -0
- data/lib/scout/model/python/huggingface/causal.rb +29 -0
- data/lib/scout/model/python/huggingface/classification +0 -0
- data/lib/scout/model/python/huggingface/classification.rb +50 -0
- data/lib/scout/model/python/huggingface.rb +112 -0
- data/lib/scout/model/python/torch/dataloader.rb +57 -0
- data/lib/scout/model/python/torch/helpers.rb +84 -0
- data/lib/scout/model/python/torch/introspection.rb +34 -0
- data/lib/scout/model/python/torch/load_and_save.rb +47 -0
- data/lib/scout/model/python/torch.rb +94 -0
- data/lib/scout/model/util/run.rb +181 -0
- data/lib/scout/model/util/save.rb +81 -0
- data/lib/scout-ai.rb +3 -1
- data/python/scout_ai/__init__.py +35 -0
- data/python/scout_ai/__pycache__/__init__.cpython-310.pyc +0 -0
- data/python/scout_ai/__pycache__/__init__.cpython-311.pyc +0 -0
- data/python/scout_ai/__pycache__/huggingface.cpython-310.pyc +0 -0
- data/python/scout_ai/__pycache__/huggingface.cpython-311.pyc +0 -0
- data/python/scout_ai/__pycache__/util.cpython-310.pyc +0 -0
- data/python/scout_ai/__pycache__/util.cpython-311.pyc +0 -0
- data/python/scout_ai/atcold/__init__.py +0 -0
- data/python/scout_ai/atcold/plot_lib.py +141 -0
- data/python/scout_ai/atcold/spiral.py +27 -0
- data/python/scout_ai/huggingface/data.py +48 -0
- data/python/scout_ai/huggingface/eval.py +60 -0
- data/python/scout_ai/huggingface/model.py +29 -0
- data/python/scout_ai/huggingface/rlhf.py +83 -0
- data/python/scout_ai/huggingface/train/__init__.py +34 -0
- data/python/scout_ai/huggingface/train/__pycache__/__init__.cpython-310.pyc +0 -0
- data/python/scout_ai/huggingface/train/__pycache__/next_token.cpython-310.pyc +0 -0
- data/python/scout_ai/huggingface/train/next_token.py +315 -0
- data/python/scout_ai/language_model.py +70 -0
- data/python/scout_ai/util.py +32 -0
- data/scout-ai.gemspec +130 -0
- data/scout_commands/agent/ask +133 -15
- data/scout_commands/agent/kb +15 -0
- data/scout_commands/llm/ask +71 -12
- data/scout_commands/llm/process +4 -2
- data/test/data/cat.jpg +0 -0
- data/test/scout/llm/agent/test_chat.rb +14 -0
- data/test/scout/llm/backends/test_bedrock.rb +60 -0
- data/test/scout/llm/backends/test_huggingface.rb +3 -3
- data/test/scout/llm/backends/test_ollama.rb +48 -10
- data/test/scout/llm/backends/test_openai.rb +96 -11
- data/test/scout/llm/backends/test_responses.rb +115 -0
- data/test/scout/llm/test_ask.rb +1 -0
- data/test/scout/llm/test_chat.rb +214 -0
- data/test/scout/llm/test_parse.rb +81 -2
- data/test/scout/model/python/huggingface/causal/test_next_token.rb +59 -0
- data/test/scout/model/python/huggingface/test_causal.rb +33 -0
- data/test/scout/model/python/huggingface/test_classification.rb +30 -0
- data/test/scout/model/python/test_base.rb +44 -0
- data/test/scout/model/python/test_huggingface.rb +9 -0
- data/test/scout/model/python/test_torch.rb +71 -0
- data/test/scout/model/python/torch/test_helpers.rb +14 -0
- data/test/scout/model/test_base.rb +117 -0
- data/test/scout/model/util/test_save.rb +31 -0
- metadata +72 -5
- data/questions/coach +0 -2
@@ -0,0 +1,47 @@
|
|
1
|
+
class TorchModel
|
2
|
+
def self.model_architecture(state_file)
|
3
|
+
state_file + '.architecture'
|
4
|
+
end
|
5
|
+
|
6
|
+
def self.save_state(state, state_file)
|
7
|
+
Log.debug "Saving model state into #{state_file}"
|
8
|
+
ScoutPython.torch.save(state.state_dict(), state_file)
|
9
|
+
end
|
10
|
+
|
11
|
+
def self.load_state(state, state_file)
|
12
|
+
return state unless Open.exists?(state_file)
|
13
|
+
Log.debug "Loading model state from #{state_file}"
|
14
|
+
state.load_state_dict(ScoutPython.torch.load(state_file))
|
15
|
+
state
|
16
|
+
end
|
17
|
+
|
18
|
+
def self.save_architecture(state, state_file)
|
19
|
+
model_architecture = model_architecture(state_file)
|
20
|
+
Log.debug "Saving model architecture into #{model_architecture}"
|
21
|
+
ScoutPython.torch.save(state, model_architecture)
|
22
|
+
end
|
23
|
+
|
24
|
+
def self.load_architecture(state_file)
|
25
|
+
model_architecture = model_architecture(state_file)
|
26
|
+
return unless Open.exists?(model_architecture)
|
27
|
+
Log.debug "Loading model architecture from #{model_architecture}"
|
28
|
+
ScoutPython.torch.load(model_architecture, weights_only: false)
|
29
|
+
end
|
30
|
+
|
31
|
+
def reset_state
|
32
|
+
@trainer = @state = nil
|
33
|
+
Open.rm_rf state_file
|
34
|
+
Open.rm_rf TorchModel.model_architecture(state_file)
|
35
|
+
end
|
36
|
+
|
37
|
+
def self.save(state_file, state)
|
38
|
+
TorchModel.save_architecture(state, state_file)
|
39
|
+
TorchModel.save_state(state, state_file)
|
40
|
+
end
|
41
|
+
|
42
|
+
def self.load(state_file, state = nil)
|
43
|
+
state ||= TorchModel.load_architecture(state_file)
|
44
|
+
TorchModel.load_state(state, state_file)
|
45
|
+
state
|
46
|
+
end
|
47
|
+
end
|
@@ -0,0 +1,94 @@
|
|
1
|
+
require_relative 'base'
|
2
|
+
|
3
|
+
class TorchModel < PythonModel
|
4
|
+
attr_accessor :criterion, :optimizer, :device, :dtype
|
5
|
+
|
6
|
+
def fix_options
|
7
|
+
@options[:training_options] = @options.delete(:training_args) if @options.include?(:training_args)
|
8
|
+
training_args = IndiferentHash.pull_keys(@options, :training) || {}
|
9
|
+
@options[:training_args] = training_args
|
10
|
+
end
|
11
|
+
|
12
|
+
def initialize(...)
|
13
|
+
|
14
|
+
super(...)
|
15
|
+
|
16
|
+
fix_options
|
17
|
+
|
18
|
+
load_state do |state_file|
|
19
|
+
@state = TorchModel.load(state_file, @state)
|
20
|
+
end
|
21
|
+
|
22
|
+
save_state do |state_file,state|
|
23
|
+
TorchModel.save(state_file, state)
|
24
|
+
end
|
25
|
+
|
26
|
+
train do |features,labels|
|
27
|
+
TorchModel.init_python
|
28
|
+
device ||= TorchModel.device(options)
|
29
|
+
dtype ||= TorchModel.dtype(options)
|
30
|
+
state.to(device)
|
31
|
+
@optimizer ||= TorchModel.optimizer(state, options[:training_args] || {})
|
32
|
+
@criterion ||= TorchModel.optimizer(state, options[:training_args] || {})
|
33
|
+
|
34
|
+
epochs = options[:training_args][:epochs] || 3
|
35
|
+
batch_size = options[:batch_size]
|
36
|
+
batch_size ||= options[:training_args][:batch_size]
|
37
|
+
batch_size ||= 1
|
38
|
+
|
39
|
+
inputs = TorchModel.tensor(features, device, dtype)
|
40
|
+
#target = TorchModel.tensor(labels.collect{|v| [v] }, @device, @dtype)
|
41
|
+
target = TorchModel.tensor(labels, device, dtype)
|
42
|
+
|
43
|
+
Log::ProgressBar.with_bar epochs, :desc => "Training" do |bar|
|
44
|
+
epochs.times do |i|
|
45
|
+
optimizer.zero_grad()
|
46
|
+
outputs = state.call(inputs)
|
47
|
+
outputs = outputs.squeeze() if target.dim() == 1
|
48
|
+
loss = criterion.call(outputs, target)
|
49
|
+
loss.backward()
|
50
|
+
optimizer.step
|
51
|
+
Log.debug "Epoch #{i}, loss #{loss}"
|
52
|
+
bar.tick
|
53
|
+
end
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|
57
|
+
self.eval do |features,list|
|
58
|
+
TorchModel.init_python
|
59
|
+
device ||= TorchModel.device(options)
|
60
|
+
dtype ||= TorchModel.dtype(options)
|
61
|
+
state.to(device)
|
62
|
+
state.eval
|
63
|
+
|
64
|
+
list = [features] if features
|
65
|
+
|
66
|
+
batch_size = options[:batch_size]
|
67
|
+
batch_size ||= options[:training_args][:batch_size]
|
68
|
+
batch_size ||= 1
|
69
|
+
|
70
|
+
res = Misc.chunk(list, batch_size).inject(nil) do |acc,batch|
|
71
|
+
tensor = TorchModel.tensor(batch, device, dtype)
|
72
|
+
|
73
|
+
loss, chunk_res = state.call(tensor)
|
74
|
+
tensor.del
|
75
|
+
|
76
|
+
chunk_res = loss if chunk_res.nil?
|
77
|
+
|
78
|
+
TorchModel::Tensor.setup(chunk_res)
|
79
|
+
chunk_res = chunk_res.to_ruby!
|
80
|
+
|
81
|
+
acc = acc.nil? ? chunk_res : acc + chunk_res
|
82
|
+
|
83
|
+
acc
|
84
|
+
end
|
85
|
+
|
86
|
+
features ? res[0] : res
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
90
|
+
|
91
|
+
require_relative 'torch/helpers'
|
92
|
+
require_relative 'torch/dataloader'
|
93
|
+
require_relative 'torch/load_and_save'
|
94
|
+
require_relative 'torch/introspection'
|
@@ -0,0 +1,181 @@
|
|
1
|
+
class ScoutModel
|
2
|
+
def execute(method, *args)
|
3
|
+
case method
|
4
|
+
when Proc
|
5
|
+
instance_exec *args, &method
|
6
|
+
when nil
|
7
|
+
args.first
|
8
|
+
end
|
9
|
+
end
|
10
|
+
|
11
|
+
def save_state(&block)
|
12
|
+
if block_given?
|
13
|
+
@save_state = block
|
14
|
+
else
|
15
|
+
return @state unless @save_state
|
16
|
+
execute @save_state, state_file, @state
|
17
|
+
end
|
18
|
+
end
|
19
|
+
|
20
|
+
def load_state(&block)
|
21
|
+
if block_given?
|
22
|
+
@load_state = block
|
23
|
+
else
|
24
|
+
return @state unless @load_state
|
25
|
+
execute @load_state, state_file
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def init(&block)
|
30
|
+
return @state if @state
|
31
|
+
if block_given?
|
32
|
+
@init = block
|
33
|
+
else
|
34
|
+
@state = execute @init
|
35
|
+
load_state
|
36
|
+
@state
|
37
|
+
end
|
38
|
+
end
|
39
|
+
|
40
|
+
def eval(sample = nil, &block)
|
41
|
+
if block_given?
|
42
|
+
@eval = block
|
43
|
+
else
|
44
|
+
features = extract_features sample
|
45
|
+
|
46
|
+
init unless @state
|
47
|
+
result = if @eval.arity == 2
|
48
|
+
|
49
|
+
execute @eval, features, nil
|
50
|
+
else
|
51
|
+
execute @eval, features
|
52
|
+
end
|
53
|
+
|
54
|
+
post_process result
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|
58
|
+
def eval_list(list = nil, &block)
|
59
|
+
if block_given?
|
60
|
+
@eval_list = block
|
61
|
+
else
|
62
|
+
list = extract_features_list list
|
63
|
+
|
64
|
+
init unless @state
|
65
|
+
result = if @eval_list
|
66
|
+
execute @eval_list, list
|
67
|
+
elsif @eval
|
68
|
+
|
69
|
+
if @eval.arity == 2
|
70
|
+
execute @eval, nil, list
|
71
|
+
else
|
72
|
+
list.collect{|features| execute @eval, features }
|
73
|
+
end
|
74
|
+
end
|
75
|
+
|
76
|
+
post_process_list result
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
def post_process(result = nil, &block)
|
81
|
+
if block_given?
|
82
|
+
@post_process = block
|
83
|
+
else
|
84
|
+
return result if @post_process.nil?
|
85
|
+
|
86
|
+
if @post_process.arity == 2
|
87
|
+
execute @post_process, result, nil
|
88
|
+
else
|
89
|
+
execute @post_process, result
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
|
94
|
+
def post_process_list(list = nil, &block)
|
95
|
+
if block_given?
|
96
|
+
@post_process_list = block
|
97
|
+
else
|
98
|
+
|
99
|
+
if @post_process_list
|
100
|
+
execute @post_process_list, list
|
101
|
+
elsif @post_process
|
102
|
+
if @post_process.arity == 2
|
103
|
+
execute @post_process, nil, list
|
104
|
+
else
|
105
|
+
list.collect{|result| execute @post_process, result }
|
106
|
+
end
|
107
|
+
else
|
108
|
+
return list
|
109
|
+
end
|
110
|
+
end
|
111
|
+
end
|
112
|
+
|
113
|
+
def train(&block)
|
114
|
+
if block_given?
|
115
|
+
@train = block
|
116
|
+
else
|
117
|
+
init unless @state
|
118
|
+
execute @train, @features, @labels
|
119
|
+
save_state
|
120
|
+
end
|
121
|
+
end
|
122
|
+
|
123
|
+
def extract_features(sample = nil, &block)
|
124
|
+
if block_given?
|
125
|
+
@extract_features = block
|
126
|
+
else
|
127
|
+
return sample if @extract_features.nil?
|
128
|
+
|
129
|
+
if @extract_features.arity == 2
|
130
|
+
execute @extract_features, sample, nil
|
131
|
+
else
|
132
|
+
execute @extract_features, sample
|
133
|
+
end
|
134
|
+
end
|
135
|
+
end
|
136
|
+
|
137
|
+
def extract_features_list(list = nil, &block)
|
138
|
+
if block_given?
|
139
|
+
@extract_features_list = block
|
140
|
+
else
|
141
|
+
return list if @extract_features.nil?
|
142
|
+
|
143
|
+
if @extract_features_list
|
144
|
+
execute @extract_features_list, list
|
145
|
+
elsif @extract_features
|
146
|
+
if @extract_features.arity == 2
|
147
|
+
execute @extract_features, nil, list
|
148
|
+
else
|
149
|
+
list.collect{|sample| execute @extract_features, sample }
|
150
|
+
end
|
151
|
+
else
|
152
|
+
return list
|
153
|
+
end
|
154
|
+
end
|
155
|
+
end
|
156
|
+
|
157
|
+
def add(sample, label = nil)
|
158
|
+
features = extract_features sample
|
159
|
+
@features << features
|
160
|
+
@labels << label
|
161
|
+
end
|
162
|
+
|
163
|
+
def add_list(list, labels = nil)
|
164
|
+
if Hash === list
|
165
|
+
list.each do |sample,label|
|
166
|
+
add sample, label
|
167
|
+
end
|
168
|
+
else
|
169
|
+
list = extract_features_list list
|
170
|
+
@features.concat list
|
171
|
+
|
172
|
+
if Hash === labels
|
173
|
+
list.each do |sample|
|
174
|
+
@labels << labels[sample]
|
175
|
+
end
|
176
|
+
elsif labels
|
177
|
+
@labels.concat labels
|
178
|
+
end
|
179
|
+
end
|
180
|
+
end
|
181
|
+
end
|
@@ -0,0 +1,81 @@
|
|
1
|
+
class ScoutModel
|
2
|
+
def state_file
|
3
|
+
return nil unless directory
|
4
|
+
directory.state
|
5
|
+
end
|
6
|
+
|
7
|
+
def save_options
|
8
|
+
file = directory['options.json']
|
9
|
+
file.write(options.to_json)
|
10
|
+
end
|
11
|
+
|
12
|
+
def load_options
|
13
|
+
file = directory['options.json']
|
14
|
+
if file.exists?
|
15
|
+
IndiferentHash.setup(JSON.parse(file.read)).merge @options
|
16
|
+
else
|
17
|
+
@options
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
def load_ruby_code(file)
|
22
|
+
Log.debug "Loading ruby file #{file}"
|
23
|
+
code = Open.read(file)
|
24
|
+
code.sub!(/.*(\sdo\b|{)/, 'Proc.new\1')
|
25
|
+
instance_eval code, file
|
26
|
+
end
|
27
|
+
|
28
|
+
def load_method(name)
|
29
|
+
file = directory[name.to_s]
|
30
|
+
|
31
|
+
if file.exists?
|
32
|
+
file.read
|
33
|
+
elsif file.set_extension('rb').exists?
|
34
|
+
load_ruby_code file.set_extension('rb')
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
def save_method(name, value)
|
39
|
+
file = directory[name.to_s]
|
40
|
+
|
41
|
+
Log.debug "Saving #{file}"
|
42
|
+
case
|
43
|
+
when Proc === value
|
44
|
+
require 'method_source'
|
45
|
+
Open.write(file.set_extension('rb'), value.source)
|
46
|
+
when String === train_model
|
47
|
+
Open.write(file, @train_model)
|
48
|
+
end
|
49
|
+
end
|
50
|
+
|
51
|
+
def save
|
52
|
+
save_options if @options
|
53
|
+
|
54
|
+
save_method(:eval, @eval) if @eval
|
55
|
+
save_method(:eval_list, @eval_list) if @eval_list
|
56
|
+
save_method(:extract_features, @extract_features) if @extract_features
|
57
|
+
save_method(:extract_features_list, @extract_features_list) if @extract_features_list
|
58
|
+
save_method(:post_process, @post_process) if @post_process
|
59
|
+
save_method(:post_process_list, @post_process_list) if @post_process_list
|
60
|
+
save_method(:train, @train) if @train
|
61
|
+
save_method(:init, @init) if @init
|
62
|
+
save_method(:load_state, @load_state) if @load_state
|
63
|
+
save_method(:save_state, @save_state) if @save_state
|
64
|
+
|
65
|
+
save_state if @state
|
66
|
+
end
|
67
|
+
|
68
|
+
def restore
|
69
|
+
@eval = load_method :eval
|
70
|
+
@eval_list = load_method :eval_list
|
71
|
+
@extract_features = load_method :extract_features
|
72
|
+
@extract_features_list = load_method :extract_features_list
|
73
|
+
@post_process = load_method :post_process
|
74
|
+
@post_process_list = load_method :post_process_list
|
75
|
+
@train = load_method :train
|
76
|
+
@init = load_method :init
|
77
|
+
@load_state = load_method :load_state
|
78
|
+
@save_state = load_method :save_state
|
79
|
+
@options = load_options
|
80
|
+
end
|
81
|
+
end
|
data/lib/scout-ai.rb
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
require 'scout'
|
2
2
|
require 'scout/path'
|
3
3
|
require 'scout/resource'
|
4
|
-
|
4
|
+
|
5
|
+
Path.add_path :scout_ai_lib, File.join(Path.caller_lib_dir(__FILE__), "{TOPLEVEL}/{SUBPATH}")
|
5
6
|
|
6
7
|
require 'scout/llm/ask'
|
7
8
|
require 'scout/llm/embed'
|
9
|
+
require 'scout/llm/agent'
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import scout
|
2
|
+
import torch
|
3
|
+
from .util import *
|
4
|
+
|
5
|
+
class TSVDataset(torch.utils.data.Dataset):
|
6
|
+
def __init__(self, tsv):
|
7
|
+
self.tsv = tsv
|
8
|
+
|
9
|
+
def __getitem__(self, key):
|
10
|
+
if (type(key) == int):
|
11
|
+
row = self.tsv.iloc[key]
|
12
|
+
else:
|
13
|
+
row = self.tsv.loc[key]
|
14
|
+
|
15
|
+
row = row.to_numpy()
|
16
|
+
features = row[:-1]
|
17
|
+
label = row[-1]
|
18
|
+
|
19
|
+
return features, label
|
20
|
+
|
21
|
+
def __len__(self):
|
22
|
+
return len(self.tsv)
|
23
|
+
|
24
|
+
def tsv_dataset(filename, *args, **kwargs):
|
25
|
+
return TSVDataset(scout.tsv(filename, *args, **kwargs))
|
26
|
+
|
27
|
+
def tsv(*args, **kwargs):
|
28
|
+
return tsv_dataset(*args, **kwargs)
|
29
|
+
|
30
|
+
def tsv_loader(*args, **kwargs):
|
31
|
+
dataset = tsv(*args, kwargs)
|
32
|
+
return torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
|
33
|
+
|
34
|
+
def data_dir():
|
35
|
+
return scout.path('var/scout_dm/data')
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
File without changes
|
@@ -0,0 +1,141 @@
|
|
1
|
+
from matplotlib import pyplot as plt
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
from IPython.display import HTML, display
|
5
|
+
|
6
|
+
|
7
|
+
def set_default(figsize=(10, 10), dpi=100):
|
8
|
+
plt.style.use(['dark_background', 'bmh'])
|
9
|
+
plt.rc('axes', facecolor='k')
|
10
|
+
plt.rc('figure', facecolor='k')
|
11
|
+
plt.rc('figure', figsize=figsize, dpi=dpi)
|
12
|
+
|
13
|
+
|
14
|
+
def plot_data(X, y, d=0, auto=False, zoom=1):
|
15
|
+
X = X.cpu()
|
16
|
+
y = y.cpu()
|
17
|
+
plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
|
18
|
+
plt.axis('square')
|
19
|
+
plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
|
20
|
+
if auto is True: plt.axis('equal')
|
21
|
+
plt.axis('off')
|
22
|
+
|
23
|
+
_m, _c = 0, '.15'
|
24
|
+
plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
|
25
|
+
plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
|
26
|
+
|
27
|
+
|
28
|
+
def plot_model(X, y, model):
|
29
|
+
model.cpu()
|
30
|
+
mesh = np.arange(-1.1, 1.1, 0.01)
|
31
|
+
xx, yy = np.meshgrid(mesh, mesh)
|
32
|
+
with torch.no_grad():
|
33
|
+
data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
|
34
|
+
Z = model(data).detach()
|
35
|
+
Z = np.argmax(Z, axis=1).reshape(xx.shape)
|
36
|
+
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
|
37
|
+
plot_data(X, y)
|
38
|
+
|
39
|
+
|
40
|
+
def show_scatterplot(X, colors, title=''):
|
41
|
+
colors = colors.cpu().numpy()
|
42
|
+
X = X.cpu().numpy()
|
43
|
+
plt.figure()
|
44
|
+
plt.axis('equal')
|
45
|
+
plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
|
46
|
+
# plt.grid(True)
|
47
|
+
plt.title(title)
|
48
|
+
plt.axis('off')
|
49
|
+
|
50
|
+
|
51
|
+
def plot_bases(bases, width=0.04):
|
52
|
+
bases = bases.cpu()
|
53
|
+
bases[2:] -= bases[:2]
|
54
|
+
plt.arrow(*bases[0], *bases[2], width=width, color=(1,0,0), zorder=10, alpha=1., length_includes_head=True)
|
55
|
+
plt.arrow(*bases[1], *bases[3], width=width, color=(0,1,0), zorder=10, alpha=1., length_includes_head=True)
|
56
|
+
|
57
|
+
|
58
|
+
def show_mat(mat, vect, prod, threshold=-1):
|
59
|
+
# Subplot grid definition
|
60
|
+
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True,
|
61
|
+
gridspec_kw={'width_ratios':[5,1,1]})
|
62
|
+
# Plot matrices
|
63
|
+
cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1))
|
64
|
+
ax2.matshow(vect.numpy(), clim=(-1, 1))
|
65
|
+
cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1))
|
66
|
+
|
67
|
+
# Set titles
|
68
|
+
ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}')
|
69
|
+
ax2.set_title(f'a^(i): {vect.numel()}')
|
70
|
+
ax3.set_title(f'p: {prod.numel()}')
|
71
|
+
|
72
|
+
# Remove xticks for vectors
|
73
|
+
ax2.set_xticks(tuple())
|
74
|
+
ax3.set_xticks(tuple())
|
75
|
+
|
76
|
+
# Plot colourbars
|
77
|
+
fig.colorbar(cax1, ax=ax2)
|
78
|
+
fig.colorbar(cax3, ax=ax3)
|
79
|
+
|
80
|
+
# Fix y-axis limits
|
81
|
+
ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5)
|
82
|
+
|
83
|
+
|
84
|
+
colors = dict(
|
85
|
+
aqua='#8dd3c7',
|
86
|
+
yellow='#ffffb3',
|
87
|
+
lavender='#bebada',
|
88
|
+
red='#fb8072',
|
89
|
+
blue='#80b1d3',
|
90
|
+
orange='#fdb462',
|
91
|
+
green='#b3de69',
|
92
|
+
pink='#fccde5',
|
93
|
+
grey='#d9d9d9',
|
94
|
+
violet='#bc80bd',
|
95
|
+
unk1='#ccebc5',
|
96
|
+
unk2='#ffed6f',
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
def _cstr(s, color='black'):
|
101
|
+
if s == ' ':
|
102
|
+
return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>'
|
103
|
+
else:
|
104
|
+
return f'<text style=color:#000;background-color:{color}>{s} </text>'
|
105
|
+
|
106
|
+
# print html
|
107
|
+
def _print_color(t):
|
108
|
+
display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t])))
|
109
|
+
|
110
|
+
# get appropriate color for value
|
111
|
+
def _get_clr(value):
|
112
|
+
colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8',
|
113
|
+
'#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
|
114
|
+
'#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
|
115
|
+
'#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e')
|
116
|
+
value = int((value * 100) / 5)
|
117
|
+
if value == len(colors): value -= 1 # fixing bugs...
|
118
|
+
return colors[value]
|
119
|
+
|
120
|
+
def _visualise_values(output_values, result_list):
|
121
|
+
text_colours = []
|
122
|
+
for i in range(len(output_values)):
|
123
|
+
text = (result_list[i], _get_clr(output_values[i]))
|
124
|
+
text_colours.append(text)
|
125
|
+
_print_color(text_colours)
|
126
|
+
|
127
|
+
def print_colourbar():
|
128
|
+
color_range = torch.linspace(-2.5, 2.5, 20)
|
129
|
+
to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range]
|
130
|
+
_print_color(to_print)
|
131
|
+
|
132
|
+
|
133
|
+
# Let's only focus on the last time step for now
|
134
|
+
# First, the cell state (Long term memory)
|
135
|
+
def plot_state(data, state, b, decoder):
|
136
|
+
actual_data = decoder(data[b, :, :].numpy())
|
137
|
+
seq_len = len(actual_data)
|
138
|
+
seq_len_w_pad = len(state)
|
139
|
+
for s in range(state.size(2)):
|
140
|
+
states = torch.sigmoid(state[:, b, s])
|
141
|
+
_visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))
|
@@ -0,0 +1,27 @@
|
|
1
|
+
import torch
|
2
|
+
import math
|
3
|
+
def spiral_data(N=1000, D=2, C=3):
|
4
|
+
X = torch.zeros(N * C, D)
|
5
|
+
y = torch.zeros(N * C, dtype=torch.long)
|
6
|
+
for c in range(C):
|
7
|
+
index = 0
|
8
|
+
t = torch.linspace(0, 1, N)
|
9
|
+
# When c = 0 and t = 0: start of linspace
|
10
|
+
# When c = 0 and t = 1: end of linpace
|
11
|
+
# This inner_var is for the formula inside sin() and cos() like sin(inner_var) and cos(inner_Var)
|
12
|
+
inner_var = torch.linspace(
|
13
|
+
# When t = 0
|
14
|
+
(2 * math.pi / C) * (c),
|
15
|
+
# When t = 1
|
16
|
+
(2 * math.pi / C) * (2 + c),
|
17
|
+
N
|
18
|
+
) + torch.randn(N) * 0.2
|
19
|
+
|
20
|
+
for ix in range(N * c, N * (c + 1)):
|
21
|
+
X[ix] = t[index] * torch.FloatTensor((
|
22
|
+
math.sin(inner_var[index]), math.cos(inner_var[index])
|
23
|
+
))
|
24
|
+
y[ix] = c
|
25
|
+
index += 1
|
26
|
+
|
27
|
+
return (X, y)
|
@@ -0,0 +1,48 @@
|
|
1
|
+
import scout
|
2
|
+
import pandas as pd
|
3
|
+
import datasets
|
4
|
+
from typing import Any, Dict, List
|
5
|
+
|
6
|
+
def load_tsv(tsv_file):
|
7
|
+
tsv = scout.tsv(tsv_file)
|
8
|
+
ds = datasets.Dataset.from_pandas(tsv)
|
9
|
+
d = datasets.DatasetDict()
|
10
|
+
d["train"] = ds
|
11
|
+
return d
|
12
|
+
|
13
|
+
def load_json(json_file):
|
14
|
+
return datasets.load_dataset('json', data_files=[json_file])
|
15
|
+
|
16
|
+
def tokenize_dataset(tokenizer, dataset, max_length=32):
|
17
|
+
def preprocess_function(examples):
|
18
|
+
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)
|
19
|
+
if isinstance(dataset, datasets.DatasetDict):
|
20
|
+
for split in dataset:
|
21
|
+
dataset[split] = dataset[split].map(preprocess_function, batched=True)
|
22
|
+
return dataset
|
23
|
+
else:
|
24
|
+
return dataset.map(preprocess_function, batched=True)
|
25
|
+
|
26
|
+
def tsv_dataset(tokenizer, tsv_file):
|
27
|
+
dataset = load_tsv(tsv_file)
|
28
|
+
return tokenize_dataset(tokenizer, dataset)
|
29
|
+
|
30
|
+
def json_dataset(tokenizer, json_file):
|
31
|
+
dataset = load_json(json_file)
|
32
|
+
return tokenize_dataset(tokenizer, dataset)
|
33
|
+
|
34
|
+
def list_dataset(tokenizer, texts, labels=None, max_length=32):
|
35
|
+
data_dict = {"text": texts}
|
36
|
+
if labels is not None:
|
37
|
+
data_dict["label"] = labels
|
38
|
+
ds = datasets.Dataset.from_dict(data_dict)
|
39
|
+
|
40
|
+
def preprocess_function(examples):
|
41
|
+
output = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)
|
42
|
+
if "label" in examples:
|
43
|
+
output["label"] = examples["label"]
|
44
|
+
return output
|
45
|
+
|
46
|
+
tokenized_ds = ds.map(preprocess_function, batched=True)
|
47
|
+
return tokenized_ds
|
48
|
+
|