hmm 0.0.2 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/VERSION +1 -1
- data/hmm.gemspec +2 -2
- data/lib/hmm.rb +303 -9
- data/test/test_hmm.rb +48 -14
- metadata +2 -2
    
        data/VERSION
    CHANGED
    
    | @@ -1 +1 @@ | |
| 1 | 
            -
            0.0 | 
| 1 | 
            +
            0.1.0
         | 
    
        data/hmm.gemspec
    CHANGED
    
    | @@ -5,11 +5,11 @@ | |
| 5 5 |  | 
| 6 6 | 
             
            Gem::Specification.new do |s|
         | 
| 7 7 | 
             
              s.name = %q{hmm}
         | 
| 8 | 
            -
              s.version = "0.0 | 
| 8 | 
            +
              s.version = "0.1.0"
         | 
| 9 9 |  | 
| 10 10 | 
             
              s.required_rubygems_version = Gem::Requirement.new(">= 0") if s.respond_to? :required_rubygems_version=
         | 
| 11 11 | 
             
              s.authors = ["David Tresner-Kirsch"]
         | 
| 12 | 
            -
              s.date = %q{2009- | 
| 12 | 
            +
              s.date = %q{2009-12-02}
         | 
| 13 13 | 
             
              s.description = %q{This project is a Ruby gem ('hmm') for machine learning that natively implements a (somewhat) generalized Hidden Markov Model classifier.}
         | 
| 14 14 | 
             
              s.email = %q{dwkirsch@gmail.com}
         | 
| 15 15 | 
             
              s.extra_rdoc_files = [
         | 
    
        data/lib/hmm.rb
    CHANGED
    
    | @@ -5,11 +5,14 @@ | |
| 5 5 | 
             
            #	-computing token level accuracy across a list of observation sequences
         | 
| 6 6 | 
             
            #		against a provided gold standard
         | 
| 7 7 |  | 
| 8 | 
            -
             | 
| 9 8 | 
             
            require 'rubygems'
         | 
| 10 9 | 
             
            require 'narray'
         | 
| 11 10 |  | 
| 11 | 
            +
            class Array; def sum; inject( nil ) { |sum,x| sum ? sum+x : x }; end; end
         | 
| 12 | 
            +
             | 
| 12 13 | 
             
            class HMM
         | 
| 14 | 
            +
            	
         | 
| 15 | 
            +
            	Infinity = 1.0/0
         | 
| 13 16 |  | 
| 14 17 | 
             
            	class Classifier
         | 
| 15 18 | 
             
            		attr_accessor :a, :b, :pi, :o_lex, :q_lex, :debug, :train
         | 
| @@ -47,13 +50,296 @@ class HMM | |
| 47 50 | 
             
            				end
         | 
| 48 51 | 
             
            			end
         | 
| 49 52 |  | 
| 53 | 
            +
            			# smooth to allow unobserved cases
         | 
| 54 | 
            +
            			@pi += 0.1
         | 
| 55 | 
            +
            			@a += 0.1
         | 
| 56 | 
            +
            			@b += 0.1
         | 
| 57 | 
            +
            			
         | 
| 50 58 | 
             
            			# normalize frequencies into probabilities
         | 
| 51 59 | 
             
            			@pi /= @pi.sum
         | 
| 52 60 | 
             
            			@a /= @a.sum(1)
         | 
| 53 61 | 
             
            			@b /= @b.sum(1)
         | 
| 62 | 
            +
            		end	
         | 
| 63 | 
            +
            		
         | 
| 64 | 
            +
            		def train_unsupervised2(sequences)
         | 
| 65 | 
            +
            			# for debugging ONLY
         | 
| 66 | 
            +
            			orig_sequences = sequences.clone
         | 
| 67 | 
            +
            			sequences = [sequences.sum]
         | 
| 68 | 
            +
            			
         | 
| 69 | 
            +
            			# initialize model parameters if we don't already have an estimate
         | 
| 70 | 
            +
            			@pi ||= NArray.float(@q_lex.length).fill(1)/@q_lex.length			
         | 
| 71 | 
            +
            			@a ||= NArray.float(@q_lex.length, @q_lex.length).fill(1)/@q_lex.length
         | 
| 72 | 
            +
            			@b ||= NArray.float(@q_lex.length, @o_lex.length).fill(1)/@q_lex.length
         | 
| 73 | 
            +
            			puts @pi.inspect, @a.inspect, @b.inspect if debug
         | 
| 74 | 
            +
            			
         | 
| 75 | 
            +
            			max_iterations = 1 #1000 #kwargs.get('max_iterations', 1000)
         | 
| 76 | 
            +
            			epsilon = 1e-6 # kwargs.get('convergence_logprob', 1e-6)
         | 
| 77 | 
            +
            			
         | 
| 78 | 
            +
            			max_iterations.times do |iteration|
         | 
| 79 | 
            +
            				puts "iteration ##{iteration}" #if debug
         | 
| 80 | 
            +
            				logprob = 0.0
         | 
| 81 | 
            +
            				
         | 
| 82 | 
            +
            				sequences.each do |sequence|
         | 
| 83 | 
            +
            					# just in case, skip if sequence contains unrecognized tokens
         | 
| 84 | 
            +
            					next unless (sequence-o_lex).empty?
         | 
| 85 | 
            +
            					
         | 
| 86 | 
            +
            					# compute forward and backward probabilities
         | 
| 87 | 
            +
            					alpha = forward_probability(sequence)
         | 
| 88 | 
            +
            					beta = backward_probability(sequence)
         | 
| 89 | 
            +
            					lpk = log_add(alpha[-1, true]) #sum of last alphas. divide by this to get probs
         | 
| 90 | 
            +
            					logprob += lpk
         | 
| 91 | 
            +
            					
         | 
| 92 | 
            +
            					xi = xi(sequence)
         | 
| 93 | 
            +
            					gamma = gamma(xi)
         | 
| 94 | 
            +
            					
         | 
| 95 | 
            +
            					localA = NArray.float(q_lex.length,q_lex.length)
         | 
| 96 | 
            +
            					localB = NArray.float(q_lex.length,o_lex.length)
         | 
| 97 | 
            +
            					
         | 
| 98 | 
            +
            					q_lex.each_index do |i|
         | 
| 99 | 
            +
            						q_lex.each_index do |j|
         | 
| 100 | 
            +
            							numA = -Infinity
         | 
| 101 | 
            +
            							denomA = -Infinity
         | 
| 102 | 
            +
            							sequence.each_index do |t|
         | 
| 103 | 
            +
            								break if t >= sequence.length-1
         | 
| 104 | 
            +
            								numA = log_add([numA, xi[t, i, j]])
         | 
| 105 | 
            +
            								denomA = log_add([denomA, gamma[t, i]])
         | 
| 106 | 
            +
            							end
         | 
| 107 | 
            +
            							localA[i,j] = numA - denomA
         | 
| 108 | 
            +
            						end
         | 
| 109 | 
            +
            						
         | 
| 110 | 
            +
            						o_lex.each_index do |k|
         | 
| 111 | 
            +
            							numB = -Infinity
         | 
| 112 | 
            +
            							denomB = -Infinity
         | 
| 113 | 
            +
            							sequence.each_index do |t|
         | 
| 114 | 
            +
            								break if t >= sequence.length-1
         | 
| 115 | 
            +
            								denomB = log_add([denomB, gamma[t, i]])
         | 
| 116 | 
            +
            								next unless k == index(sequence[t], o_lex)
         | 
| 117 | 
            +
            								numB = log_add([numB, gamma[t, i]])
         | 
| 118 | 
            +
            							end
         | 
| 119 | 
            +
            							localB[i, k] = numB - denomB
         | 
| 120 | 
            +
            						end
         | 
| 121 | 
            +
            						
         | 
| 122 | 
            +
            					end
         | 
| 123 | 
            +
            					
         | 
| 124 | 
            +
            					puts "LogProb: #{logprob}"
         | 
| 125 | 
            +
            					
         | 
| 126 | 
            +
            					@a = localA.collect{|x| Math::E**x}
         | 
| 127 | 
            +
            					@b = localB.collect{|x| Math::E**x}
         | 
| 128 | 
            +
            					#@pi = gamma[0, true] / gamma[0, true].sum
         | 
| 129 | 
            +
            					
         | 
| 130 | 
            +
            				end
         | 
| 131 | 
            +
            			end
         | 
| 132 | 
            +
            		end
         | 
| 133 | 
            +
            		
         | 
| 134 | 
            +
            		
         | 
| 135 | 
            +
            		def train_unsupervised(sequences, max_iterations = 10)
         | 
| 136 | 
            +
            			# initialize model parameters if we don't already have an estimate
         | 
| 137 | 
            +
            			@pi ||= NArray.float(@q_lex.length).fill(1)/@q_lex.length			
         | 
| 138 | 
            +
            			@a ||= NArray.float(@q_lex.length, @q_lex.length).fill(1)/@q_lex.length
         | 
| 139 | 
            +
            			@b ||= NArray.float(@q_lex.length, @o_lex.length).fill(1)/@q_lex.length
         | 
| 140 | 
            +
            			puts @pi.inspect, @a.inspect, @b.inspect if debug
         | 
| 141 | 
            +
            			
         | 
| 142 | 
            +
            			converged = false
         | 
| 143 | 
            +
            			last_logprob = 0
         | 
| 144 | 
            +
            			iteration = 0
         | 
| 145 | 
            +
            			#max_iterations = 10 #1000 #kwargs.get('max_iterations', 1000)
         | 
| 146 | 
            +
            			epsilon = 1e-6 # kwargs.get('convergence_logprob', 1e-6)
         | 
| 147 | 
            +
            			
         | 
| 148 | 
            +
            			max_iterations.times do |iteration|
         | 
| 149 | 
            +
            				puts "iteration ##{iteration}" #if debug
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            				_A_numer = NArray.float(q_lex.length,q_lex.length).fill(-Infinity)
         | 
| 152 | 
            +
            				_B_numer = NArray.float(q_lex.length, o_lex.length).fill(-Infinity)
         | 
| 153 | 
            +
            				_A_denom = NArray.float(q_lex.length).fill(-Infinity)
         | 
| 154 | 
            +
            				_B_denom = NArray.float(q_lex.length).fill(-Infinity)
         | 
| 155 | 
            +
            				_Pi = NArray.float(q_lex.length)
         | 
| 156 | 
            +
            				
         | 
| 157 | 
            +
            				logprob = 0.0
         | 
| 158 | 
            +
            				
         | 
| 159 | 
            +
            				#logprob = last_logprob + 1 # take this out
         | 
| 160 | 
            +
            				
         | 
| 161 | 
            +
            				sequences.each do |sequence|
         | 
| 162 | 
            +
            					# just in case, skip if sequence contains unrecognized tokens
         | 
| 163 | 
            +
            					next unless (sequence-o_lex).empty?
         | 
| 164 | 
            +
            					
         | 
| 165 | 
            +
            					# compute forward and backward probabilities
         | 
| 166 | 
            +
            					alpha = forward_probability(sequence)
         | 
| 167 | 
            +
            					beta = backward_probability(sequence)
         | 
| 168 | 
            +
            					lpk = log_add(alpha[-1, true]) #sum of last alphas. divide by this to get probs
         | 
| 169 | 
            +
            					logprob += lpk
         | 
| 170 | 
            +
            					
         | 
| 171 | 
            +
            					local_A_numer = NArray.float(q_lex.length,q_lex.length).fill(-Infinity)
         | 
| 172 | 
            +
            					local_B_numer = NArray.float(q_lex.length, o_lex.length).fill(-Infinity)
         | 
| 173 | 
            +
            					local_A_denom = NArray.float(q_lex.length).fill(-Infinity)
         | 
| 174 | 
            +
            					local_B_denom = NArray.float(q_lex.length).fill(-Infinity)
         | 
| 175 | 
            +
            					local_Pi = NArray.float(q_lex.length)
         | 
| 176 | 
            +
            					
         | 
| 177 | 
            +
            					sequence.each_with_index do |o, t|
         | 
| 178 | 
            +
            						o_next = index(sequence[t+1], o_lex) if t < sequence.length-1
         | 
| 179 | 
            +
            						
         | 
| 180 | 
            +
            						q_lex.each_index do |i|
         | 
| 181 | 
            +
            							if t < sequence.length-1
         | 
| 182 | 
            +
            								q_lex.each_index do |j|
         | 
| 183 | 
            +
            									local_A_numer[i, j] =  \
         | 
| 184 | 
            +
            										log_add([local_A_numer[i, j], \
         | 
| 185 | 
            +
            										alpha[t, i] + \
         | 
| 186 | 
            +
            											log(@a[i,j]) + \
         | 
| 187 | 
            +
            											log(@b[j,o_next]) + \
         | 
| 188 | 
            +
            											beta[t+1, j]])
         | 
| 189 | 
            +
            								end
         | 
| 190 | 
            +
            								local_A_denom[i] = log_add([local_A_denom[i],
         | 
| 191 | 
            +
            											alpha[t, i] + beta[t, i]])
         | 
| 192 | 
            +
            	
         | 
| 193 | 
            +
            							else
         | 
| 194 | 
            +
            								local_B_denom[i] = log_add([local_A_denom[i],
         | 
| 195 | 
            +
            											alpha[t, i] + beta[t, i]])
         | 
| 196 | 
            +
            							end
         | 
| 197 | 
            +
            							local_B_numer[i, index(o,o_lex)] = log_add([local_B_numer[i, index(o, o_lex)],
         | 
| 198 | 
            +
            								alpha[t, i] + beta[t, i]])
         | 
| 199 | 
            +
             | 
| 200 | 
            +
            						end
         | 
| 201 | 
            +
            						
         | 
| 202 | 
            +
            						puts local_A_numer.inspect if debug
         | 
| 203 | 
            +
            						
         | 
| 204 | 
            +
            						q_lex.each_index do |i|
         | 
| 205 | 
            +
            							q_lex.each_index do |j|
         | 
| 206 | 
            +
            								_A_numer[i, j] = log_add([_A_numer[i, j],
         | 
| 207 | 
            +
            									local_A_numer[i, j] - lpk])
         | 
| 208 | 
            +
            							end
         | 
| 209 | 
            +
            							o_lex.each_index do |k|	
         | 
| 210 | 
            +
            								_B_numer[i, k] = log_add([_B_numer[i, k], local_B_numer[i, k] - lpk])
         | 
| 211 | 
            +
            							end
         | 
| 212 | 
            +
            							_A_denom[i] = log_add([_A_denom[i], local_A_denom[i] - lpk])
         | 
| 213 | 
            +
            							_B_denom[i] = log_add([_B_denom[i], local_B_denom[i] - lpk])
         | 
| 214 | 
            +
            						end
         | 
| 215 | 
            +
            						
         | 
| 216 | 
            +
            					end
         | 
| 217 | 
            +
            				
         | 
| 218 | 
            +
            					puts alpha.collect{|x| Math::E**x}.inspect if debug
         | 
| 219 | 
            +
            				end		
         | 
| 220 | 
            +
             | 
| 221 | 
            +
            				puts _A_denom.inspect if debug
         | 
| 222 | 
            +
             | 
| 223 | 
            +
            				q_lex.each_index do |i|
         | 
| 224 | 
            +
            					q_lex.each_index do |j|
         | 
| 225 | 
            +
            						#puts 2**(_A_numer[i,j] - _A_denom[i]), _A_numer[i,j], _A_denom[i]
         | 
| 226 | 
            +
            						@a[i, j] = Math::E**(_A_numer[i,j] - _A_denom[i])
         | 
| 227 | 
            +
            					end
         | 
| 228 | 
            +
            					o_lex.each_index do |k|
         | 
| 229 | 
            +
            						@b[i, k] = Math::E**(_B_numer[i,k] - _B_denom[i])
         | 
| 230 | 
            +
            					end
         | 
| 231 | 
            +
            					# This comment appears in NLTK:
         | 
| 232 | 
            +
            					# Rabiner says the priors don't need to be updated. I don't
         | 
| 233 | 
            +
            					# believe him. FIXME
         | 
| 234 | 
            +
            				end
         | 
| 235 | 
            +
            					
         | 
| 236 | 
            +
             | 
| 237 | 
            +
            				if iteration > 0 and (logprob - last_logprob).abs < epsilon
         | 
| 238 | 
            +
            					puts "CONVERGED: #{(logprob - last_logprob).abs}" if debug
         | 
| 239 | 
            +
            					puts "epsilon: #{epsilon}" if debug
         | 
| 240 | 
            +
            					break
         | 
| 241 | 
            +
            				end
         | 
| 242 | 
            +
            				
         | 
| 243 | 
            +
            				puts "LogProb: #{logprob}" #if debug
         | 
| 244 | 
            +
            				
         | 
| 245 | 
            +
            				last_logprob = logprob
         | 
| 246 | 
            +
            			end
         | 
| 247 | 
            +
            		end
         | 
| 248 | 
            +
            		
         | 
| 249 | 
            +
            		def xi(sequence)
         | 
| 250 | 
            +
            			xi = NArray.float(sequence.length-1, q_lex.length, q_lex.length)
         | 
| 251 | 
            +
            			
         | 
| 252 | 
            +
            			alpha = forward_probability(sequence)
         | 
| 253 | 
            +
            			beta = backward_probability(sequence)
         | 
| 254 | 
            +
            			
         | 
| 255 | 
            +
            			0.upto sequence.length-2 do |t|
         | 
| 256 | 
            +
            				denom = 0
         | 
| 257 | 
            +
            				q_lex.each_index do |i|
         | 
| 258 | 
            +
            					q_lex.each_index do |j|
         | 
| 259 | 
            +
            						x = alpha[t, i] + log(@a[i,j]) + \
         | 
| 260 | 
            +
            							log(@b[j,index(sequence[t+1], o_lex)]) + \
         | 
| 261 | 
            +
            							beta[t+1, j]
         | 
| 262 | 
            +
            						denom = log_add([denom, x])
         | 
| 263 | 
            +
            					end
         | 
| 264 | 
            +
            				end
         | 
| 265 | 
            +
            				
         | 
| 266 | 
            +
            				q_lex.each_index do |i|
         | 
| 267 | 
            +
            					q_lex.each_index do |j|
         | 
| 268 | 
            +
            						numer = alpha[t, i] + log(@a[i,j]) + \
         | 
| 269 | 
            +
            							log(@b[j,index(sequence[t+1], o_lex)]) + \
         | 
| 270 | 
            +
            							beta[t+1, j]
         | 
| 271 | 
            +
            						xi[t, i, j] = numer - denom
         | 
| 272 | 
            +
            					end
         | 
| 273 | 
            +
            				end
         | 
| 274 | 
            +
            			end
         | 
| 275 | 
            +
            			
         | 
| 276 | 
            +
            			puts "Xi: #{xi.inspect}" if debug
         | 
| 277 | 
            +
            			xi
         | 
| 278 | 
            +
            		end
         | 
| 279 | 
            +
            		
         | 
| 280 | 
            +
            		def gamma(xi)
         | 
| 281 | 
            +
            			gamma = NArray.float(xi.shape[0], xi.shape[1]).fill(-Infinity)
         | 
| 282 | 
            +
            			
         | 
| 283 | 
            +
            			0.upto gamma.shape[0] - 1 do |t|
         | 
| 284 | 
            +
            				q_lex.each_index do |i|
         | 
| 285 | 
            +
            					q_lex.each_index do |j|
         | 
| 286 | 
            +
            						gamma[t, i] = log_add([gamma[t, i], xi[t, i, j]])
         | 
| 287 | 
            +
            					end
         | 
| 288 | 
            +
            				end
         | 
| 289 | 
            +
            			end
         | 
| 290 | 
            +
            			
         | 
| 291 | 
            +
            			puts "Gamma: #{gamma.inspect}" if debug
         | 
| 292 | 
            +
            			gamma
         | 
| 293 | 
            +
            		end
         | 
| 294 | 
            +
            		
         | 
| 295 | 
            +
            		def forward_probability(sequence)
         | 
| 296 | 
            +
            			alpha = NArray.float(sequence.length, q_lex.length).fill(-Infinity)
         | 
| 297 | 
            +
            			
         | 
| 298 | 
            +
            			alpha[0, true] = log(@pi) + log(@b[true, index(sequence.first, o_lex)])
         | 
| 299 | 
            +
            			
         | 
| 300 | 
            +
            			sequence.each_with_index do |o, t|
         | 
| 301 | 
            +
            				next if t==0
         | 
| 302 | 
            +
            				q_lex.each_index do |i|
         | 
| 303 | 
            +
            					q_lex.each_index do |j|
         | 
| 304 | 
            +
            						alpha[t, i] = log_add([alpha[t, i], alpha[t-1, j]+log(@a[j, i])])
         | 
| 305 | 
            +
            					end
         | 
| 306 | 
            +
            					alpha[t, i] += log(b[i, index(o, o_lex)])
         | 
| 307 | 
            +
            				end
         | 
| 308 | 
            +
            			end
         | 
| 309 | 
            +
            			alpha
         | 
| 54 310 | 
             
            		end
         | 
| 55 | 
            -
             | 
| 311 | 
            +
            		
         | 
| 312 | 
            +
            		def log_add(values)
         | 
| 313 | 
            +
            			x = values.max
         | 
| 314 | 
            +
            			if x > -Infinity
         | 
| 315 | 
            +
            				sum_diffs = 0
         | 
| 316 | 
            +
            				values.each do |value|
         | 
| 317 | 
            +
            					sum_diffs += Math::E**(value - x)
         | 
| 318 | 
            +
            				end
         | 
| 319 | 
            +
            				return x + log(sum_diffs)
         | 
| 320 | 
            +
            			else
         | 
| 321 | 
            +
            				return x
         | 
| 322 | 
            +
            			end
         | 
| 323 | 
            +
            		end
         | 
| 324 | 
            +
            		
         | 
| 325 | 
            +
            		def backward_probability(sequence)
         | 
| 326 | 
            +
            			beta = NArray.float(sequence.length, q_lex.length).fill(-Infinity)
         | 
| 327 | 
            +
            			
         | 
| 328 | 
            +
            			beta[-1, true] = log(1)
         | 
| 329 | 
            +
            			
         | 
| 330 | 
            +
            			(sequence.length-2).downto(0) do |t|
         | 
| 331 | 
            +
            				q_lex.each_index do |i|
         | 
| 332 | 
            +
            					q_lex.each_index do |j|
         | 
| 333 | 
            +
            						beta[t, i] = log_add([beta[t,i], log(@a[i, j]) \
         | 
| 334 | 
            +
            							+ log(@b[j, index(sequence[t+1], o_lex)]) \
         | 
| 335 | 
            +
            							+ beta[t+1, j]])
         | 
| 336 | 
            +
            					end
         | 
| 337 | 
            +
            				end
         | 
| 338 | 
            +
            			end
         | 
| 56 339 |  | 
| 340 | 
            +
            			beta
         | 
| 341 | 
            +
            		end
         | 
| 342 | 
            +
            		
         | 
| 57 343 | 
             
            		def decode(o_sequence)
         | 
| 58 344 | 
             
            			# Viterbi!  with log probability math to avoid underflow
         | 
| 59 345 |  | 
| @@ -98,19 +384,27 @@ class HMM | |
| 98 384 |  | 
| 99 385 | 
             
            		# index and deindex map between labels and the ordinals of those labels.
         | 
| 100 386 | 
             
            		# the ordinals map the labels to rows and columns of Pi, A, and B
         | 
| 101 | 
            -
            		def index( | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 387 | 
            +
            		def index(subject, lexicon)
         | 
| 388 | 
            +
            			if subject.is_a?(Array) or subject.is_a?(NArray)
         | 
| 389 | 
            +
            				return subject.collect{|x| lexicon.rindex(x)}
         | 
| 390 | 
            +
            			else
         | 
| 391 | 
            +
            				return index(Array[subject], lexicon)[0]
         | 
| 392 | 
            +
            			end
         | 
| 104 393 | 
             
            		end
         | 
| 105 394 |  | 
| 395 | 
            +
            		#private
         | 
| 396 | 
            +
            		
         | 
| 106 397 | 
             
            		def deindex(sequence, lexicon)
         | 
| 107 398 | 
             
            			sequence.collect{|i| lexicon[i]}
         | 
| 108 399 | 
             
            		end
         | 
| 109 400 |  | 
| 110 401 | 
             
            		# abstracting out some array element operations for readability
         | 
| 111 | 
            -
            		def log( | 
| 112 | 
            -
            			 | 
| 113 | 
            -
             | 
| 402 | 
            +
            		def log(subject)
         | 
| 403 | 
            +
            			if subject.is_a?(Array) or subject.is_a?(NArray)
         | 
| 404 | 
            +
            				return subject.collect{|n| NMath::log n}
         | 
| 405 | 
            +
            			else
         | 
| 406 | 
            +
            				return log(Array[subject])[0]
         | 
| 407 | 
            +
            			end
         | 
| 114 408 | 
             
            		end
         | 
| 115 409 |  | 
| 116 410 | 
             
            		def exp(array)
         | 
| @@ -132,4 +426,4 @@ class HMM | |
| 132 426 | 
             
            	      	  @o, @q = o, q
         | 
| 133 427 | 
             
            	      end
         | 
| 134 428 | 
             
            	end
         | 
| 135 | 
            -
            end
         | 
| 429 | 
            +
            end
         | 
    
        data/test/test_hmm.rb
    CHANGED
    
    | @@ -1,28 +1,62 @@ | |
| 1 1 | 
             
            require 'helper'
         | 
| 2 | 
            +
            require 'narray'
         | 
| 2 3 |  | 
| 3 4 | 
             
            class TestHmm < Test::Unit::TestCase
         | 
| 4 | 
            -
            	 | 
| 5 | 
            -
            		 | 
| 6 | 
            -
            		 | 
| 7 | 
            -
            	end
         | 
| 8 | 
            -
            	
         | 
| 9 | 
            -
            	should "decode using hand-built model" do
         | 
| 10 | 
            -
            		model = HMM::Classifier.new
         | 
| 11 | 
            -
             | 
| 5 | 
            +
            	def setup
         | 
| 6 | 
            +
            		@simple_model = HMM::Classifier.new
         | 
| 7 | 
            +
            		
         | 
| 12 8 | 
             
            		# manually build a classifier
         | 
| 13 | 
            -
            		 | 
| 14 | 
            -
            		 | 
| 15 | 
            -
            		 | 
| 9 | 
            +
            		@simple_model.o_lex = ["A", "B"]
         | 
| 10 | 
            +
            		@simple_model.q_lex = ["X", "Y", "Z"]
         | 
| 11 | 
            +
            		@simple_model.a = NArray[[0.8, 0.1, 0.1],
         | 
| 16 12 | 
             
            					[0.2, 0.5, 0.3],
         | 
| 17 13 | 
             
            					[0.9, 0.1, 0.0]].transpose(1,0)
         | 
| 18 | 
            -
            		 | 
| 14 | 
            +
            		@simple_model.b = NArray[ [0.2, 0.8],
         | 
| 19 15 | 
             
            					[0.7, 0.3],
         | 
| 20 16 | 
             
            					[0.9, 0.1]].transpose(1,0)
         | 
| 21 | 
            -
            		 | 
| 17 | 
            +
            		@simple_model.pi = NArray[0.5, 0.3, 0.2]
         | 
| 22 18 |  | 
| 19 | 
            +
            	end
         | 
| 20 | 
            +
            	
         | 
| 21 | 
            +
            	should "create new classifier" do
         | 
| 22 | 
            +
            		model = HMM::Classifier.new
         | 
| 23 | 
            +
            		assert model.class == HMM::Classifier
         | 
| 24 | 
            +
            	end
         | 
| 25 | 
            +
            	
         | 
| 26 | 
            +
            	should "decode using hand-built model" do
         | 
| 23 27 | 
             
            		# apply classifier to a sample observation string
         | 
| 24 | 
            -
            		q_star =  | 
| 28 | 
            +
            		q_star = @simple_model.decode(["A","B","A"])
         | 
| 25 29 | 
             
            		assert q_star == ["Z", "X", "X"]
         | 
| 26 30 | 
             
            	end
         | 
| 27 31 |  | 
| 32 | 
            +
            	should "compute forward probabilities" do
         | 
| 33 | 
            +
            		expected_alpha = NArray[ [ 0.1, 0.2272, 0.039262 ], 
         | 
| 34 | 
            +
            						[ 0.21, 0.0399, 0.03038 ], 
         | 
| 35 | 
            +
            						[ 0.18, 0.0073, 0.031221 ] ]
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            		assert close_enough(expected_alpha, \
         | 
| 38 | 
            +
            			@simple_model.forward_probability(["A","B","A"]).collect{|x| Math::E**x})
         | 
| 39 | 
            +
            	end
         | 
| 40 | 
            +
            		
         | 
| 41 | 
            +
            	should "compute backward probabilities" do
         | 
| 42 | 
            +
            		expected_beta = NArray[ [ 0.2271, 0.32, 1.0 ], 
         | 
| 43 | 
            +
            						[ 0.1577, 0.66, 1.0 ], 
         | 
| 44 | 
            +
            						[ 0.2502, 0.25, 1.0 ] ]
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            		assert close_enough(expected_beta, \
         | 
| 47 | 
            +
            			@simple_model.backward_probability(["A","B","A"]).collect{|x| Math::E**x})
         | 
| 48 | 
            +
            	end
         | 
| 49 | 
            +
            		
         | 
| 50 | 
            +
            	should "compute xi" do
         | 
| 51 | 
            +
            		@simple_model.gamma(@simple_model.xi(["A","B","A"]))
         | 
| 52 | 
            +
            	end
         | 
| 53 | 
            +
            		
         | 
| 54 | 
            +
            	
         | 
| 55 | 
            +
            	
         | 
| 56 | 
            +
            	def close_enough(a, b)
         | 
| 57 | 
            +
            		# since we're dealing with some irrational values from logs, some checks
         | 
| 58 | 
            +
            		# need to be "good enough" rather than a perfect ==
         | 
| 59 | 
            +
            		(a-b).abs < 1e-10
         | 
| 60 | 
            +
            	end
         | 
| 61 | 
            +
             | 
| 28 62 | 
             
            end
         | 
    
        metadata
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 1 | 
             
            --- !ruby/object:Gem::Specification 
         | 
| 2 2 | 
             
            name: hmm
         | 
| 3 3 | 
             
            version: !ruby/object:Gem::Version 
         | 
| 4 | 
            -
              version: 0.0 | 
| 4 | 
            +
              version: 0.1.0
         | 
| 5 5 | 
             
            platform: ruby
         | 
| 6 6 | 
             
            authors: 
         | 
| 7 7 | 
             
            - David Tresner-Kirsch
         | 
| @@ -9,7 +9,7 @@ autorequire: | |
| 9 9 | 
             
            bindir: bin
         | 
| 10 10 | 
             
            cert_chain: []
         | 
| 11 11 |  | 
| 12 | 
            -
            date: 2009- | 
| 12 | 
            +
            date: 2009-12-02 00:00:00 -05:00
         | 
| 13 13 | 
             
            default_executable: 
         | 
| 14 14 | 
             
            dependencies: 
         | 
| 15 15 | 
             
            - !ruby/object:Gem::Dependency 
         |