scout-ai 0.2.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. checksums.yaml +4 -4
  2. data/.vimproject +155 -9
  3. data/README.md +296 -0
  4. data/Rakefile +3 -0
  5. data/VERSION +1 -1
  6. data/bin/scout-ai +2 -0
  7. data/doc/Agent.md +279 -0
  8. data/doc/Chat.md +258 -0
  9. data/doc/LLM.md +446 -0
  10. data/doc/Model.md +513 -0
  11. data/doc/RAG.md +129 -0
  12. data/lib/scout/llm/agent/chat.rb +74 -0
  13. data/lib/scout/llm/agent/delegate.rb +39 -0
  14. data/lib/scout/llm/agent/iterate.rb +44 -0
  15. data/lib/scout/llm/agent.rb +51 -30
  16. data/lib/scout/llm/ask.rb +63 -21
  17. data/lib/scout/llm/backends/anthropic.rb +147 -0
  18. data/lib/scout/llm/backends/bedrock.rb +129 -0
  19. data/lib/scout/llm/backends/huggingface.rb +6 -21
  20. data/lib/scout/llm/backends/ollama.rb +62 -35
  21. data/lib/scout/llm/backends/openai.rb +77 -33
  22. data/lib/scout/llm/backends/openwebui.rb +1 -1
  23. data/lib/scout/llm/backends/relay.rb +3 -2
  24. data/lib/scout/llm/backends/responses.rb +320 -0
  25. data/lib/scout/llm/chat.rb +703 -0
  26. data/lib/scout/llm/embed.rb +4 -4
  27. data/lib/scout/llm/mcp.rb +28 -0
  28. data/lib/scout/llm/parse.rb +71 -13
  29. data/lib/scout/llm/rag.rb +9 -0
  30. data/lib/scout/llm/tools/call.rb +66 -0
  31. data/lib/scout/llm/tools/knowledge_base.rb +158 -0
  32. data/lib/scout/llm/tools/mcp.rb +59 -0
  33. data/lib/scout/llm/tools/workflow.rb +69 -0
  34. data/lib/scout/llm/tools.rb +112 -76
  35. data/lib/scout/llm/utils.rb +17 -10
  36. data/lib/scout/model/base.rb +19 -0
  37. data/lib/scout/model/python/base.rb +25 -0
  38. data/lib/scout/model/python/huggingface/causal/next_token.rb +23 -0
  39. data/lib/scout/model/python/huggingface/causal.rb +29 -0
  40. data/lib/scout/model/python/huggingface/classification +0 -0
  41. data/lib/scout/model/python/huggingface/classification.rb +50 -0
  42. data/lib/scout/model/python/huggingface.rb +112 -0
  43. data/lib/scout/model/python/torch/dataloader.rb +57 -0
  44. data/lib/scout/model/python/torch/helpers.rb +84 -0
  45. data/lib/scout/model/python/torch/introspection.rb +34 -0
  46. data/lib/scout/model/python/torch/load_and_save.rb +47 -0
  47. data/lib/scout/model/python/torch.rb +94 -0
  48. data/lib/scout/model/util/run.rb +181 -0
  49. data/lib/scout/model/util/save.rb +81 -0
  50. data/lib/scout-ai.rb +4 -1
  51. data/python/scout_ai/__init__.py +35 -0
  52. data/python/scout_ai/huggingface/data.py +48 -0
  53. data/python/scout_ai/huggingface/eval.py +60 -0
  54. data/python/scout_ai/huggingface/model.py +29 -0
  55. data/python/scout_ai/huggingface/rlhf.py +83 -0
  56. data/python/scout_ai/huggingface/train/__init__.py +34 -0
  57. data/python/scout_ai/huggingface/train/next_token.py +315 -0
  58. data/python/scout_ai/util.py +32 -0
  59. data/scout-ai.gemspec +143 -0
  60. data/scout_commands/agent/ask +89 -14
  61. data/scout_commands/agent/kb +15 -0
  62. data/scout_commands/documenter +148 -0
  63. data/scout_commands/llm/ask +71 -12
  64. data/scout_commands/llm/process +4 -2
  65. data/scout_commands/llm/server +319 -0
  66. data/share/server/chat.html +138 -0
  67. data/share/server/chat.js +468 -0
  68. data/test/data/cat.jpg +0 -0
  69. data/test/scout/llm/agent/test_chat.rb +14 -0
  70. data/test/scout/llm/backends/test_anthropic.rb +134 -0
  71. data/test/scout/llm/backends/test_bedrock.rb +60 -0
  72. data/test/scout/llm/backends/test_huggingface.rb +3 -3
  73. data/test/scout/llm/backends/test_ollama.rb +48 -10
  74. data/test/scout/llm/backends/test_openai.rb +134 -10
  75. data/test/scout/llm/backends/test_responses.rb +239 -0
  76. data/test/scout/llm/test_agent.rb +0 -70
  77. data/test/scout/llm/test_ask.rb +4 -1
  78. data/test/scout/llm/test_chat.rb +256 -0
  79. data/test/scout/llm/test_mcp.rb +29 -0
  80. data/test/scout/llm/test_parse.rb +81 -2
  81. data/test/scout/llm/tools/test_call.rb +0 -0
  82. data/test/scout/llm/tools/test_knowledge_base.rb +22 -0
  83. data/test/scout/llm/tools/test_mcp.rb +11 -0
  84. data/test/scout/llm/tools/test_workflow.rb +39 -0
  85. data/test/scout/model/python/huggingface/causal/test_next_token.rb +59 -0
  86. data/test/scout/model/python/huggingface/test_causal.rb +33 -0
  87. data/test/scout/model/python/huggingface/test_classification.rb +30 -0
  88. data/test/scout/model/python/test_base.rb +44 -0
  89. data/test/scout/model/python/test_huggingface.rb +9 -0
  90. data/test/scout/model/python/test_torch.rb +71 -0
  91. data/test/scout/model/python/torch/test_helpers.rb +14 -0
  92. data/test/scout/model/test_base.rb +117 -0
  93. data/test/scout/model/util/test_save.rb +31 -0
  94. metadata +113 -7
  95. data/README.rdoc +0 -18
  96. data/questions/coach +0 -2
@@ -1,104 +1,140 @@
1
- require 'scout/workflow'
2
1
  require 'scout/knowledge_base'
2
+ require_relative 'tools/mcp'
3
+ require_relative 'tools/workflow'
4
+ require_relative 'tools/knowledge_base'
5
+ require_relative 'tools/call'
3
6
  module LLM
7
+ def self.call_tools(tool_calls, &block)
8
+ tool_calls.collect{|tool_call|
9
+ response_message = LLM.tool_response(tool_call, &block)
10
+ function_call = tool_call
11
+ function_call['id'] = tool_call.delete('call_id') if tool_call.dig('call_id')
12
+ [
13
+ {role: "function_call", content: tool_call.to_json},
14
+ {role: "function_call_output", content: response_message.to_json},
15
+ ]
16
+ }.flatten
17
+ end
18
+
4
19
  def self.tool_response(tool_call, &block)
5
- tool_call_id = tool_call.dig("id")
6
- function_name = tool_call.dig("function", "name")
7
- function_arguments = tool_call.dig("function", "arguments")
20
+ tool_call_id = tool_call.dig("call_id") || tool_call.dig("id")
21
+ if tool_call['function']
22
+ function_name = tool_call.dig("function", "name")
23
+ function_arguments = tool_call.dig("function", "arguments")
24
+ else
25
+ function_name = tool_call.dig("name")
26
+ function_arguments = tool_call.dig("arguments")
27
+ end
28
+
8
29
  function_arguments = JSON.parse(function_arguments, { symbolize_names: true }) if String === function_arguments
9
- function_response = block.call function_name, function_arguments
10
30
 
11
- #content = String === function_response ? function_response : function_response.to_json,
31
+ Log.high "Calling function #{function_name} with arguments #{Log.fingerprint function_arguments}"
32
+
33
+ function_response = begin
34
+ block.call function_name, function_arguments
35
+ rescue
36
+ $!
37
+ end
38
+
12
39
  content = case function_response
13
40
  when String
14
41
  function_response
15
42
  when nil
16
43
  "success"
44
+ when Exception
45
+ {exception: function_response.message, stack: function_response.backtrace }.to_json
17
46
  else
18
47
  function_response.to_json
19
48
  end
49
+ content = content.to_s if Numeric === content
20
50
  {
21
- tool_call_id: tool_call_id,
51
+ id: tool_call_id,
22
52
  role: "tool",
23
53
  content: content
24
54
  }
25
55
  end
26
56
 
27
- def self.task_tool_definition(workflow, task_name)
28
- task_info = workflow.task_info(task_name)
29
-
30
- properties = task_info[:inputs].inject({}) do |acc,input|
31
- type = task_info[:input_types][input]
32
- description = task_info[:input_descriptions][input]
33
-
34
- type = :string if type == :select
35
-
36
- acc[input] = {
37
- "type": type,
38
- "description": description
39
- }
40
-
41
- if input_options = task_info[:input_options][input]
42
- if select_options = input_options[:select_options]
43
- select_options = select_options.values if Hash === select_options
44
- acc[input]["enum"] = select_options
45
- end
57
+ def self.run_tools(messages)
58
+ messages.collect do |info|
59
+ IndiferentHash.setup(info)
60
+ role = info[:role]
61
+ if role == 'cmd'
62
+ {
63
+ role: 'tool',
64
+ content: CMD.cmd(info[:content]).read
65
+ }
66
+ else
67
+ info
46
68
  end
47
-
48
- acc
49
- end
50
-
51
- required_inputs = task_info[:inputs].select do |input|
52
- task_info[:input_options].include?(input) && task_info[:input_options][:required]
53
69
  end
54
-
55
- {
56
- type: "function",
57
- function: {
58
- name: task_name,
59
- description: task_info[:description],
60
- parameters: {
61
- type: "object",
62
- properties: properties,
63
- required: required_inputs
64
- }
65
- }
66
- }
67
70
  end
68
71
 
69
- def self.workflow_tools(workflow, tasks = nil)
70
- tasks ||= workflow.tasks.keys
71
- tasks.collect{|task_name| self.task_tool_definition(workflow, task_name) }
72
+ def self.tools_to_openai(messages)
73
+ messages.collect do |message|
74
+ if message[:role] == 'function_call'
75
+ tool_call = JSON.parse(message[:content])
76
+ arguments = tool_call.delete('arguments') || {}
77
+ name = tool_call[:name]
78
+ tool_call['type'] = 'function'
79
+ tool_call['function'] ||= {}
80
+ tool_call['function']['name'] ||= name
81
+ tool_call['function']['arguments'] = arguments.to_json
82
+ {role: 'assistant', tool_calls: [tool_call]}
83
+ elsif message[:role] == 'function_call_output'
84
+ info = JSON.parse(message[:content])
85
+ id = info.delete('call_id') || info.dig('id')
86
+ info['role'] = 'tool'
87
+ info['tool_call_id'] = id
88
+ info
89
+ else
90
+ message
91
+ end
92
+ end.flatten
72
93
  end
73
94
 
74
- def self.knowledge_base_tool_definition(knowledge_base)
75
-
76
- databases = knowledge_base.all_databases.collect{|d| d.to_s }
77
-
78
- properties = {
79
- database: {
80
- type: "string",
81
- enum: databases,
82
- description: "Database to traverse"
83
- },
84
- entities: {
85
- type: "array",
86
- items: { type: :string },
87
- description: "Parent entities to find children for"
88
- }
89
- }
95
+ def self.tools_to_anthropic(messages)
96
+ messages.collect do |message|
97
+ if message[:role] == 'function_call'
98
+ tool_call = JSON.parse(message[:content])
99
+ arguments = tool_call.delete('arguments') || {}
100
+ name = tool_call[:name]
101
+ tool_call['type'] = 'tool_use'
102
+ tool_call['name'] ||= name
103
+ tool_call['input'] = arguments
104
+ {role: 'assistant', content: [tool_call]}
105
+ elsif message[:role] == 'function_call_output'
106
+ info = JSON.parse(message[:content])
107
+ id = info.delete('call_id') || info.delete('id')
108
+ info.delete "role"
109
+ info['tool_use_id'] = id
110
+ info['type'] = 'tool_result'
111
+ {role: 'user', content: [info]}
112
+ else
113
+ message
114
+ end
115
+ end.flatten
116
+ end
90
117
 
91
- [{
92
- type: "function",
93
- function: {
94
- name: 'children',
95
- description: "Find the graph children for a list of entities in a format like parent~child. Returns a list.",
96
- parameters: {
97
- type: "object",
98
- properties: properties,
99
- required: ['database', 'entities']
100
- }
101
- }
102
- }]
118
+ def self.tools_to_ollama(messages)
119
+ messages.collect do |message|
120
+ if message[:role] == 'function_call'
121
+ tool_call = JSON.parse(message[:content])
122
+ arguments = tool_call.delete('arguments') || {}
123
+ id = tool_call.delete('id')
124
+ name = tool_call.delete('name')
125
+ tool_call['type'] = 'function'
126
+ tool_call['function'] ||= {}
127
+ tool_call['function']['name'] ||= name
128
+ tool_call['function']['arguments'] ||= arguments
129
+ {role: 'assistant', tool_calls: [tool_call]}
130
+ elsif message[:role] == 'function_call_output'
131
+ info = JSON.parse(message[:content])
132
+ id = info.delete('id') || ''
133
+ info['role'] = 'tool'
134
+ info
135
+ else
136
+ message
137
+ end
138
+ end.flatten
103
139
  end
104
140
  end
@@ -1,4 +1,21 @@
1
1
  module LLM
2
+
3
+ def self.tag(tag, content, name = nil)
4
+ if name
5
+ <<-EOF.strip
6
+ <#{tag} name="#{name}">
7
+ #{content}
8
+ </#{tag}>
9
+ EOF
10
+ else
11
+ <<-EOF.strip
12
+ <#{tag}>
13
+ #{content}
14
+ </#{tag}>
15
+ EOF
16
+ end
17
+ end
18
+
2
19
  def self.get_url_server_tokens(url, prefix=nil)
3
20
  return get_url_server_tokens(url).collect{|e| prefix.to_s + "." + e } if prefix
4
21
 
@@ -22,14 +39,4 @@ module LLM
22
39
  end
23
40
  Scout::Config.get(key, *all_tokens, hash)
24
41
  end
25
-
26
- def self
27
- if workflow.root.etc.AI[@model || 'default'].exists?
28
- workflow.root.etc.AI[@model || 'default'].json
29
- elsif Scout.etc.AI[@model || 'default'].exists?
30
- Scout.etc.AI[@model || 'default'].json
31
- else
32
- {}
33
- end
34
- end
35
42
  end
@@ -0,0 +1,19 @@
1
+ require_relative 'util/save'
2
+ require_relative 'util/run'
3
+
4
+ class ScoutModel
5
+ attr_accessor :directory, :options, :state
6
+
7
+ def initialize(directory = nil, options={})
8
+ @options = options
9
+
10
+ if directory
11
+ directory = Path.setup directory.dup unless Path === directory
12
+ @directory = directory
13
+ restore
14
+ end
15
+
16
+ @features = []
17
+ @labels = []
18
+ end
19
+ end
@@ -0,0 +1,25 @@
1
+ require_relative '../base'
2
+ require 'scout/python'
3
+
4
+ class PythonModel < ScoutModel
5
+ def initialize(dir, python_class = nil, python_module = nil, options = nil)
6
+ options, python_module = python_module, :model if options.nil? && Hash === python_module
7
+ options = {} if options.nil?
8
+
9
+ options[:python_class] = python_class if python_class
10
+ options[:python_module] = python_module if python_module
11
+
12
+ super(dir, options)
13
+
14
+ if options[:python_class]
15
+ self.init do
16
+ ScoutPython.add_path Scout.python.find(:lib)
17
+ ScoutPython.add_path directory
18
+ ScoutPython.init_scout
19
+ ScoutPython.class_new_obj(options[:python_module],
20
+ options[:python_class],
21
+ **options.except(:python_class, :python_module))
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,23 @@
1
+ require_relative '../causal'
2
+
3
+ class NextTokenModel < CausalModel
4
+ def initialize(...)
5
+ super(...)
6
+
7
+ train do |texts|
8
+ model, tokenizer = @state
9
+
10
+ if self.directory
11
+ output_dir = self.directory['output'].find
12
+ else
13
+ output_dir = TmpFile.tmp_file "next_token_model"
14
+ end
15
+ dataset = ScoutPython.call_method(
16
+ "scout_ai.huggingface.data", :list_dataset, tokenizer, texts)
17
+ ScoutPython.call_method(
18
+ "scout_ai.huggingface.train.next_token", :train_next_token,
19
+ model:model, tokenizer:tokenizer, dataset:dataset, output_dir:output_dir, **options[:training_args]
20
+ )
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,29 @@
1
+ require_relative '../huggingface'
2
+
3
+ class CausalModel < HuggingfaceModel
4
+ def initialize(...)
5
+ super("CausalLM", ...)
6
+
7
+ self.eval do |messages,list|
8
+ model, tokenizer = @state
9
+ ScoutPython.call_method(
10
+ "scout_ai.huggingface.eval", :eval_causal_lm_chat,
11
+ model, tokenizer, messages,
12
+ options[:chat_template],
13
+ options[:chat_template_kwargs],
14
+ options[:generation_kwargs]
15
+ )
16
+ end
17
+
18
+ train do |pairs,labels|
19
+ # data: array of [response, reward] or [prompt, response, reward]
20
+ model, tokenizer = @state
21
+
22
+ ScoutPython.call_method(
23
+ "scout_ai.huggingface.rlhf", :train_rlhf,
24
+ self.state_file, tokenizer, pairs, labels, options[:rlhf_config]
25
+ )
26
+ load_state
27
+ end
28
+ end
29
+ end
File without changes
@@ -0,0 +1,50 @@
1
+ require_relative '../huggingface'
2
+
3
+ class SequenceClassificationModel < HuggingfaceModel
4
+ def initialize(...)
5
+ super("SequenceClassification", ...)
6
+
7
+ self.eval do |features,list|
8
+ model, tokenizer = @state
9
+ texts = list ? list : [features]
10
+ res = ScoutPython.call_method("scout_ai.huggingface.eval", :eval_model, model, tokenizer, texts, options[:locate_tokens])
11
+ list ? res : res[0]
12
+ end
13
+
14
+ post_process do |result,list|
15
+ model, tokenizer = @state
16
+
17
+ logit_list = list ? list.logits : result
18
+
19
+ res = ScoutPython.collect(logit_list) do |logits|
20
+ logits = ScoutPython.numpy2ruby logits
21
+ best_class = logits.index logits.max
22
+ best_class = options[:class_labels][best_class] if options[:class_labels]
23
+ best_class
24
+ end
25
+
26
+ list ? res : res[0]
27
+ end
28
+
29
+ train do |texts,labels|
30
+ model, tokenizer = @state
31
+
32
+ if directory
33
+ tsv_file = File.join(directory, 'dataset.tsv')
34
+ checkpoint_dir = File.join(directory, 'checkpoints')
35
+ else
36
+ tmpdir = TmpFile.tmp_file
37
+ Open.mkdir tmpdir
38
+ tsv_file = File.join(tmpdir, 'dataset.tsv')
39
+ checkpoint_dir = File.join(tmpdir, 'checkpoints')
40
+ end
41
+
42
+ training_args_obj = ScoutPython.call_method("scout_ai.huggingface.train", :training_args, checkpoint_dir, options[:training_args])
43
+ dataset_file = HuggingfaceModel.text_dataset(tsv_file, texts, labels, options[:class_labels])
44
+
45
+ ScoutPython.call_method("scout_ai.huggingface.train", :train_model, model, tokenizer, training_args_obj, dataset_file, options[:class_weights])
46
+
47
+ Open.rm_rf tmpdir if tmpdir
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,112 @@
1
+ require_relative 'torch'
2
+
3
+ class HuggingfaceModel < TorchModel
4
+
5
+ def fix_options
6
+ @options[:training_options] = @options.delete(:training_args) if @options.include?(:training_args)
7
+ @options[:training_options] = @options.delete(:training_kwargs) if @options.include?(:training_kwargs)
8
+ training_args = IndiferentHash.pull_keys(@options, :training) || {}
9
+
10
+ @options[:tokenizer_options] = @options.delete(:tokenizer_args) if @options.include?(:tokenizer_args)
11
+ @options[:tokenizer_options] = @options.delete(:tokenizer_kwargs) if @options.include?(:tokenizer_kwargs)
12
+ tokenizer_args = IndiferentHash.pull_keys(@options, :tokenizer) || {}
13
+
14
+ @options[:training_args] = training_args
15
+ @options[:tokenizer_args] = tokenizer_args
16
+ end
17
+
18
+ def initialize(task=nil, checkpoint=nil, dir = nil, options = {})
19
+
20
+ super(dir, nil, nil, options)
21
+
22
+ fix_options
23
+
24
+ options[:checkpoint] = checkpoint
25
+ options[:task] = task
26
+
27
+ init do
28
+ TorchModel.init_python
29
+ checkpoint = state_file && File.directory?(state_file) ? state_file : self.options[:checkpoint]
30
+
31
+ model = ScoutPython.call_method("scout_ai.huggingface.model", :load_model,
32
+ self.options[:task], checkpoint,
33
+ **(IndiferentHash.setup(
34
+ self.options.except(
35
+ :training_args, :tokenizer_args,
36
+ :task, :checkpoint, :class_labels,
37
+ :model_options, :return_logits
38
+ ))))
39
+
40
+ tokenizer_checkpoint = self.options[:tokenizer_args][:checkpoint] || checkpoint
41
+
42
+ tokenizer = ScoutPython.call_method("scout_ai.huggingface.model", :load_tokenizer,
43
+ tokenizer_checkpoint,
44
+ **(IndiferentHash.setup(self.options[:tokenizer_args])))
45
+
46
+ [model, tokenizer]
47
+ end
48
+
49
+ load_state do |state_file|
50
+ model, tokenizer = @state
51
+ TorchModel.init_python
52
+ if state_file && Open.directory?(state_file)
53
+ model.from_pretrained(state_file)
54
+ tokenizer.from_pretrained(state_file)
55
+ end
56
+ end
57
+
58
+ save_state do |state_file,state|
59
+ model, tokenizer = @state
60
+ TorchModel.init_python
61
+ if state_file
62
+ model.save_pretrained(state_file)
63
+ tokenizer.save_pretrained(state_file)
64
+ end
65
+ end
66
+
67
+ #self.eval do |features,list|
68
+ # model, tokenizer = @state
69
+ # res = case options[:task]
70
+ # when "CausalLM"
71
+ # if not list
72
+ # list = [features]
73
+ # end
74
+ # # Allow for options :chat_template, :chat_template_kwargs, :generation_kwargs
75
+ # #options[:generation_kwargs] = {max_new_tokens: 1000}
76
+ # ScoutPython.call_method(
77
+ # "scout_ai.huggingface.eval", :eval_causal_lm_chat,
78
+ # model, tokenizer, list,
79
+ # options[:chat_template],
80
+ # options[:chat_template_kwargs],
81
+ # options[:generation_kwargs]
82
+ # )
83
+ # else
84
+ # texts = list ? list : [features]
85
+ # ScoutPython.call_method("scout_ai.huggingface.eval", :eval_model, model, tokenizer, texts, options[:locate_tokens])
86
+ # end
87
+ # list ? res : res[0]
88
+ #end
89
+
90
+ #train do |texts,labels|
91
+ # model, tokenizer = @state
92
+ #
93
+ # if directory
94
+ # tsv_file = File.join(directory, 'dataset.tsv')
95
+ # checkpoint_dir = File.join(directory, 'checkpoints')
96
+ # else
97
+ # tmpdir = TmpFile.tmp_file
98
+ # Open.mkdir tmpdir
99
+ # tsv_file = File.join(tmpdir, 'dataset.tsv')
100
+ # checkpoint_dir = File.join(tmpdir, 'checkpoints')
101
+ # end
102
+
103
+ # training_args_obj = ScoutPython.call_method("scout_ai.huggingface.train", :training_args, checkpoint_dir, options[:training_args])
104
+ # dataset_file = HuggingfaceModel.text_dataset(tsv_file, texts, labels, options[:class_labels])
105
+
106
+ # ScoutPython.call_method("scout_ai.huggingface.train", :train_model, model, tokenizer, training_args_obj, dataset_file, options[:class_weights])
107
+
108
+ # Open.rm_rf tmpdir if tmpdir
109
+ #end
110
+
111
+ end
112
+ end
@@ -0,0 +1,57 @@
1
+ class TorchModel
2
+ def self.feature_tsv(elements, labels = nil, class_labels = nil)
3
+ tsv = TSV.setup({}, :key_field => "ID", :fields => ["features"], :type => :flat)
4
+ if labels
5
+ tsv.fields = tsv.fields + ["label"]
6
+ labels = case class_labels
7
+ when Array
8
+ labels.collect{|l| class_labels.index l}
9
+ when Hash
10
+ inverse_class_labels = {}
11
+ class_labels.each{|c,l| inverse_class_labels[l] = c }
12
+ labels.collect{|l| inverse_class_labels[l]}
13
+ else
14
+ labels
15
+ end
16
+ elements.zip(labels).each_with_index do |p,i|
17
+ features, label = p
18
+ id = i
19
+ if Array === features
20
+ tsv[id] = features + [label]
21
+ else
22
+ tsv[id] = [features, label]
23
+ end
24
+ end
25
+ else
26
+ elements.each_with_index do |features,i|
27
+ id = i
28
+ if Array === features
29
+ tsv[id] = features
30
+ else
31
+ tsv[id] = [features]
32
+ end
33
+ end
34
+ end
35
+ tsv
36
+ end
37
+
38
+ def self.feature_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
39
+ tsv = feature_tsv(elements, labels, class_labels)
40
+ Open.write(tsv_dataset_file, tsv.to_s)
41
+ tsv_dataset_file
42
+ end
43
+
44
+ def self.text_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
45
+ elements = elements.compact.collect{|e| e.gsub("\n", ' ').gsub('"', '\'') }
46
+ tsv = feature_tsv(elements, labels, class_labels)
47
+ tsv.fields[0] = "text"
48
+ if labels.nil?
49
+ tsv = tsv.to_single
50
+ else
51
+ tsv.type = :list
52
+ end
53
+ Open.write(tsv_dataset_file, tsv.to_s)
54
+ tsv_dataset_file
55
+ end
56
+
57
+ end
@@ -0,0 +1,84 @@
1
+ require 'scout/python'
2
+
3
+ class TorchModel
4
+ module Tensor
5
+ def to_ruby
6
+ ScoutPython.numpy2ruby(self)
7
+ end
8
+
9
+ def to_ruby!
10
+ r = self.to_ruby
11
+ self.del
12
+ r
13
+ end
14
+
15
+ def length
16
+ PyCall.len(self)
17
+ end
18
+
19
+ def self.setup(obj)
20
+ obj.extend Tensor
21
+ end
22
+
23
+ def del
24
+ begin
25
+ self.to("cpu")
26
+ self.detach
27
+ self.grad = nil
28
+ self.untyped_storage.resize_ 0
29
+ rescue Exception
30
+ Log.exception $!
31
+ end
32
+ self
33
+ end
34
+ end
35
+
36
+ def self.init_python
37
+ return if defined?(@@init_python) && @@init_python
38
+ ScoutPython.add_path Scout.python.find(:lib)
39
+ ScoutPython.init_scout
40
+ ScoutPython.pyimport :torch
41
+ ScoutPython.pyimport :scout
42
+ ScoutPython.pyimport :scout_ai
43
+ ScoutPython.pyfrom :scout_ai, import: :util
44
+ ScoutPython.pyfrom :torch, import: :nn
45
+ @@init_python = true
46
+ end
47
+
48
+ def self.optimizer(model, training_args = {})
49
+ begin
50
+ learning_rate = training_args[:learning_rate] || 0.01
51
+ ScoutPython.torch.optim.SGD.new(model.parameters(), lr: learning_rate)
52
+ end
53
+ end
54
+
55
+ def self.criterion(model, training_args = {})
56
+ ScoutPython.torch.nn.MSELoss.new()
57
+ end
58
+
59
+ def self.device(model_options)
60
+ case model_options[:device]
61
+ when String, Symbol
62
+ ScoutPython.torch.device(model_options[:device].to_s)
63
+ when nil
64
+ ScoutPython.scout_ai.util.device()
65
+ else
66
+ model_options[:device]
67
+ end
68
+ end
69
+
70
+ def self.dtype(model_options)
71
+ case model_options[:dtype]
72
+ when String, Symbol
73
+ ScoutPython.torch.call(model_options[:dtype])
74
+ when nil
75
+ nil
76
+ else
77
+ model_options[:dtype]
78
+ end
79
+ end
80
+
81
+ def self.tensor(obj, device, dtype)
82
+ TorchModel::Tensor.setup(ScoutPython.torch.tensor(obj, dtype: dtype, device: device))
83
+ end
84
+ end