scout-ai 0.2.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. checksums.yaml +4 -4
  2. data/.vimproject +155 -9
  3. data/README.md +296 -0
  4. data/Rakefile +3 -0
  5. data/VERSION +1 -1
  6. data/bin/scout-ai +2 -0
  7. data/doc/Agent.md +279 -0
  8. data/doc/Chat.md +258 -0
  9. data/doc/LLM.md +446 -0
  10. data/doc/Model.md +513 -0
  11. data/doc/RAG.md +129 -0
  12. data/lib/scout/llm/agent/chat.rb +74 -0
  13. data/lib/scout/llm/agent/delegate.rb +39 -0
  14. data/lib/scout/llm/agent/iterate.rb +44 -0
  15. data/lib/scout/llm/agent.rb +51 -30
  16. data/lib/scout/llm/ask.rb +63 -21
  17. data/lib/scout/llm/backends/anthropic.rb +147 -0
  18. data/lib/scout/llm/backends/bedrock.rb +129 -0
  19. data/lib/scout/llm/backends/huggingface.rb +6 -21
  20. data/lib/scout/llm/backends/ollama.rb +62 -35
  21. data/lib/scout/llm/backends/openai.rb +77 -33
  22. data/lib/scout/llm/backends/openwebui.rb +1 -1
  23. data/lib/scout/llm/backends/relay.rb +3 -2
  24. data/lib/scout/llm/backends/responses.rb +320 -0
  25. data/lib/scout/llm/chat.rb +703 -0
  26. data/lib/scout/llm/embed.rb +4 -4
  27. data/lib/scout/llm/mcp.rb +28 -0
  28. data/lib/scout/llm/parse.rb +71 -13
  29. data/lib/scout/llm/rag.rb +9 -0
  30. data/lib/scout/llm/tools/call.rb +66 -0
  31. data/lib/scout/llm/tools/knowledge_base.rb +158 -0
  32. data/lib/scout/llm/tools/mcp.rb +59 -0
  33. data/lib/scout/llm/tools/workflow.rb +69 -0
  34. data/lib/scout/llm/tools.rb +112 -76
  35. data/lib/scout/llm/utils.rb +17 -10
  36. data/lib/scout/model/base.rb +19 -0
  37. data/lib/scout/model/python/base.rb +25 -0
  38. data/lib/scout/model/python/huggingface/causal/next_token.rb +23 -0
  39. data/lib/scout/model/python/huggingface/causal.rb +29 -0
  40. data/lib/scout/model/python/huggingface/classification +0 -0
  41. data/lib/scout/model/python/huggingface/classification.rb +50 -0
  42. data/lib/scout/model/python/huggingface.rb +112 -0
  43. data/lib/scout/model/python/torch/dataloader.rb +57 -0
  44. data/lib/scout/model/python/torch/helpers.rb +84 -0
  45. data/lib/scout/model/python/torch/introspection.rb +34 -0
  46. data/lib/scout/model/python/torch/load_and_save.rb +47 -0
  47. data/lib/scout/model/python/torch.rb +94 -0
  48. data/lib/scout/model/util/run.rb +181 -0
  49. data/lib/scout/model/util/save.rb +81 -0
  50. data/lib/scout-ai.rb +4 -1
  51. data/python/scout_ai/__init__.py +35 -0
  52. data/python/scout_ai/huggingface/data.py +48 -0
  53. data/python/scout_ai/huggingface/eval.py +60 -0
  54. data/python/scout_ai/huggingface/model.py +29 -0
  55. data/python/scout_ai/huggingface/rlhf.py +83 -0
  56. data/python/scout_ai/huggingface/train/__init__.py +34 -0
  57. data/python/scout_ai/huggingface/train/next_token.py +315 -0
  58. data/python/scout_ai/util.py +32 -0
  59. data/scout-ai.gemspec +143 -0
  60. data/scout_commands/agent/ask +89 -14
  61. data/scout_commands/agent/kb +15 -0
  62. data/scout_commands/documenter +148 -0
  63. data/scout_commands/llm/ask +71 -12
  64. data/scout_commands/llm/process +4 -2
  65. data/scout_commands/llm/server +319 -0
  66. data/share/server/chat.html +138 -0
  67. data/share/server/chat.js +468 -0
  68. data/test/data/cat.jpg +0 -0
  69. data/test/scout/llm/agent/test_chat.rb +14 -0
  70. data/test/scout/llm/backends/test_anthropic.rb +134 -0
  71. data/test/scout/llm/backends/test_bedrock.rb +60 -0
  72. data/test/scout/llm/backends/test_huggingface.rb +3 -3
  73. data/test/scout/llm/backends/test_ollama.rb +48 -10
  74. data/test/scout/llm/backends/test_openai.rb +134 -10
  75. data/test/scout/llm/backends/test_responses.rb +239 -0
  76. data/test/scout/llm/test_agent.rb +0 -70
  77. data/test/scout/llm/test_ask.rb +4 -1
  78. data/test/scout/llm/test_chat.rb +256 -0
  79. data/test/scout/llm/test_mcp.rb +29 -0
  80. data/test/scout/llm/test_parse.rb +81 -2
  81. data/test/scout/llm/tools/test_call.rb +0 -0
  82. data/test/scout/llm/tools/test_knowledge_base.rb +22 -0
  83. data/test/scout/llm/tools/test_mcp.rb +11 -0
  84. data/test/scout/llm/tools/test_workflow.rb +39 -0
  85. data/test/scout/model/python/huggingface/causal/test_next_token.rb +59 -0
  86. data/test/scout/model/python/huggingface/test_causal.rb +33 -0
  87. data/test/scout/model/python/huggingface/test_classification.rb +30 -0
  88. data/test/scout/model/python/test_base.rb +44 -0
  89. data/test/scout/model/python/test_huggingface.rb +9 -0
  90. data/test/scout/model/python/test_torch.rb +71 -0
  91. data/test/scout/model/python/torch/test_helpers.rb +14 -0
  92. data/test/scout/model/test_base.rb +117 -0
  93. data/test/scout/model/util/test_save.rb +31 -0
  94. metadata +113 -7
  95. data/README.rdoc +0 -18
  96. data/questions/coach +0 -2
@@ -0,0 +1,256 @@
1
+ require File.expand_path(__FILE__).sub(%r(/test/.*), '/test/test_helper.rb')
2
+ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1')
3
+
4
+ class TestMessages < Test::Unit::TestCase
5
+
6
+ def test_short
7
+
8
+ question =<<-EOF
9
+ Hi
10
+ EOF
11
+
12
+ iii LLM.chat(question)
13
+ end
14
+
15
+ def test_inline
16
+ question =<<-EOF
17
+ system:
18
+
19
+ you are a terse assistant that only write in short sentences
20
+
21
+ assistant:
22
+
23
+ Here is some stuff
24
+
25
+ user: feedback
26
+
27
+ that continues here
28
+ EOF
29
+
30
+ iii LLM.chat(question)
31
+ end
32
+
33
+ def test_messages
34
+ question =<<-EOF
35
+ system:
36
+
37
+ you are a terse assistant that only write in short sentences
38
+
39
+ user:
40
+
41
+ What is the capital of France
42
+
43
+ assistant:
44
+
45
+ Paris
46
+
47
+ user:
48
+
49
+ is this the national anthem
50
+
51
+ [[
52
+ corous: Viva Espagna
53
+ ]]
54
+
55
+ assistant:
56
+
57
+ no
58
+
59
+ user:
60
+
61
+ import: math.system
62
+
63
+ consider this file
64
+
65
+ <file name=foo_bar>
66
+ foo: bar
67
+ </file>
68
+
69
+ how many characters does it hold
70
+
71
+ assistant:
72
+
73
+ 8
74
+ EOF
75
+
76
+ messages = LLM.messages question
77
+ refute messages.collect{|i| i[:role] }.include?("corous")
78
+ assert messages.collect{|i| i[:role] }.include?("import")
79
+ end
80
+
81
+ def test_chat_import
82
+ file1 =<<-EOF
83
+ system: You are an assistant
84
+ EOF
85
+
86
+ file2 =<<-EOF
87
+ import: header
88
+ user: say something
89
+ EOF
90
+
91
+ TmpFile.with_path do |tmpdir|
92
+ tmpdir.header.write file1
93
+ tmpdir.chat.write file2
94
+
95
+ chat = LLM.chat tmpdir.chat
96
+ end
97
+ end
98
+
99
+ def test_clear
100
+ question =<<-EOF
101
+ system:
102
+
103
+ you are a terse assistant that only write in short sentences
104
+
105
+ clear:
106
+
107
+ user:
108
+
109
+ What is the capital of France
110
+ EOF
111
+
112
+ TmpFile.with_file question do |file|
113
+ messages = LLM.chat file
114
+ refute messages.collect{|m| m[:role] }.include?('system')
115
+ end
116
+ end
117
+
118
+ def __test_job
119
+ question =<<-EOF
120
+ system:
121
+
122
+ you are a terse assistant that only write in short sentences
123
+
124
+ job: Baking/bake_muffin_tray/Default_08a1812eca3a18dce2232509dabc9b41
125
+
126
+ How are muffins made
127
+
128
+ EOF
129
+
130
+ TmpFile.with_file question do |file|
131
+ messages = LLM.chat file
132
+ ppp LLM.print messages
133
+ end
134
+ end
135
+
136
+
137
+ def test_task
138
+ question =<<-EOF
139
+ system:
140
+
141
+ you are a terse assistant that only write in short sentences
142
+
143
+ user:
144
+
145
+ task: Baking bake_muffin_tray blueberries=true title="This is a title" list=one,two,"and three"
146
+
147
+ How are muffins made?
148
+
149
+ EOF
150
+
151
+ TmpFile.with_file question do |file|
152
+ messages = LLM.chat file
153
+ ppp LLM.print messages
154
+ end
155
+ end
156
+
157
+ def test_structure
158
+ require 'scout/llm/ask'
159
+ sss 0
160
+ question =<<-EOF
161
+ system:
162
+
163
+ Respond in json format with a hash of strings as keys and string arrays as values, at most three in length
164
+
165
+ endpoint: sambanova
166
+
167
+ What other movies have the protagonists of the original gost busters played on, just the top.
168
+
169
+ EOF
170
+
171
+ TmpFile.with_file question do |file|
172
+ ppp LLM.ask file
173
+ end
174
+ end
175
+
176
+ def test_tool
177
+ require 'scout/llm/ask'
178
+
179
+ sss 0
180
+ question =<<-EOF
181
+ user:
182
+
183
+ Use the provided tool to learn the instructions of baking a tray of muffins. Don't
184
+ give me your own recipe, return the one provided by the tool
185
+
186
+ tool: Baking
187
+ EOF
188
+
189
+ TmpFile.with_file question do |file|
190
+ ppp LLM.ask file, endpoint: :nano
191
+ end
192
+ end
193
+
194
+ def test_tools_with_task
195
+ require 'scout/llm/ask'
196
+
197
+ question =<<-EOF
198
+ user:
199
+
200
+ Use the provided tool to learn the instructions of baking a tray of muffins. Don't
201
+ give me your own recipe, return the one provided by the tool
202
+
203
+ tool: Baking bake_muffin_tray
204
+ EOF
205
+
206
+ TmpFile.with_file question do |file|
207
+ ppp LLM.ask file
208
+ end
209
+ end
210
+
211
+ def test_knowledge_base
212
+ require 'scout/llm/ask'
213
+ sss 0
214
+ question =<<-EOF
215
+ system:
216
+
217
+ Query the knowledge base of familiar relationships to answer the question
218
+
219
+ user:
220
+
221
+ Who is Miki's brother in law?
222
+
223
+ association: brothers #{datafile_test(:person).brothers} undirected=true
224
+ association: marriages #{datafile_test(:person).marriages} undirected=true source="=>Alias" target="=>Alias"
225
+ EOF
226
+
227
+ TmpFile.with_file question do |file|
228
+ ppp LLM.ask file
229
+ end
230
+ end
231
+
232
+ def test_previous_response
233
+ require 'scout/llm/ask'
234
+ sss 0
235
+ question =<<-EOF
236
+ user:
237
+
238
+ Say hi
239
+
240
+ assistant:
241
+
242
+ Hi
243
+
244
+ previous_response_id: asdfasdfasdfasdf
245
+
246
+ Bye
247
+
248
+ EOF
249
+
250
+ messages = LLM.messages question
251
+
252
+ iii messages
253
+
254
+ end
255
+ end
256
+
@@ -0,0 +1,29 @@
1
+ require File.expand_path(__FILE__).sub(%r(/test/.*), '/test/test_helper.rb')
2
+ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1')
3
+
4
+ require "scout-ai"
5
+ class TestMCP < Test::Unit::TestCase
6
+ def test_workflow_stdio
7
+ require "mcp/server/transports/stdio_transport"
8
+ wf = Module.new do
9
+ extend Workflow
10
+ self.name = "TestWorkflow"
11
+
12
+ desc "Just say hi to someone"
13
+ input :name, :string, "Name", nil, required: true
14
+ task :hi => :string do |name|
15
+ "Hi #{name}"
16
+ end
17
+
18
+ desc "Just say bye to someone"
19
+ input :name, :string, "Name", nil, required: true
20
+ task :bye => :string do |name|
21
+ "Bye #{name}"
22
+ end
23
+ end
24
+
25
+ transport = MCP::Server::Transports::StdioTransport.new(wf.mcp(:hi))
26
+ transport.open
27
+ end
28
+ end
29
+
@@ -4,8 +4,9 @@ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1
4
4
  class TestLLMParse < Test::Unit::TestCase
5
5
  def test_parse
6
6
  text=<<-EOF
7
+ hi
7
8
  system: you are an asistant
8
- user: Given the contents of this file: [[
9
+ user: Given the contents of this file:[[
9
10
  line 1: 1
10
11
  line 2: 2
11
12
  line 3: 3
@@ -13,7 +14,85 @@ line 3: 3
13
14
  Show me the lines in reverse order
14
15
  EOF
15
16
 
16
- iii LLM.parse(text)
17
+ assert_include LLM.parse(text).first[:content], 'hi'
18
+ assert_include LLM.parse(text).last[:content], 'reverse'
19
+ end
20
+
21
+ def test_code
22
+ text=<<-EOF
23
+ hi
24
+ system: you are an asistant
25
+ user: Given the contents of this file:
26
+ ```yaml
27
+ key: value
28
+ key2: value2
29
+ ```
30
+ Show me the lines in reverse order
31
+ EOF
32
+
33
+ assert_include LLM.parse(text).last[:content], 'key2'
34
+ end
35
+
36
+ def test_lines
37
+ text=<<-EOF
38
+ system: you are an asistant
39
+ user: I have a question
40
+ EOF
41
+
42
+ assert_include LLM.parse(text).last[:content], 'question'
43
+ end
44
+
45
+ def test_blocks
46
+ text=<<-EOF
47
+ system:
48
+
49
+ you are an asistant
50
+
51
+ user:
52
+
53
+ I have a question
54
+
55
+ EOF
56
+
57
+ assert_include LLM.parse(text).last[:content], 'question'
58
+ end
59
+
60
+ def test_no_role
61
+ text=<<-EOF
62
+ I have a question
63
+ EOF
64
+
65
+ assert_include LLM.parse(text).last[:content], 'question'
66
+ end
67
+
68
+
69
+ def test_cmd
70
+ text=<<-EOF
71
+ How many files are there:
72
+
73
+ [[cmd list of files
74
+ echo "file1 file2"
75
+ ]]
76
+ EOF
77
+
78
+ assert_equal :user, LLM.parse(text).last[:role]
79
+ assert_include LLM.parse(text).first[:content], 'file1'
80
+ end
81
+
82
+ def test_directory
83
+ TmpFile.with_path do |tmpdir|
84
+ tmpdir.file1.write "foo"
85
+ tmpdir.file2.write "bar"
86
+ text=<<-EOF
87
+ How many files are there:
88
+
89
+ [[directory DIR
90
+ #{tmpdir}
91
+ ]]
92
+ EOF
93
+
94
+ assert_include LLM.parse(text).first[:content], 'file1'
95
+ end
17
96
  end
18
97
  end
19
98
 
File without changes
@@ -0,0 +1,22 @@
1
+ require File.expand_path(__FILE__).sub(%r(/test/.*), '/test/test_helper.rb')
2
+ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1')
3
+
4
+ class TestLLMToolKB < Test::Unit::TestCase
5
+ def test_knowledbase_definition
6
+ TmpFile.with_dir do |dir|
7
+ kb = KnowledgeBase.new dir
8
+ kb.register :brothers, datafile_test(:person).brothers, undirected: true
9
+ kb.register :parents, datafile_test(:person).parents
10
+
11
+ assert_include kb.all_databases, :brothers
12
+
13
+ assert_equal Person, kb.target_type(:parents)
14
+
15
+ knowledge_base_definition = LLM.knowledge_base_tool_definition(kb)
16
+ ppp JSON.pretty_generate knowledge_base_definition
17
+
18
+ assert_equal ['Isa~Miki', 'Miki~Isa', 'Guille~Clei'], LLM.call_knowledge_base(kb, :brothers, entities: %w(Isa Miki Guille))
19
+ end
20
+ end
21
+ end
22
+
@@ -0,0 +1,11 @@
1
+ require File.expand_path(__FILE__).sub(%r(/test/.*), '/test/test_helper.rb')
2
+ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1')
3
+
4
+ class TestClass < Test::Unit::TestCase
5
+ def test_client
6
+ c = LLM.mcp_tools("https://api.githubcopilot.com/mcp/")
7
+ assert_include c.keys, "get_me"
8
+ end
9
+ end
10
+
11
+
@@ -0,0 +1,39 @@
1
+ require File.expand_path(__FILE__).sub(%r(/test/.*), '/test/test_helper.rb')
2
+ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1')
3
+
4
+ class TestLLMToolWorkflow < Test::Unit::TestCase
5
+ def test_workflow_definition
6
+ m = Module.new do
7
+ extend Workflow
8
+ self.name = "RecipeWorkflow"
9
+
10
+ desc "List the steps to cook a recipe"
11
+ input :recipe, :string, "Recipe for which to extract steps"
12
+ task :recipe_steps => :array do |recipe|
13
+ ["prepare batter", "bake"]
14
+ end
15
+
16
+ desc "Calculate time spent in each step of the recipe"
17
+ input :step, :string, "Cooking step"
18
+ task :step_time => :string do |step|
19
+ case step
20
+ when "prepare batter"
21
+ "2 hours"
22
+ when "bake"
23
+ "30 minutes"
24
+ else
25
+ "1 minute"
26
+ end
27
+ end
28
+ end
29
+
30
+ LLM.task_tool_definition(m, :recipe_steps)
31
+ LLM.task_tool_definition(m, :step_time)
32
+
33
+ tool_definitions = LLM.workflow_tools(m)
34
+ ppp JSON.pretty_generate tool_definitions
35
+
36
+ assert_equal ["prepare batter", "bake"], LLM.call_workflow(m, :recipe_steps)
37
+ end
38
+ end
39
+
@@ -0,0 +1,59 @@
1
+ require File.expand_path(__FILE__).sub(%r(/test/.*), '/test/test_helper.rb')
2
+ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1')
3
+
4
+ require 'scout-ai'
5
+ class TestClass < Test::Unit::TestCase
6
+ def test_main
7
+ model = NextTokenModel.new
8
+ train_texts = [
9
+ "say hi, no!",
10
+ "say hi, no no no",
11
+ "say hi, hi ",
12
+ "say hi, hi how are you ",
13
+ "say hi, hi are you good",
14
+ ]
15
+
16
+ model_name = "distilgpt2" # Replace with your local/other HF Llama checkpoint as needed
17
+
18
+ TmpFile.with_path do |tmp_dir|
19
+ iii tmp_dir
20
+
21
+ sss 0
22
+ model = NextTokenModel.new model_name, tmp_dir, training_num_train_epochs: 1000, training_learning_rate: 0.1
23
+
24
+ iii :new
25
+ chat = Chat.setup []
26
+ chat.user "say hi"
27
+ ppp model.eval chat
28
+
29
+ model.save
30
+ model = PythonModel.new tmp_dir
31
+
32
+ iii :load
33
+ chat = Chat.setup []
34
+ chat.user "say hi"
35
+ ppp model.eval chat
36
+
37
+ iii :training
38
+ state, tokenizer = model.init
39
+ tokenizer.pad_token = tokenizer.eos_token
40
+ model.add_list train_texts.shuffle
41
+ model.train
42
+
43
+ iii :trained
44
+ chat = Chat.setup []
45
+ chat.user "say hi"
46
+ ppp model.eval chat
47
+
48
+ model.save
49
+ model = PythonModel.new tmp_dir
50
+
51
+ iii :load_again
52
+ chat = Chat.setup []
53
+ chat.user "say hi"
54
+ ppp model.eval chat
55
+ end
56
+
57
+ end
58
+ end
59
+
@@ -0,0 +1,33 @@
1
+ require File.expand_path(__FILE__).sub(%r(/test/.*), '/test/test_helper.rb')
2
+ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1')
3
+
4
+ class TestClass < Test::Unit::TestCase
5
+ def test_eval_chat
6
+ #model = CausalModel.new 'BSC-LT/salamandra-2b-instruct'
7
+ model = CausalModel.new 'mistralai/Mistral-7B-Instruct-v0.3'
8
+
9
+ model.init
10
+
11
+ net, tok = model.state
12
+
13
+ iii model.eval([
14
+ {role: :system, content: "You are a calculator, just reply with the answer"},
15
+ {role: :user, content: " 1 + 2 ="}
16
+ ])
17
+ end
18
+
19
+ def test_eval_train
20
+ #model = CausalModel.new 'BSC-LT/salamandra-2b-instruct'
21
+ model = CausalModel.new 'mistralai/Mistral-7B-Instruct-v0.3'
22
+
23
+ model.init
24
+
25
+ net, tok = model.state
26
+
27
+ iii model.eval([
28
+ {role: :system, content: "You are a calculator, just reply with the answer"},
29
+ {role: :user, content: " 1 + 2 ="}
30
+ ])
31
+ end
32
+ end
33
+
@@ -0,0 +1,30 @@
1
+ require File.expand_path(__FILE__).sub(%r(/test/.*), '/test/test_helper.rb')
2
+ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1')
3
+
4
+ class TestSequenceClassification < Test::Unit::TestCase
5
+ def _test_eval_sequence_classification
6
+ model = SequenceClassificationModel.new 'bert-base-uncased', nil,
7
+ class_labels: %w(Bad Good)
8
+
9
+ assert_include ["Bad", "Good"], model.eval("This is dog")
10
+ assert_include ["Bad", "Good"], model.eval_list(["This is dog", "This is cat"]).first
11
+ end
12
+
13
+ def test_train_sequence_classification
14
+ model = SequenceClassificationModel.new 'bert-base-uncased', nil,
15
+ class_labels: %w(Bad Good)
16
+
17
+ model.init
18
+
19
+ 10.times do
20
+ model.add "The dog", 'Bad'
21
+ model.add "The cat", 'Good'
22
+ end
23
+
24
+ model.train
25
+
26
+ assert_equal "Bad", model.eval("This is dog")
27
+ assert_equal "Good", model.eval("This is cat")
28
+ end
29
+ end
30
+
@@ -0,0 +1,44 @@
1
+ require File.expand_path(__FILE__).sub(%r(/test/.*), '/test/test_helper.rb')
2
+ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1')
3
+
4
+ class TestPythonModel < Test::Unit::TestCase
5
+ def test_linear
6
+ model = nil
7
+
8
+ TmpFile.with_path do |dir|
9
+
10
+ dir['model.py'].write <<-EOF
11
+ class TestModel:
12
+ def __init__(self, delta):
13
+ self.delta = delta
14
+
15
+ def eval(self, x):
16
+ return [e + self.delta for e in x]
17
+ EOF
18
+
19
+ model = PythonModel.new dir, 'TestModel', :model, delta: 1
20
+
21
+ model.eval do |sample,list=nil|
22
+ init unless state
23
+ if list
24
+ state.eval(list)
25
+ else
26
+ state.eval([sample])[0]
27
+ end
28
+ end
29
+
30
+ assert_equal 2, model.eval(1)
31
+ assert_equal [4, 6], model.eval_list([3, 5])
32
+
33
+ model.save
34
+
35
+ model = ScoutModel.new dir
36
+
37
+ assert_equal 2, model.eval(1)
38
+
39
+ model = ScoutModel.new dir, delta: 2
40
+
41
+ assert_equal 3, model.eval(1)
42
+ end
43
+ end
44
+ end
@@ -0,0 +1,9 @@
1
+ require File.expand_path(__FILE__).sub(%r(/test/.*), '/test/test_helper.rb')
2
+ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1')
3
+
4
+ class TestHuggingface < Test::Unit::TestCase
5
+ def test_true
6
+ assert true
7
+ end
8
+ end
9
+
@@ -0,0 +1,71 @@
1
+ require File.expand_path(__FILE__).sub(%r(/test/.*), '/test/test_helper.rb')
2
+ require File.expand_path(__FILE__).sub(%r(.*/test/), '').sub(/test_(.*)\.rb/,'\1')
3
+
4
+ class TestTorch < Test::Unit::TestCase
5
+ def test_linear
6
+ model = nil
7
+
8
+ TmpFile.with_dir do |dir|
9
+
10
+ # Create model
11
+
12
+ TorchModel.init_python
13
+
14
+ model = TorchModel.new dir
15
+ model.state = ScoutPython.torch.nn.Linear.new(1, 1)
16
+ model.criterion = ScoutPython.torch.nn.MSELoss.new()
17
+
18
+ model.extract_features do |f|
19
+ [f]
20
+ end
21
+
22
+ model.post_process do |v,list|
23
+ list ? list.collect{|vv| vv.first } : v.first
24
+ end
25
+
26
+ # Train model
27
+
28
+ model.add 5.0, [10.0]
29
+ model.add 10.0, [20.0]
30
+
31
+ model.options[:training_args][:epochs] = 1000
32
+ model.train
33
+
34
+ w = model.get_weights.to_ruby.first.first
35
+
36
+ assert w > 1.8
37
+ assert w < 2.2
38
+
39
+ # Load the model again
40
+
41
+ sss 0
42
+ model.save
43
+
44
+ model = ScoutModel.new dir
45
+
46
+ # Test model
47
+
48
+ y = model.eval_list([100.0, 200.0]).first
49
+
50
+ assert(y > 150.0)
51
+ assert(y < 250.0)
52
+
53
+ y = model.eval(100.0)
54
+
55
+ assert(y > 150.0)
56
+ assert(y < 250.0)
57
+
58
+ test = [1.0, 5.0, 10.0, 20.0]
59
+ input_sum = Misc.sum(test)
60
+ sum = Misc.sum(model.eval_list(test))
61
+ assert sum > 0.8 * input_sum * 2
62
+ assert sum < 1.2 * input_sum * 2
63
+
64
+ w = TorchModel.get_weights(model.state).to_ruby.first.first
65
+
66
+ assert w > 1.8
67
+ assert w < 2.2
68
+ end
69
+ end
70
+ end
71
+