scout-ai 0.2.0 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. checksums.yaml +4 -4
  2. data/.vimproject +91 -10
  3. data/Rakefile +1 -0
  4. data/VERSION +1 -1
  5. data/bin/scout-ai +2 -0
  6. data/lib/scout/llm/agent/chat.rb +24 -0
  7. data/lib/scout/llm/agent.rb +13 -13
  8. data/lib/scout/llm/ask.rb +26 -16
  9. data/lib/scout/llm/backends/bedrock.rb +129 -0
  10. data/lib/scout/llm/backends/huggingface.rb +6 -21
  11. data/lib/scout/llm/backends/ollama.rb +69 -36
  12. data/lib/scout/llm/backends/openai.rb +85 -35
  13. data/lib/scout/llm/backends/openwebui.rb +1 -1
  14. data/lib/scout/llm/backends/relay.rb +3 -2
  15. data/lib/scout/llm/backends/responses.rb +272 -0
  16. data/lib/scout/llm/chat.rb +547 -0
  17. data/lib/scout/llm/parse.rb +70 -13
  18. data/lib/scout/llm/tools.rb +126 -5
  19. data/lib/scout/llm/utils.rb +17 -10
  20. data/lib/scout/model/base.rb +19 -0
  21. data/lib/scout/model/python/base.rb +25 -0
  22. data/lib/scout/model/python/huggingface/causal/next_token.rb +23 -0
  23. data/lib/scout/model/python/huggingface/causal.rb +29 -0
  24. data/lib/scout/model/python/huggingface/classification +0 -0
  25. data/lib/scout/model/python/huggingface/classification.rb +50 -0
  26. data/lib/scout/model/python/huggingface.rb +112 -0
  27. data/lib/scout/model/python/torch/dataloader.rb +57 -0
  28. data/lib/scout/model/python/torch/helpers.rb +84 -0
  29. data/lib/scout/model/python/torch/introspection.rb +34 -0
  30. data/lib/scout/model/python/torch/load_and_save.rb +47 -0
  31. data/lib/scout/model/python/torch.rb +94 -0
  32. data/lib/scout/model/util/run.rb +181 -0
  33. data/lib/scout/model/util/save.rb +81 -0
  34. data/lib/scout-ai.rb +3 -1
  35. data/python/scout_ai/__init__.py +35 -0
  36. data/python/scout_ai/__pycache__/__init__.cpython-310.pyc +0 -0
  37. data/python/scout_ai/__pycache__/__init__.cpython-311.pyc +0 -0
  38. data/python/scout_ai/__pycache__/huggingface.cpython-310.pyc +0 -0
  39. data/python/scout_ai/__pycache__/huggingface.cpython-311.pyc +0 -0
  40. data/python/scout_ai/__pycache__/util.cpython-310.pyc +0 -0
  41. data/python/scout_ai/__pycache__/util.cpython-311.pyc +0 -0
  42. data/python/scout_ai/atcold/__init__.py +0 -0
  43. data/python/scout_ai/atcold/plot_lib.py +141 -0
  44. data/python/scout_ai/atcold/spiral.py +27 -0
  45. data/python/scout_ai/huggingface/data.py +48 -0
  46. data/python/scout_ai/huggingface/eval.py +60 -0
  47. data/python/scout_ai/huggingface/model.py +29 -0
  48. data/python/scout_ai/huggingface/rlhf.py +83 -0
  49. data/python/scout_ai/huggingface/train/__init__.py +34 -0
  50. data/python/scout_ai/huggingface/train/__pycache__/__init__.cpython-310.pyc +0 -0
  51. data/python/scout_ai/huggingface/train/__pycache__/next_token.cpython-310.pyc +0 -0
  52. data/python/scout_ai/huggingface/train/next_token.py +315 -0
  53. data/python/scout_ai/language_model.py +70 -0
  54. data/python/scout_ai/util.py +32 -0
  55. data/scout-ai.gemspec +130 -0
  56. data/scout_commands/agent/ask +133 -15
  57. data/scout_commands/agent/kb +15 -0
  58. data/scout_commands/llm/ask +71 -12
  59. data/scout_commands/llm/process +4 -2
  60. data/test/data/cat.jpg +0 -0
  61. data/test/scout/llm/agent/test_chat.rb +14 -0
  62. data/test/scout/llm/backends/test_bedrock.rb +60 -0
  63. data/test/scout/llm/backends/test_huggingface.rb +3 -3
  64. data/test/scout/llm/backends/test_ollama.rb +48 -10
  65. data/test/scout/llm/backends/test_openai.rb +96 -11
  66. data/test/scout/llm/backends/test_responses.rb +115 -0
  67. data/test/scout/llm/test_ask.rb +1 -0
  68. data/test/scout/llm/test_chat.rb +214 -0
  69. data/test/scout/llm/test_parse.rb +81 -2
  70. data/test/scout/model/python/huggingface/causal/test_next_token.rb +59 -0
  71. data/test/scout/model/python/huggingface/test_causal.rb +33 -0
  72. data/test/scout/model/python/huggingface/test_classification.rb +30 -0
  73. data/test/scout/model/python/test_base.rb +44 -0
  74. data/test/scout/model/python/test_huggingface.rb +9 -0
  75. data/test/scout/model/python/test_torch.rb +71 -0
  76. data/test/scout/model/python/torch/test_helpers.rb +14 -0
  77. data/test/scout/model/test_base.rb +117 -0
  78. data/test/scout/model/util/test_save.rb +31 -0
  79. metadata +72 -5
  80. data/questions/coach +0 -2
@@ -8,7 +8,6 @@ module LLM
8
8
  function_arguments = JSON.parse(function_arguments, { symbolize_names: true }) if String === function_arguments
9
9
  function_response = block.call function_name, function_arguments
10
10
 
11
- #content = String === function_response ? function_response : function_response.to_json,
12
11
  content = case function_response
13
12
  when String
14
13
  function_response
@@ -17,21 +16,28 @@ module LLM
17
16
  else
18
17
  function_response.to_json
19
18
  end
19
+ content = content.to_s if Numeric === content
20
20
  {
21
- tool_call_id: tool_call_id,
21
+ id: tool_call_id,
22
22
  role: "tool",
23
23
  content: content
24
24
  }
25
25
  end
26
26
 
27
- def self.task_tool_definition(workflow, task_name)
27
+ def self.task_tool_definition(workflow, task_name, inputs = nil)
28
28
  task_info = workflow.task_info(task_name)
29
29
 
30
+ inputs = inputs.collect{|i| i.to_sym } if inputs
31
+
30
32
  properties = task_info[:inputs].inject({}) do |acc,input|
33
+ next acc if inputs and not inputs.include?(input)
31
34
  type = task_info[:input_types][input]
32
35
  description = task_info[:input_descriptions][input]
33
36
 
37
+ type = :string if type == :text
34
38
  type = :string if type == :select
39
+ type = :string if type == :path
40
+ type = :number if type == :float
35
41
 
36
42
  acc[input] = {
37
43
  "type": type,
@@ -49,7 +55,8 @@ module LLM
49
55
  end
50
56
 
51
57
  required_inputs = task_info[:inputs].select do |input|
52
- task_info[:input_options].include?(input) && task_info[:input_options][:required]
58
+ next if inputs and not inputs.include?(input.to_sym)
59
+ task_info[:input_options].include?(input) && task_info[:input_options][input][:required]
53
60
  end
54
61
 
55
62
  {
@@ -67,7 +74,7 @@ module LLM
67
74
  end
68
75
 
69
76
  def self.workflow_tools(workflow, tasks = nil)
70
- tasks ||= workflow.tasks.keys
77
+ tasks = workflow.all_exports
71
78
  tasks.collect{|task_name| self.task_tool_definition(workflow, task_name) }
72
79
  end
73
80
 
@@ -101,4 +108,118 @@ module LLM
101
108
  }
102
109
  }]
103
110
  end
111
+
112
+ def self.association_tool_definition(name)
113
+ properties = {
114
+ entities: {
115
+ type: "array",
116
+ items: { type: :string },
117
+ description: "Source entities in the association, or target entities if 'reverse' it true."
118
+ },
119
+ reverse: {
120
+ type: "boolean",
121
+ description: "Look for targets instead of sources, defaults to 'false'."
122
+ }
123
+ }
124
+
125
+ {
126
+ type: "function",
127
+ function: {
128
+ name: name,
129
+ description: "Find associations for a list of entities. Returns a list in the format source~target.",
130
+ parameters: {
131
+ type: "object",
132
+ properties: properties,
133
+ required: ['entities']
134
+ }
135
+ }
136
+ }
137
+ end
138
+
139
+ def self.run_tools(messages)
140
+ messages.collect do |info|
141
+ IndiferentHash.setup(info)
142
+ role = info[:role]
143
+ if role == 'cmd'
144
+ {
145
+ role: 'tool',
146
+ content: CMD.cmd(info[:content]).read
147
+ }
148
+ else
149
+ info
150
+ end
151
+ end
152
+ end
153
+
154
+ def self.tools_to_openai(messages)
155
+ messages.collect do |message|
156
+ if message[:role] == 'function_call'
157
+ tool_call = JSON.parse(message[:content])
158
+ arguments = tool_call.delete('arguments') || {}
159
+ name = tool_call.delete('name')
160
+ tool_call['type'] = 'function'
161
+ tool_call['function'] ||= {}
162
+ tool_call['function']['name'] ||= name
163
+ tool_call['function']['arguments'] = arguments.to_json
164
+ {role: 'assistant', tool_calls: [tool_call]}
165
+ elsif message[:role] == 'function_call_output'
166
+ info = JSON.parse(message[:content])
167
+ id = info.delete('id') || ''
168
+ info['role'] = 'tool'
169
+ info['tool_call_id'] = id
170
+ info
171
+ else
172
+ message
173
+ end
174
+ end.flatten
175
+ end
176
+
177
+ def self.tools_to_ollama(messages)
178
+ messages.collect do |message|
179
+ if message[:role] == 'function_call'
180
+ tool_call = JSON.parse(message[:content])
181
+ arguments = tool_call.delete('arguments') || {}
182
+ id = tool_call.delete('id')
183
+ name = tool_call.delete('name')
184
+ tool_call['type'] = 'function'
185
+ tool_call['function'] ||= {}
186
+ tool_call['function']['name'] ||= name
187
+ tool_call['function']['arguments'] ||= arguments
188
+ {role: 'assistant', tool_calls: [tool_call]}
189
+ elsif message[:role] == 'function_call_output'
190
+ info = JSON.parse(message[:content])
191
+ id = info.delete('id') || ''
192
+ info['role'] = 'tool'
193
+ info
194
+ else
195
+ message
196
+ end
197
+ end.flatten
198
+ end
199
+
200
+ def self.tools_to_responses(messages)
201
+ messages.collect do |message|
202
+ if message[:role] == 'function_call'
203
+ tool_call = JSON.parse(message[:content])
204
+ tool_call['function']['arguments'] = (tool_call['function']['arguments'] || {}).to_json
205
+ {role: 'assistant', tool_calls: [tool_call]}
206
+ elsif message[:role] == 'function_call_output'
207
+ info = JSON.parse(message[:content])
208
+ info["tool_call_id"] = info['id']
209
+ info
210
+ else
211
+ message
212
+ end
213
+ end.flatten
214
+ end
215
+
216
+ def self.call_tools(tool_calls, &block)
217
+ tool_calls.collect{|tool_call|
218
+ response_message = LLM.tool_response(tool_call, &block)
219
+ [
220
+ {role: "function_call", content: tool_call.to_json},
221
+ {role: "function_call_output", content: response_message.to_json},
222
+ ]
223
+ }.flatten
224
+ end
104
225
  end
@@ -1,4 +1,21 @@
1
1
  module LLM
2
+
3
+ def self.tag(tag, content, name = nil)
4
+ if name
5
+ <<-EOF.strip
6
+ <#{tag} name="#{name}">
7
+ #{content}
8
+ </#{tag}>
9
+ EOF
10
+ else
11
+ <<-EOF.strip
12
+ <#{tag}>
13
+ #{content}
14
+ </#{tag}>
15
+ EOF
16
+ end
17
+ end
18
+
2
19
  def self.get_url_server_tokens(url, prefix=nil)
3
20
  return get_url_server_tokens(url).collect{|e| prefix.to_s + "." + e } if prefix
4
21
 
@@ -22,14 +39,4 @@ module LLM
22
39
  end
23
40
  Scout::Config.get(key, *all_tokens, hash)
24
41
  end
25
-
26
- def self
27
- if workflow.root.etc.AI[@model || 'default'].exists?
28
- workflow.root.etc.AI[@model || 'default'].json
29
- elsif Scout.etc.AI[@model || 'default'].exists?
30
- Scout.etc.AI[@model || 'default'].json
31
- else
32
- {}
33
- end
34
- end
35
42
  end
@@ -0,0 +1,19 @@
1
+ require_relative 'util/save'
2
+ require_relative 'util/run'
3
+
4
+ class ScoutModel
5
+ attr_accessor :directory, :options, :state
6
+
7
+ def initialize(directory = nil, options={})
8
+ @options = options
9
+
10
+ if directory
11
+ directory = Path.setup directory.dup unless Path === directory
12
+ @directory = directory
13
+ restore
14
+ end
15
+
16
+ @features = []
17
+ @labels = []
18
+ end
19
+ end
@@ -0,0 +1,25 @@
1
+ require_relative '../base'
2
+ require 'scout/python'
3
+
4
+ class PythonModel < ScoutModel
5
+ def initialize(dir, python_class = nil, python_module = nil, options = nil)
6
+ options, python_module = python_module, :model if options.nil? && Hash === python_module
7
+ options = {} if options.nil?
8
+
9
+ options[:python_class] = python_class if python_class
10
+ options[:python_module] = python_module if python_module
11
+
12
+ super(dir, options)
13
+
14
+ if options[:python_class]
15
+ self.init do
16
+ ScoutPython.add_path Scout.python.find(:lib)
17
+ ScoutPython.add_path directory
18
+ ScoutPython.init_scout
19
+ ScoutPython.class_new_obj(options[:python_module],
20
+ options[:python_class],
21
+ **options.except(:python_class, :python_module))
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,23 @@
1
+ require_relative '../causal'
2
+
3
+ class NextTokenModel < CausalModel
4
+ def initialize(...)
5
+ super(...)
6
+
7
+ train do |texts|
8
+ model, tokenizer = @state
9
+
10
+ if self.directory
11
+ output_dir = self.directory['output'].find
12
+ else
13
+ output_dir = TmpFile.tmp_file "next_token_model"
14
+ end
15
+ dataset = ScoutPython.call_method(
16
+ "scout_ai.huggingface.data", :list_dataset, tokenizer, texts)
17
+ ScoutPython.call_method(
18
+ "scout_ai.huggingface.train.next_token", :train_next_token,
19
+ model:model, tokenizer:tokenizer, dataset:dataset, output_dir:output_dir, **options[:training_args]
20
+ )
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,29 @@
1
+ require_relative '../huggingface'
2
+
3
+ class CausalModel < HuggingfaceModel
4
+ def initialize(...)
5
+ super("CausalLM", ...)
6
+
7
+ self.eval do |messages,list|
8
+ model, tokenizer = @state
9
+ ScoutPython.call_method(
10
+ "scout_ai.huggingface.eval", :eval_causal_lm_chat,
11
+ model, tokenizer, messages,
12
+ options[:chat_template],
13
+ options[:chat_template_kwargs],
14
+ options[:generation_kwargs]
15
+ )
16
+ end
17
+
18
+ train do |pairs,labels|
19
+ # data: array of [response, reward] or [prompt, response, reward]
20
+ model, tokenizer = @state
21
+
22
+ ScoutPython.call_method(
23
+ "scout_ai.huggingface.rlhf", :train_rlhf,
24
+ self.state_file, tokenizer, pairs, labels, options[:rlhf_config]
25
+ )
26
+ load_state
27
+ end
28
+ end
29
+ end
File without changes
@@ -0,0 +1,50 @@
1
+ require_relative '../huggingface'
2
+
3
+ class SequenceClassificationModel < HuggingfaceModel
4
+ def initialize(...)
5
+ super("SequenceClassification", ...)
6
+
7
+ self.eval do |features,list|
8
+ model, tokenizer = @state
9
+ texts = list ? list : [features]
10
+ res = ScoutPython.call_method("scout_ai.huggingface.eval", :eval_model, model, tokenizer, texts, options[:locate_tokens])
11
+ list ? res : res[0]
12
+ end
13
+
14
+ post_process do |result,list|
15
+ model, tokenizer = @state
16
+
17
+ logit_list = list ? list.logits : result
18
+
19
+ res = ScoutPython.collect(logit_list) do |logits|
20
+ logits = ScoutPython.numpy2ruby logits
21
+ best_class = logits.index logits.max
22
+ best_class = options[:class_labels][best_class] if options[:class_labels]
23
+ best_class
24
+ end
25
+
26
+ list ? res : res[0]
27
+ end
28
+
29
+ train do |texts,labels|
30
+ model, tokenizer = @state
31
+
32
+ if directory
33
+ tsv_file = File.join(directory, 'dataset.tsv')
34
+ checkpoint_dir = File.join(directory, 'checkpoints')
35
+ else
36
+ tmpdir = TmpFile.tmp_file
37
+ Open.mkdir tmpdir
38
+ tsv_file = File.join(tmpdir, 'dataset.tsv')
39
+ checkpoint_dir = File.join(tmpdir, 'checkpoints')
40
+ end
41
+
42
+ training_args_obj = ScoutPython.call_method("scout_ai.huggingface.train", :training_args, checkpoint_dir, options[:training_args])
43
+ dataset_file = HuggingfaceModel.text_dataset(tsv_file, texts, labels, options[:class_labels])
44
+
45
+ ScoutPython.call_method("scout_ai.huggingface.train", :train_model, model, tokenizer, training_args_obj, dataset_file, options[:class_weights])
46
+
47
+ Open.rm_rf tmpdir if tmpdir
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,112 @@
1
+ require_relative 'torch'
2
+
3
+ class HuggingfaceModel < TorchModel
4
+
5
+ def fix_options
6
+ @options[:training_options] = @options.delete(:training_args) if @options.include?(:training_args)
7
+ @options[:training_options] = @options.delete(:training_kwargs) if @options.include?(:training_kwargs)
8
+ training_args = IndiferentHash.pull_keys(@options, :training) || {}
9
+
10
+ @options[:tokenizer_options] = @options.delete(:tokenizer_args) if @options.include?(:tokenizer_args)
11
+ @options[:tokenizer_options] = @options.delete(:tokenizer_kwargs) if @options.include?(:tokenizer_kwargs)
12
+ tokenizer_args = IndiferentHash.pull_keys(@options, :tokenizer) || {}
13
+
14
+ @options[:training_args] = training_args
15
+ @options[:tokenizer_args] = tokenizer_args
16
+ end
17
+
18
+ def initialize(task=nil, checkpoint=nil, dir = nil, options = {})
19
+
20
+ super(dir, nil, nil, options)
21
+
22
+ fix_options
23
+
24
+ options[:checkpoint] = checkpoint
25
+ options[:task] = task
26
+
27
+ init do
28
+ TorchModel.init_python
29
+ checkpoint = state_file && File.directory?(state_file) ? state_file : self.options[:checkpoint]
30
+
31
+ model = ScoutPython.call_method("scout_ai.huggingface.model", :load_model,
32
+ self.options[:task], checkpoint,
33
+ **(IndiferentHash.setup(
34
+ self.options.except(
35
+ :training_args, :tokenizer_args,
36
+ :task, :checkpoint, :class_labels,
37
+ :model_options, :return_logits
38
+ ))))
39
+
40
+ tokenizer_checkpoint = self.options[:tokenizer_args][:checkpoint] || checkpoint
41
+
42
+ tokenizer = ScoutPython.call_method("scout_ai.huggingface.model", :load_tokenizer,
43
+ tokenizer_checkpoint,
44
+ **(IndiferentHash.setup(self.options[:tokenizer_args])))
45
+
46
+ [model, tokenizer]
47
+ end
48
+
49
+ load_state do |state_file|
50
+ model, tokenizer = @state
51
+ TorchModel.init_python
52
+ if state_file && Open.directory?(state_file)
53
+ model.from_pretrained(state_file)
54
+ tokenizer.from_pretrained(state_file)
55
+ end
56
+ end
57
+
58
+ save_state do |state_file,state|
59
+ model, tokenizer = @state
60
+ TorchModel.init_python
61
+ if state_file
62
+ model.save_pretrained(state_file)
63
+ tokenizer.save_pretrained(state_file)
64
+ end
65
+ end
66
+
67
+ #self.eval do |features,list|
68
+ # model, tokenizer = @state
69
+ # res = case options[:task]
70
+ # when "CausalLM"
71
+ # if not list
72
+ # list = [features]
73
+ # end
74
+ # # Allow for options :chat_template, :chat_template_kwargs, :generation_kwargs
75
+ # #options[:generation_kwargs] = {max_new_tokens: 1000}
76
+ # ScoutPython.call_method(
77
+ # "scout_ai.huggingface.eval", :eval_causal_lm_chat,
78
+ # model, tokenizer, list,
79
+ # options[:chat_template],
80
+ # options[:chat_template_kwargs],
81
+ # options[:generation_kwargs]
82
+ # )
83
+ # else
84
+ # texts = list ? list : [features]
85
+ # ScoutPython.call_method("scout_ai.huggingface.eval", :eval_model, model, tokenizer, texts, options[:locate_tokens])
86
+ # end
87
+ # list ? res : res[0]
88
+ #end
89
+
90
+ #train do |texts,labels|
91
+ # model, tokenizer = @state
92
+ #
93
+ # if directory
94
+ # tsv_file = File.join(directory, 'dataset.tsv')
95
+ # checkpoint_dir = File.join(directory, 'checkpoints')
96
+ # else
97
+ # tmpdir = TmpFile.tmp_file
98
+ # Open.mkdir tmpdir
99
+ # tsv_file = File.join(tmpdir, 'dataset.tsv')
100
+ # checkpoint_dir = File.join(tmpdir, 'checkpoints')
101
+ # end
102
+
103
+ # training_args_obj = ScoutPython.call_method("scout_ai.huggingface.train", :training_args, checkpoint_dir, options[:training_args])
104
+ # dataset_file = HuggingfaceModel.text_dataset(tsv_file, texts, labels, options[:class_labels])
105
+
106
+ # ScoutPython.call_method("scout_ai.huggingface.train", :train_model, model, tokenizer, training_args_obj, dataset_file, options[:class_weights])
107
+
108
+ # Open.rm_rf tmpdir if tmpdir
109
+ #end
110
+
111
+ end
112
+ end
@@ -0,0 +1,57 @@
1
+ class TorchModel
2
+ def self.feature_tsv(elements, labels = nil, class_labels = nil)
3
+ tsv = TSV.setup({}, :key_field => "ID", :fields => ["features"], :type => :flat)
4
+ if labels
5
+ tsv.fields = tsv.fields + ["label"]
6
+ labels = case class_labels
7
+ when Array
8
+ labels.collect{|l| class_labels.index l}
9
+ when Hash
10
+ inverse_class_labels = {}
11
+ class_labels.each{|c,l| inverse_class_labels[l] = c }
12
+ labels.collect{|l| inverse_class_labels[l]}
13
+ else
14
+ labels
15
+ end
16
+ elements.zip(labels).each_with_index do |p,i|
17
+ features, label = p
18
+ id = i
19
+ if Array === features
20
+ tsv[id] = features + [label]
21
+ else
22
+ tsv[id] = [features, label]
23
+ end
24
+ end
25
+ else
26
+ elements.each_with_index do |features,i|
27
+ id = i
28
+ if Array === features
29
+ tsv[id] = features
30
+ else
31
+ tsv[id] = [features]
32
+ end
33
+ end
34
+ end
35
+ tsv
36
+ end
37
+
38
+ def self.feature_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
39
+ tsv = feature_tsv(elements, labels, class_labels)
40
+ Open.write(tsv_dataset_file, tsv.to_s)
41
+ tsv_dataset_file
42
+ end
43
+
44
+ def self.text_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
45
+ elements = elements.compact.collect{|e| e.gsub("\n", ' ').gsub('"', '\'') }
46
+ tsv = feature_tsv(elements, labels, class_labels)
47
+ tsv.fields[0] = "text"
48
+ if labels.nil?
49
+ tsv = tsv.to_single
50
+ else
51
+ tsv.type = :list
52
+ end
53
+ Open.write(tsv_dataset_file, tsv.to_s)
54
+ tsv_dataset_file
55
+ end
56
+
57
+ end
@@ -0,0 +1,84 @@
1
+ require 'scout/python'
2
+
3
+ class TorchModel
4
+ module Tensor
5
+ def to_ruby
6
+ ScoutPython.numpy2ruby(self)
7
+ end
8
+
9
+ def to_ruby!
10
+ r = self.to_ruby
11
+ self.del
12
+ r
13
+ end
14
+
15
+ def length
16
+ PyCall.len(self)
17
+ end
18
+
19
+ def self.setup(obj)
20
+ obj.extend Tensor
21
+ end
22
+
23
+ def del
24
+ begin
25
+ self.to("cpu")
26
+ self.detach
27
+ self.grad = nil
28
+ self.untyped_storage.resize_ 0
29
+ rescue Exception
30
+ Log.exception $!
31
+ end
32
+ self
33
+ end
34
+ end
35
+
36
+ def self.init_python
37
+ return if defined?(@@init_python) && @@init_python
38
+ ScoutPython.add_path Scout.python.find(:lib)
39
+ ScoutPython.init_scout
40
+ ScoutPython.pyimport :torch
41
+ ScoutPython.pyimport :scout
42
+ ScoutPython.pyimport :scout_ai
43
+ ScoutPython.pyfrom :scout_ai, import: :util
44
+ ScoutPython.pyfrom :torch, import: :nn
45
+ @@init_python = true
46
+ end
47
+
48
+ def self.optimizer(model, training_args = {})
49
+ begin
50
+ learning_rate = training_args[:learning_rate] || 0.01
51
+ ScoutPython.torch.optim.SGD.new(model.parameters(), lr: learning_rate)
52
+ end
53
+ end
54
+
55
+ def self.criterion(model, training_args = {})
56
+ ScoutPython.torch.nn.MSELoss.new()
57
+ end
58
+
59
+ def self.device(model_options)
60
+ case model_options[:device]
61
+ when String, Symbol
62
+ ScoutPython.torch.device(model_options[:device].to_s)
63
+ when nil
64
+ ScoutPython.scout_ai.util.device()
65
+ else
66
+ model_options[:device]
67
+ end
68
+ end
69
+
70
+ def self.dtype(model_options)
71
+ case model_options[:dtype]
72
+ when String, Symbol
73
+ ScoutPython.torch.call(model_options[:dtype])
74
+ when nil
75
+ nil
76
+ else
77
+ model_options[:dtype]
78
+ end
79
+ end
80
+
81
+ def self.tensor(obj, device, dtype)
82
+ TorchModel::Tensor.setup(ScoutPython.torch.tensor(obj, dtype: dtype, device: device))
83
+ end
84
+ end
@@ -0,0 +1,34 @@
1
+ require_relative 'helpers'
2
+ class TorchModel
3
+ def self.get_layer(state, layer = nil)
4
+ state = state.first if Array === state
5
+ if layer.nil?
6
+ state
7
+ else
8
+ layer.split(".").inject(state){|acc,l| PyCall.getattr(acc, l.to_sym) }
9
+ end
10
+ end
11
+ def get_layer(...); TorchModel.get_layer(state, ...); end
12
+
13
+ def self.get_weights(state, layer = nil)
14
+ Tensor.setup PyCall.getattr(get_layer(state, layer), :weight)
15
+ end
16
+ def get_weights(...); TorchModel.get_weights(state, ...); end
17
+
18
+ def self.freeze(layer, requires_grad=false)
19
+ begin
20
+ PyCall.getattr(layer, :weight).requires_grad = requires_grad
21
+ rescue
22
+ end
23
+ ScoutPython.iterate(layer.children) do |layer|
24
+ freeze(layer, requires_grad)
25
+ end
26
+ end
27
+
28
+ def self.freeze_layer(state, layer, requires_grad = false)
29
+ layer = get_layer(state, layer)
30
+ freeze(layer, requires_grad)
31
+ end
32
+
33
+ def freeze_layer(...); TorchModel.freeze_layer(state, ...); end
34
+ end