scout-ai 0.2.0 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. checksums.yaml +4 -4
  2. data/.vimproject +91 -10
  3. data/Rakefile +1 -0
  4. data/VERSION +1 -1
  5. data/bin/scout-ai +2 -0
  6. data/lib/scout/llm/agent/chat.rb +24 -0
  7. data/lib/scout/llm/agent.rb +13 -13
  8. data/lib/scout/llm/ask.rb +26 -16
  9. data/lib/scout/llm/backends/bedrock.rb +129 -0
  10. data/lib/scout/llm/backends/huggingface.rb +6 -21
  11. data/lib/scout/llm/backends/ollama.rb +69 -36
  12. data/lib/scout/llm/backends/openai.rb +85 -35
  13. data/lib/scout/llm/backends/openwebui.rb +1 -1
  14. data/lib/scout/llm/backends/relay.rb +3 -2
  15. data/lib/scout/llm/backends/responses.rb +272 -0
  16. data/lib/scout/llm/chat.rb +547 -0
  17. data/lib/scout/llm/parse.rb +70 -13
  18. data/lib/scout/llm/tools.rb +126 -5
  19. data/lib/scout/llm/utils.rb +17 -10
  20. data/lib/scout/model/base.rb +19 -0
  21. data/lib/scout/model/python/base.rb +25 -0
  22. data/lib/scout/model/python/huggingface/causal/next_token.rb +23 -0
  23. data/lib/scout/model/python/huggingface/causal.rb +29 -0
  24. data/lib/scout/model/python/huggingface/classification +0 -0
  25. data/lib/scout/model/python/huggingface/classification.rb +50 -0
  26. data/lib/scout/model/python/huggingface.rb +112 -0
  27. data/lib/scout/model/python/torch/dataloader.rb +57 -0
  28. data/lib/scout/model/python/torch/helpers.rb +84 -0
  29. data/lib/scout/model/python/torch/introspection.rb +34 -0
  30. data/lib/scout/model/python/torch/load_and_save.rb +47 -0
  31. data/lib/scout/model/python/torch.rb +94 -0
  32. data/lib/scout/model/util/run.rb +181 -0
  33. data/lib/scout/model/util/save.rb +81 -0
  34. data/lib/scout-ai.rb +3 -1
  35. data/python/scout_ai/__init__.py +35 -0
  36. data/python/scout_ai/__pycache__/__init__.cpython-310.pyc +0 -0
  37. data/python/scout_ai/__pycache__/__init__.cpython-311.pyc +0 -0
  38. data/python/scout_ai/__pycache__/huggingface.cpython-310.pyc +0 -0
  39. data/python/scout_ai/__pycache__/huggingface.cpython-311.pyc +0 -0
  40. data/python/scout_ai/__pycache__/util.cpython-310.pyc +0 -0
  41. data/python/scout_ai/__pycache__/util.cpython-311.pyc +0 -0
  42. data/python/scout_ai/atcold/__init__.py +0 -0
  43. data/python/scout_ai/atcold/plot_lib.py +141 -0
  44. data/python/scout_ai/atcold/spiral.py +27 -0
  45. data/python/scout_ai/huggingface/data.py +48 -0
  46. data/python/scout_ai/huggingface/eval.py +60 -0
  47. data/python/scout_ai/huggingface/model.py +29 -0
  48. data/python/scout_ai/huggingface/rlhf.py +83 -0
  49. data/python/scout_ai/huggingface/train/__init__.py +34 -0
  50. data/python/scout_ai/huggingface/train/__pycache__/__init__.cpython-310.pyc +0 -0
  51. data/python/scout_ai/huggingface/train/__pycache__/next_token.cpython-310.pyc +0 -0
  52. data/python/scout_ai/huggingface/train/next_token.py +315 -0
  53. data/python/scout_ai/language_model.py +70 -0
  54. data/python/scout_ai/util.py +32 -0
  55. data/scout-ai.gemspec +130 -0
  56. data/scout_commands/agent/ask +133 -15
  57. data/scout_commands/agent/kb +15 -0
  58. data/scout_commands/llm/ask +71 -12
  59. data/scout_commands/llm/process +4 -2
  60. data/test/data/cat.jpg +0 -0
  61. data/test/scout/llm/agent/test_chat.rb +14 -0
  62. data/test/scout/llm/backends/test_bedrock.rb +60 -0
  63. data/test/scout/llm/backends/test_huggingface.rb +3 -3
  64. data/test/scout/llm/backends/test_ollama.rb +48 -10
  65. data/test/scout/llm/backends/test_openai.rb +96 -11
  66. data/test/scout/llm/backends/test_responses.rb +115 -0
  67. data/test/scout/llm/test_ask.rb +1 -0
  68. data/test/scout/llm/test_chat.rb +214 -0
  69. data/test/scout/llm/test_parse.rb +81 -2
  70. data/test/scout/model/python/huggingface/causal/test_next_token.rb +59 -0
  71. data/test/scout/model/python/huggingface/test_causal.rb +33 -0
  72. data/test/scout/model/python/huggingface/test_classification.rb +30 -0
  73. data/test/scout/model/python/test_base.rb +44 -0
  74. data/test/scout/model/python/test_huggingface.rb +9 -0
  75. data/test/scout/model/python/test_torch.rb +71 -0
  76. data/test/scout/model/python/torch/test_helpers.rb +14 -0
  77. data/test/scout/model/test_base.rb +117 -0
  78. data/test/scout/model/util/test_save.rb +31 -0
  79. metadata +72 -5
  80. data/questions/coach +0 -2
@@ -0,0 +1,47 @@
1
+ class TorchModel
2
+ def self.model_architecture(state_file)
3
+ state_file + '.architecture'
4
+ end
5
+
6
+ def self.save_state(state, state_file)
7
+ Log.debug "Saving model state into #{state_file}"
8
+ ScoutPython.torch.save(state.state_dict(), state_file)
9
+ end
10
+
11
+ def self.load_state(state, state_file)
12
+ return state unless Open.exists?(state_file)
13
+ Log.debug "Loading model state from #{state_file}"
14
+ state.load_state_dict(ScoutPython.torch.load(state_file))
15
+ state
16
+ end
17
+
18
+ def self.save_architecture(state, state_file)
19
+ model_architecture = model_architecture(state_file)
20
+ Log.debug "Saving model architecture into #{model_architecture}"
21
+ ScoutPython.torch.save(state, model_architecture)
22
+ end
23
+
24
+ def self.load_architecture(state_file)
25
+ model_architecture = model_architecture(state_file)
26
+ return unless Open.exists?(model_architecture)
27
+ Log.debug "Loading model architecture from #{model_architecture}"
28
+ ScoutPython.torch.load(model_architecture, weights_only: false)
29
+ end
30
+
31
+ def reset_state
32
+ @trainer = @state = nil
33
+ Open.rm_rf state_file
34
+ Open.rm_rf TorchModel.model_architecture(state_file)
35
+ end
36
+
37
+ def self.save(state_file, state)
38
+ TorchModel.save_architecture(state, state_file)
39
+ TorchModel.save_state(state, state_file)
40
+ end
41
+
42
+ def self.load(state_file, state = nil)
43
+ state ||= TorchModel.load_architecture(state_file)
44
+ TorchModel.load_state(state, state_file)
45
+ state
46
+ end
47
+ end
@@ -0,0 +1,94 @@
1
+ require_relative 'base'
2
+
3
+ class TorchModel < PythonModel
4
+ attr_accessor :criterion, :optimizer, :device, :dtype
5
+
6
+ def fix_options
7
+ @options[:training_options] = @options.delete(:training_args) if @options.include?(:training_args)
8
+ training_args = IndiferentHash.pull_keys(@options, :training) || {}
9
+ @options[:training_args] = training_args
10
+ end
11
+
12
+ def initialize(...)
13
+
14
+ super(...)
15
+
16
+ fix_options
17
+
18
+ load_state do |state_file|
19
+ @state = TorchModel.load(state_file, @state)
20
+ end
21
+
22
+ save_state do |state_file,state|
23
+ TorchModel.save(state_file, state)
24
+ end
25
+
26
+ train do |features,labels|
27
+ TorchModel.init_python
28
+ device ||= TorchModel.device(options)
29
+ dtype ||= TorchModel.dtype(options)
30
+ state.to(device)
31
+ @optimizer ||= TorchModel.optimizer(state, options[:training_args] || {})
32
+ @criterion ||= TorchModel.optimizer(state, options[:training_args] || {})
33
+
34
+ epochs = options[:training_args][:epochs] || 3
35
+ batch_size = options[:batch_size]
36
+ batch_size ||= options[:training_args][:batch_size]
37
+ batch_size ||= 1
38
+
39
+ inputs = TorchModel.tensor(features, device, dtype)
40
+ #target = TorchModel.tensor(labels.collect{|v| [v] }, @device, @dtype)
41
+ target = TorchModel.tensor(labels, device, dtype)
42
+
43
+ Log::ProgressBar.with_bar epochs, :desc => "Training" do |bar|
44
+ epochs.times do |i|
45
+ optimizer.zero_grad()
46
+ outputs = state.call(inputs)
47
+ outputs = outputs.squeeze() if target.dim() == 1
48
+ loss = criterion.call(outputs, target)
49
+ loss.backward()
50
+ optimizer.step
51
+ Log.debug "Epoch #{i}, loss #{loss}"
52
+ bar.tick
53
+ end
54
+ end
55
+ end
56
+
57
+ self.eval do |features,list|
58
+ TorchModel.init_python
59
+ device ||= TorchModel.device(options)
60
+ dtype ||= TorchModel.dtype(options)
61
+ state.to(device)
62
+ state.eval
63
+
64
+ list = [features] if features
65
+
66
+ batch_size = options[:batch_size]
67
+ batch_size ||= options[:training_args][:batch_size]
68
+ batch_size ||= 1
69
+
70
+ res = Misc.chunk(list, batch_size).inject(nil) do |acc,batch|
71
+ tensor = TorchModel.tensor(batch, device, dtype)
72
+
73
+ loss, chunk_res = state.call(tensor)
74
+ tensor.del
75
+
76
+ chunk_res = loss if chunk_res.nil?
77
+
78
+ TorchModel::Tensor.setup(chunk_res)
79
+ chunk_res = chunk_res.to_ruby!
80
+
81
+ acc = acc.nil? ? chunk_res : acc + chunk_res
82
+
83
+ acc
84
+ end
85
+
86
+ features ? res[0] : res
87
+ end
88
+ end
89
+ end
90
+
91
+ require_relative 'torch/helpers'
92
+ require_relative 'torch/dataloader'
93
+ require_relative 'torch/load_and_save'
94
+ require_relative 'torch/introspection'
@@ -0,0 +1,181 @@
1
+ class ScoutModel
2
+ def execute(method, *args)
3
+ case method
4
+ when Proc
5
+ instance_exec *args, &method
6
+ when nil
7
+ args.first
8
+ end
9
+ end
10
+
11
+ def save_state(&block)
12
+ if block_given?
13
+ @save_state = block
14
+ else
15
+ return @state unless @save_state
16
+ execute @save_state, state_file, @state
17
+ end
18
+ end
19
+
20
+ def load_state(&block)
21
+ if block_given?
22
+ @load_state = block
23
+ else
24
+ return @state unless @load_state
25
+ execute @load_state, state_file
26
+ end
27
+ end
28
+
29
+ def init(&block)
30
+ return @state if @state
31
+ if block_given?
32
+ @init = block
33
+ else
34
+ @state = execute @init
35
+ load_state
36
+ @state
37
+ end
38
+ end
39
+
40
+ def eval(sample = nil, &block)
41
+ if block_given?
42
+ @eval = block
43
+ else
44
+ features = extract_features sample
45
+
46
+ init unless @state
47
+ result = if @eval.arity == 2
48
+
49
+ execute @eval, features, nil
50
+ else
51
+ execute @eval, features
52
+ end
53
+
54
+ post_process result
55
+ end
56
+ end
57
+
58
+ def eval_list(list = nil, &block)
59
+ if block_given?
60
+ @eval_list = block
61
+ else
62
+ list = extract_features_list list
63
+
64
+ init unless @state
65
+ result = if @eval_list
66
+ execute @eval_list, list
67
+ elsif @eval
68
+
69
+ if @eval.arity == 2
70
+ execute @eval, nil, list
71
+ else
72
+ list.collect{|features| execute @eval, features }
73
+ end
74
+ end
75
+
76
+ post_process_list result
77
+ end
78
+ end
79
+
80
+ def post_process(result = nil, &block)
81
+ if block_given?
82
+ @post_process = block
83
+ else
84
+ return result if @post_process.nil?
85
+
86
+ if @post_process.arity == 2
87
+ execute @post_process, result, nil
88
+ else
89
+ execute @post_process, result
90
+ end
91
+ end
92
+ end
93
+
94
+ def post_process_list(list = nil, &block)
95
+ if block_given?
96
+ @post_process_list = block
97
+ else
98
+
99
+ if @post_process_list
100
+ execute @post_process_list, list
101
+ elsif @post_process
102
+ if @post_process.arity == 2
103
+ execute @post_process, nil, list
104
+ else
105
+ list.collect{|result| execute @post_process, result }
106
+ end
107
+ else
108
+ return list
109
+ end
110
+ end
111
+ end
112
+
113
+ def train(&block)
114
+ if block_given?
115
+ @train = block
116
+ else
117
+ init unless @state
118
+ execute @train, @features, @labels
119
+ save_state
120
+ end
121
+ end
122
+
123
+ def extract_features(sample = nil, &block)
124
+ if block_given?
125
+ @extract_features = block
126
+ else
127
+ return sample if @extract_features.nil?
128
+
129
+ if @extract_features.arity == 2
130
+ execute @extract_features, sample, nil
131
+ else
132
+ execute @extract_features, sample
133
+ end
134
+ end
135
+ end
136
+
137
+ def extract_features_list(list = nil, &block)
138
+ if block_given?
139
+ @extract_features_list = block
140
+ else
141
+ return list if @extract_features.nil?
142
+
143
+ if @extract_features_list
144
+ execute @extract_features_list, list
145
+ elsif @extract_features
146
+ if @extract_features.arity == 2
147
+ execute @extract_features, nil, list
148
+ else
149
+ list.collect{|sample| execute @extract_features, sample }
150
+ end
151
+ else
152
+ return list
153
+ end
154
+ end
155
+ end
156
+
157
+ def add(sample, label = nil)
158
+ features = extract_features sample
159
+ @features << features
160
+ @labels << label
161
+ end
162
+
163
+ def add_list(list, labels = nil)
164
+ if Hash === list
165
+ list.each do |sample,label|
166
+ add sample, label
167
+ end
168
+ else
169
+ list = extract_features_list list
170
+ @features.concat list
171
+
172
+ if Hash === labels
173
+ list.each do |sample|
174
+ @labels << labels[sample]
175
+ end
176
+ elsif labels
177
+ @labels.concat labels
178
+ end
179
+ end
180
+ end
181
+ end
@@ -0,0 +1,81 @@
1
+ class ScoutModel
2
+ def state_file
3
+ return nil unless directory
4
+ directory.state
5
+ end
6
+
7
+ def save_options
8
+ file = directory['options.json']
9
+ file.write(options.to_json)
10
+ end
11
+
12
+ def load_options
13
+ file = directory['options.json']
14
+ if file.exists?
15
+ IndiferentHash.setup(JSON.parse(file.read)).merge @options
16
+ else
17
+ @options
18
+ end
19
+ end
20
+
21
+ def load_ruby_code(file)
22
+ Log.debug "Loading ruby file #{file}"
23
+ code = Open.read(file)
24
+ code.sub!(/.*(\sdo\b|{)/, 'Proc.new\1')
25
+ instance_eval code, file
26
+ end
27
+
28
+ def load_method(name)
29
+ file = directory[name.to_s]
30
+
31
+ if file.exists?
32
+ file.read
33
+ elsif file.set_extension('rb').exists?
34
+ load_ruby_code file.set_extension('rb')
35
+ end
36
+ end
37
+
38
+ def save_method(name, value)
39
+ file = directory[name.to_s]
40
+
41
+ Log.debug "Saving #{file}"
42
+ case
43
+ when Proc === value
44
+ require 'method_source'
45
+ Open.write(file.set_extension('rb'), value.source)
46
+ when String === train_model
47
+ Open.write(file, @train_model)
48
+ end
49
+ end
50
+
51
+ def save
52
+ save_options if @options
53
+
54
+ save_method(:eval, @eval) if @eval
55
+ save_method(:eval_list, @eval_list) if @eval_list
56
+ save_method(:extract_features, @extract_features) if @extract_features
57
+ save_method(:extract_features_list, @extract_features_list) if @extract_features_list
58
+ save_method(:post_process, @post_process) if @post_process
59
+ save_method(:post_process_list, @post_process_list) if @post_process_list
60
+ save_method(:train, @train) if @train
61
+ save_method(:init, @init) if @init
62
+ save_method(:load_state, @load_state) if @load_state
63
+ save_method(:save_state, @save_state) if @save_state
64
+
65
+ save_state if @state
66
+ end
67
+
68
+ def restore
69
+ @eval = load_method :eval
70
+ @eval_list = load_method :eval_list
71
+ @extract_features = load_method :extract_features
72
+ @extract_features_list = load_method :extract_features_list
73
+ @post_process = load_method :post_process
74
+ @post_process_list = load_method :post_process_list
75
+ @train = load_method :train
76
+ @init = load_method :init
77
+ @load_state = load_method :load_state
78
+ @save_state = load_method :save_state
79
+ @options = load_options
80
+ end
81
+ end
data/lib/scout-ai.rb CHANGED
@@ -1,7 +1,9 @@
1
1
  require 'scout'
2
2
  require 'scout/path'
3
3
  require 'scout/resource'
4
- Path.add_path :scout_ai, File.join(Path.caller_lib_dir(__FILE__), "{TOPLEVEL}/{SUBPATH}")
4
+
5
+ Path.add_path :scout_ai_lib, File.join(Path.caller_lib_dir(__FILE__), "{TOPLEVEL}/{SUBPATH}")
5
6
 
6
7
  require 'scout/llm/ask'
7
8
  require 'scout/llm/embed'
9
+ require 'scout/llm/agent'
@@ -0,0 +1,35 @@
1
+ import scout
2
+ import torch
3
+ from .util import *
4
+
5
+ class TSVDataset(torch.utils.data.Dataset):
6
+ def __init__(self, tsv):
7
+ self.tsv = tsv
8
+
9
+ def __getitem__(self, key):
10
+ if (type(key) == int):
11
+ row = self.tsv.iloc[key]
12
+ else:
13
+ row = self.tsv.loc[key]
14
+
15
+ row = row.to_numpy()
16
+ features = row[:-1]
17
+ label = row[-1]
18
+
19
+ return features, label
20
+
21
+ def __len__(self):
22
+ return len(self.tsv)
23
+
24
+ def tsv_dataset(filename, *args, **kwargs):
25
+ return TSVDataset(scout.tsv(filename, *args, **kwargs))
26
+
27
+ def tsv(*args, **kwargs):
28
+ return tsv_dataset(*args, **kwargs)
29
+
30
+ def tsv_loader(*args, **kwargs):
31
+ dataset = tsv(*args, kwargs)
32
+ return torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
33
+
34
+ def data_dir():
35
+ return scout.path('var/scout_dm/data')
File without changes
@@ -0,0 +1,141 @@
1
+ from matplotlib import pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+ from IPython.display import HTML, display
5
+
6
+
7
+ def set_default(figsize=(10, 10), dpi=100):
8
+ plt.style.use(['dark_background', 'bmh'])
9
+ plt.rc('axes', facecolor='k')
10
+ plt.rc('figure', facecolor='k')
11
+ plt.rc('figure', figsize=figsize, dpi=dpi)
12
+
13
+
14
+ def plot_data(X, y, d=0, auto=False, zoom=1):
15
+ X = X.cpu()
16
+ y = y.cpu()
17
+ plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
18
+ plt.axis('square')
19
+ plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
20
+ if auto is True: plt.axis('equal')
21
+ plt.axis('off')
22
+
23
+ _m, _c = 0, '.15'
24
+ plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
25
+ plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
26
+
27
+
28
+ def plot_model(X, y, model):
29
+ model.cpu()
30
+ mesh = np.arange(-1.1, 1.1, 0.01)
31
+ xx, yy = np.meshgrid(mesh, mesh)
32
+ with torch.no_grad():
33
+ data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
34
+ Z = model(data).detach()
35
+ Z = np.argmax(Z, axis=1).reshape(xx.shape)
36
+ plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
37
+ plot_data(X, y)
38
+
39
+
40
+ def show_scatterplot(X, colors, title=''):
41
+ colors = colors.cpu().numpy()
42
+ X = X.cpu().numpy()
43
+ plt.figure()
44
+ plt.axis('equal')
45
+ plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
46
+ # plt.grid(True)
47
+ plt.title(title)
48
+ plt.axis('off')
49
+
50
+
51
+ def plot_bases(bases, width=0.04):
52
+ bases = bases.cpu()
53
+ bases[2:] -= bases[:2]
54
+ plt.arrow(*bases[0], *bases[2], width=width, color=(1,0,0), zorder=10, alpha=1., length_includes_head=True)
55
+ plt.arrow(*bases[1], *bases[3], width=width, color=(0,1,0), zorder=10, alpha=1., length_includes_head=True)
56
+
57
+
58
+ def show_mat(mat, vect, prod, threshold=-1):
59
+ # Subplot grid definition
60
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True,
61
+ gridspec_kw={'width_ratios':[5,1,1]})
62
+ # Plot matrices
63
+ cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1))
64
+ ax2.matshow(vect.numpy(), clim=(-1, 1))
65
+ cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1))
66
+
67
+ # Set titles
68
+ ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}')
69
+ ax2.set_title(f'a^(i): {vect.numel()}')
70
+ ax3.set_title(f'p: {prod.numel()}')
71
+
72
+ # Remove xticks for vectors
73
+ ax2.set_xticks(tuple())
74
+ ax3.set_xticks(tuple())
75
+
76
+ # Plot colourbars
77
+ fig.colorbar(cax1, ax=ax2)
78
+ fig.colorbar(cax3, ax=ax3)
79
+
80
+ # Fix y-axis limits
81
+ ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5)
82
+
83
+
84
+ colors = dict(
85
+ aqua='#8dd3c7',
86
+ yellow='#ffffb3',
87
+ lavender='#bebada',
88
+ red='#fb8072',
89
+ blue='#80b1d3',
90
+ orange='#fdb462',
91
+ green='#b3de69',
92
+ pink='#fccde5',
93
+ grey='#d9d9d9',
94
+ violet='#bc80bd',
95
+ unk1='#ccebc5',
96
+ unk2='#ffed6f',
97
+ )
98
+
99
+
100
+ def _cstr(s, color='black'):
101
+ if s == ' ':
102
+ return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>'
103
+ else:
104
+ return f'<text style=color:#000;background-color:{color}>{s} </text>'
105
+
106
+ # print html
107
+ def _print_color(t):
108
+ display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t])))
109
+
110
+ # get appropriate color for value
111
+ def _get_clr(value):
112
+ colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8',
113
+ '#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
114
+ '#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
115
+ '#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e')
116
+ value = int((value * 100) / 5)
117
+ if value == len(colors): value -= 1 # fixing bugs...
118
+ return colors[value]
119
+
120
+ def _visualise_values(output_values, result_list):
121
+ text_colours = []
122
+ for i in range(len(output_values)):
123
+ text = (result_list[i], _get_clr(output_values[i]))
124
+ text_colours.append(text)
125
+ _print_color(text_colours)
126
+
127
+ def print_colourbar():
128
+ color_range = torch.linspace(-2.5, 2.5, 20)
129
+ to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range]
130
+ _print_color(to_print)
131
+
132
+
133
+ # Let's only focus on the last time step for now
134
+ # First, the cell state (Long term memory)
135
+ def plot_state(data, state, b, decoder):
136
+ actual_data = decoder(data[b, :, :].numpy())
137
+ seq_len = len(actual_data)
138
+ seq_len_w_pad = len(state)
139
+ for s in range(state.size(2)):
140
+ states = torch.sigmoid(state[:, b, s])
141
+ _visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))
@@ -0,0 +1,27 @@
1
+ import torch
2
+ import math
3
+ def spiral_data(N=1000, D=2, C=3):
4
+ X = torch.zeros(N * C, D)
5
+ y = torch.zeros(N * C, dtype=torch.long)
6
+ for c in range(C):
7
+ index = 0
8
+ t = torch.linspace(0, 1, N)
9
+ # When c = 0 and t = 0: start of linspace
10
+ # When c = 0 and t = 1: end of linpace
11
+ # This inner_var is for the formula inside sin() and cos() like sin(inner_var) and cos(inner_Var)
12
+ inner_var = torch.linspace(
13
+ # When t = 0
14
+ (2 * math.pi / C) * (c),
15
+ # When t = 1
16
+ (2 * math.pi / C) * (2 + c),
17
+ N
18
+ ) + torch.randn(N) * 0.2
19
+
20
+ for ix in range(N * c, N * (c + 1)):
21
+ X[ix] = t[index] * torch.FloatTensor((
22
+ math.sin(inner_var[index]), math.cos(inner_var[index])
23
+ ))
24
+ y[ix] = c
25
+ index += 1
26
+
27
+ return (X, y)
@@ -0,0 +1,48 @@
1
+ import scout
2
+ import pandas as pd
3
+ import datasets
4
+ from typing import Any, Dict, List
5
+
6
+ def load_tsv(tsv_file):
7
+ tsv = scout.tsv(tsv_file)
8
+ ds = datasets.Dataset.from_pandas(tsv)
9
+ d = datasets.DatasetDict()
10
+ d["train"] = ds
11
+ return d
12
+
13
+ def load_json(json_file):
14
+ return datasets.load_dataset('json', data_files=[json_file])
15
+
16
+ def tokenize_dataset(tokenizer, dataset, max_length=32):
17
+ def preprocess_function(examples):
18
+ return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)
19
+ if isinstance(dataset, datasets.DatasetDict):
20
+ for split in dataset:
21
+ dataset[split] = dataset[split].map(preprocess_function, batched=True)
22
+ return dataset
23
+ else:
24
+ return dataset.map(preprocess_function, batched=True)
25
+
26
+ def tsv_dataset(tokenizer, tsv_file):
27
+ dataset = load_tsv(tsv_file)
28
+ return tokenize_dataset(tokenizer, dataset)
29
+
30
+ def json_dataset(tokenizer, json_file):
31
+ dataset = load_json(json_file)
32
+ return tokenize_dataset(tokenizer, dataset)
33
+
34
+ def list_dataset(tokenizer, texts, labels=None, max_length=32):
35
+ data_dict = {"text": texts}
36
+ if labels is not None:
37
+ data_dict["label"] = labels
38
+ ds = datasets.Dataset.from_dict(data_dict)
39
+
40
+ def preprocess_function(examples):
41
+ output = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)
42
+ if "label" in examples:
43
+ output["label"] = examples["label"]
44
+ return output
45
+
46
+ tokenized_ds = ds.map(preprocess_function, batched=True)
47
+ return tokenized_ds
48
+