pygms 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. pygms-0.4.0/.gitignore +71 -0
  2. pygms-0.4.0/PKG-INFO +45 -0
  3. pygms-0.4.0/README.md +28 -0
  4. pygms-0.4.0/pygms/.RData +0 -0
  5. pygms-0.4.0/pygms/.Rhistory +8 -0
  6. pygms-0.4.0/pygms/.swo +0 -0
  7. pygms-0.4.0/pygms/__init__.py +28 -0
  8. pygms-0.4.0/pygms/_notes.txt +130 -0
  9. pygms-0.4.0/pygms/_profile.py +40 -0
  10. pygms-0.4.0/pygms/_rbm.py +321 -0
  11. pygms-0.4.0/pygms/_todo.txt +131 -0
  12. pygms-0.4.0/pygms/causal.py +42 -0
  13. pygms-0.4.0/pygms/csp.py +385 -0
  14. pygms-0.4.0/pygms/data/__init__.py +23 -0
  15. pygms-0.4.0/pygms/data/__init__.py.bak +23 -0
  16. pygms-0.4.0/pygms/data/catalog.py +234 -0
  17. pygms-0.4.0/pygms/data/catalog.py.bak +234 -0
  18. pygms-0.4.0/pygms/data/sources.json +9 -0
  19. pygms-0.4.0/pygms/decisions.py +199 -0
  20. pygms-0.4.0/pygms/development.py +39 -0
  21. pygms-0.4.0/pygms/draw.py +366 -0
  22. pygms-0.4.0/pygms/factor.py +15 -0
  23. pygms-0.4.0/pygms/factorGauss.py +344 -0
  24. pygms-0.4.0/pygms/factorNumpy.py +718 -0
  25. pygms-0.4.0/pygms/factorSparse.py +698 -0
  26. pygms-0.4.0/pygms/factorTorch.py +758 -0
  27. pygms-0.4.0/pygms/filetypes.py +771 -0
  28. pygms-0.4.0/pygms/graphmodel.py +798 -0
  29. pygms-0.4.0/pygms/indexedheap.py +109 -0
  30. pygms-0.4.0/pygms/ising.py +698 -0
  31. pygms-0.4.0/pygms/jupyter.py +223 -0
  32. pygms-0.4.0/pygms/learning.py +246 -0
  33. pygms-0.4.0/pygms/messagepass.py +226 -0
  34. pygms-0.4.0/pygms/misc.py +358 -0
  35. pygms-0.4.0/pygms/montecarlo.py +513 -0
  36. pygms-0.4.0/pygms/regiongraph.py +223 -0
  37. pygms-0.4.0/pygms/search.py +282 -0
  38. pygms-0.4.0/pygms/search1.py +680 -0
  39. pygms-0.4.0/pygms/search2.py +324 -0
  40. pygms-0.4.0/pygms/search_sum.py +402 -0
  41. pygms-0.4.0/pygms/varset_py.py +86 -0
  42. pygms-0.4.0/pygms/varset_py2.py +152 -0
  43. pygms-0.4.0/pygms/weighted.py +380 -0
  44. pygms-0.4.0/pygms/wmb.py +733 -0
  45. pygms-0.4.0/pygms/wogm.py +312 -0
  46. pygms-0.4.0/pyproject.toml +38 -0
pygms-0.4.0/.gitignore ADDED
@@ -0,0 +1,71 @@
1
+ notes.txt
2
+ *.swp
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ env/
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *,cover
49
+ .hypothesis/
50
+
51
+ # Translations
52
+ *.mo
53
+ *.pot
54
+
55
+ # Django stuff:
56
+ *.log
57
+
58
+ # Sphinx documentation
59
+ docs/_build/
60
+
61
+ # PyBuilder
62
+ target/
63
+
64
+ #Ipython Notebook
65
+ .ipynb_checkpoints
66
+
67
+ #VSCode
68
+ .vscode
69
+
70
+ # Mac osX meta data
71
+ .DS_Store
pygms-0.4.0/PKG-INFO ADDED
@@ -0,0 +1,45 @@
1
+ Metadata-Version: 2.4
2
+ Name: pygms
3
+ Version: 0.4.0
4
+ Summary: Python Graphical Models Toolbox
5
+ Author-email: Alexander Ihler <ihler@ics.uci.edu>
6
+ License-Expression: BSD-2-Clause
7
+ Classifier: License :: OSI Approved :: BSD License
8
+ Classifier: Operating System :: OS Independent
9
+ Classifier: Programming Language :: Python :: 3
10
+ Requires-Python: >=3.6
11
+ Requires-Dist: matplotlib>=3.2
12
+ Requires-Dist: networkx>=2.5
13
+ Requires-Dist: numpy>=1.18
14
+ Requires-Dist: scipy>=1.4
15
+ Requires-Dist: sortedcontainers>=1.5.7
16
+ Description-Content-Type: text/markdown
17
+
18
+ pyGMs : A Python toolbox for Graphical Models
19
+ ================
20
+
21
+ This code provides a simple Python-based interface for defining probabilistic
22
+ graphical models (Bayesian networks, factor graphs, etc.) over discrete random
23
+ variables, along with a number of routines for approximate inference. It is
24
+ being developed for use in teaching, as well as prototyping for research.
25
+
26
+ The code currently uses [NumPy](http://www.numpy.org/) for representing and
27
+ operating on the table-based representation of discrete factors, and
28
+ [SortedContainers](https://pypi.python.org/pypi/sortedcontainers) for some
29
+ internal representations. Smaller portions use [networkx](https://networkx.org/)
30
+ and [scipy](https://www.scipy.org/) as well.
31
+
32
+ ## Installation
33
+
34
+ Simply download or clone the repository to a directory *pyGMs*, and add its
35
+ parent directory to your Python path, either:
36
+ ```
37
+ $ export PYTHONPATH=${PYTHONPATH}:/directory/containing/
38
+ ```
39
+ or in Python
40
+ ```
41
+ import sys
42
+ sys.path.append('/directory/containing/')
43
+ ```
44
+
45
+
pygms-0.4.0/README.md ADDED
@@ -0,0 +1,28 @@
1
+ pyGMs : A Python toolbox for Graphical Models
2
+ ================
3
+
4
+ This code provides a simple Python-based interface for defining probabilistic
5
+ graphical models (Bayesian networks, factor graphs, etc.) over discrete random
6
+ variables, along with a number of routines for approximate inference. It is
7
+ being developed for use in teaching, as well as prototyping for research.
8
+
9
+ The code currently uses [NumPy](http://www.numpy.org/) for representing and
10
+ operating on the table-based representation of discrete factors, and
11
+ [SortedContainers](https://pypi.python.org/pypi/sortedcontainers) for some
12
+ internal representations. Smaller portions use [networkx](https://networkx.org/)
13
+ and [scipy](https://www.scipy.org/) as well.
14
+
15
+ ## Installation
16
+
17
+ Simply download or clone the repository to a directory *pyGMs*, and add its
18
+ parent directory to your Python path, either:
19
+ ```
20
+ $ export PYTHONPATH=${PYTHONPATH}:/directory/containing/
21
+ ```
22
+ or in Python
23
+ ```
24
+ import sys
25
+ sys.path.append('/directory/containing/')
26
+ ```
27
+
28
+
Binary file
@@ -0,0 +1,8 @@
1
+ suppressMessages(library('igraph'))
2
+ suppressMessages(library('causaleffect'))
3
+ s <- graph.formula(W -+ X,X -+ W,W -+ Y,Y -+ W,W -+ R,R -+ X,X -+ Y, simplify=FALSE)
4
+ s <- set_edge_attr(s, 'description', 1:4, 'U')
5
+
6
+ causal.effect('Y', c("X"), G = s)
7
+
8
+ quit()
pygms-0.4.0/pygms/.swo ADDED
Binary file
@@ -0,0 +1,28 @@
1
+ """
2
+ pyGMs: Python Graphical Model code
3
+
4
+ A simple graphical model class for learning about, testing, and developing algorithms
5
+ for graphical models.
6
+
7
+ Version 0.4.0 (2026-03-31)
8
+
9
+ (c) 2015-2026 Alexander Ihler under the FreeBSD license; see license.txt for details.
10
+ """
11
+
12
+ from sortedcontainers import SortedSet as sset;
13
+
14
+ from .factor import *
15
+ #from .factorSparse import *
16
+ from .graphmodel import *
17
+ from .filetypes import *
18
+ from .misc import *
19
+ from .draw import *
20
+
21
+
22
+ __title__ = 'pygms'
23
+ __version__ = '0.4.0'
24
+ __author__ = 'Alexander Ihler'
25
+ __license__ = 'BSD'
26
+ __copyright__ = '2015-2026, Alexander Ihler'
27
+
28
+
@@ -0,0 +1,130 @@
1
+
2
+ General principles:
3
+
4
+ (1) GraphModel & its inheritors are containers for collections of factors (funtions over a few variables each)
5
+ (2) Variables are assumed to be X0...Xn, although some may be unused. When the dimension of Xi is known/not required,
6
+ the index itself can be used instead.
7
+ (3) Configurations (x0...xn) can be represented in different ways, depending on the circumstances:
8
+ * A map {0:x0, 1:x1, ...}, with unspecified values of a partial configuration left out
9
+ * An nparray, tuple, or list [x0,x1,...xn], with missing values (if allowed) set to NaN
10
+ * A collection of data X=[xa,xb,xc...] with xa=[xa0,...xan], so that X[i] is the ith data point
11
+ * This can be a list of lists or tuples, or a 2D numpy array
12
+ * We can detect whether a single data point or multiple are specified using try: next(iter(next(iter(X))))
13
+ * Some functions only accept single data (should check & error), others expect multiple, or can take either.
14
+ (4) All data are expected to take values 0...d-1, for discrete variables with d states, even e.g. "Ising" models.
15
+ (5) "isLog" indicates whether the model consists of log-factors: G(x) = \sum_a g_a(x_a),
16
+ or exp-factors: F(x) = \sum_a f_a(x_a). Functions which make use of the joint value (log-likelihood, etc.)
17
+ or multiple factors in combination (e.g., variable elimination) use this in computing their quantities.
18
+ The model representation may be switched using "toLog" and "toExp".
19
+ (6) Sampling functions should return config-value pairs: x,lnq where x \sim q(X) and lnq = log(q(x)).
20
+ (7) Training functions take the form "fit_method" for methods that estimate graph structure, and "refit_method"
21
+ for methods that preserve the current factors/cliques and simply update the parameters.
22
+ (8) By default, GraphModel makes a copy of each factor during construction. This is because some functions
23
+ (such as refit or reparameterization inference) may change the values of the factor tables. If this is not
24
+ desired, "gm = GraphModel(factors, copy=False)" will use references to the factors (useful for memory sharing,
25
+ for example), but set gm.lock = True to indicate that functions should not alter these factors.
26
+ The function "gm.copy()" returns a new copy of the model with non-const factors.
27
+ (9) Some functions (conditioning, manual factor additions, etc.) may alter the structure / cliques of the model as
28
+ well. Iterative methods (message passing inference, etc.) can check "gm.sig" to make sure the structure has
29
+ not changed between iterations to ensure validity, and may raise an exception if it does.
30
+ (TODO: add another lock flag for structure; functions should check before altering structure.)
31
+
32
+
33
+
34
+
35
+
36
+
37
+
38
+
39
+
40
+ ===============================================================================
41
+
42
+ *** Version that can use torch tensors?
43
+ * Allow for "direct" optimization of various quantities?
44
+
45
+
46
+ *** Data representation !!!
47
+ * Want x[i] to be data point i?
48
+ * Want to be able to pass lists of tuples, or a single tuple? (No single tuples?) (Convert to arrays?)
49
+ * If "not 2D" convert before operation? (Note: some operations can only take single tuples)
50
+ * Also want to be able to access X[:,j] = {x_j for all data}
51
+ * Want consistent with standard torch forms
52
+
53
+
54
+
55
+ *** Exact inference
56
+ * Sensitivity analysis?
57
+ * Exp family form: dLogZ/dTheta = E[X]
58
+ * "BN" form: specify p(A=a|E=e) and x=p(Xi=xi|Xpai=xpai) ?
59
+ * all-to-one version requires BN form and bnOrder, and "one" p(A=a|E=e)
60
+ * one-to-all version requires p(Xi=xi|Xpai=xpai); BN can be unnormalized
61
+ * BN specialized functions? Evidence pruning?
62
+ * CSP specialized functions?
63
+
64
+
65
+ *** Sampling
66
+ * sampling function returns config, logP = sample( [partial?] )
67
+ * MCMC returns config, logP = sample( startcfg )
68
+ *** Maybe make a generator object? "yield"?
69
+ * Method to "aggregate" samples? ("Query" object...) => generic "estimate marginal" f'ns, etc?
70
+ * Useful for repeated / cts improvement inference (save state)
71
+ * QueryMarginals (list cliques); QueryExpectations (list f'ns); QueryHistory (log all)
72
+
73
+ * Basic Gibbs (two versions)
74
+ ** Structured gibbs? Sets to sample; generate conditioned sub-models; VE-sample? (In-place slices: efficient?)
75
+ * MH: any common proposals?
76
+ TODO?
77
+ * Importance sampling (several? WMB, Tree BP, MF?)
78
+ * Annealed IS
79
+ * Estimators? Discriminance sampling?
80
+
81
+ *** Search
82
+ * Pseudo-tree object
83
+ * Heuristic f'n: takes partial config, returns cost-to-go of internal PTree given config
84
+ * Next variable f'n: given partial config, what are the next vars to condition on?
85
+ * Node priority f'n: when adding node, when should we re-examine?
86
+ * Nodes link to heuristic, priority f'n so they can modify / be dynamic? PFunc can use heuristic?
87
+ * Search algos:
88
+ * Basic DFS
89
+ * A/O DFS? (Rotating?)
90
+ * Best-First / A*
91
+ * A/O A*? Mem limited A*?
92
+
93
+
94
+
95
+ *** Variational
96
+ * Various forms of DD? (SoftArc done/easy; others?)
97
+ * Basic BP algo? Fancier (scheduling, etc)? Region models?
98
+ * NMF done; structured MF?
99
+ * WMB:
100
+ * Basic incremental MB: build; msg pass, merge; msg pass, merge; ...
101
+ * Sholeh algorithm?
102
+
103
+ ??? Iterative algorithms, use yield or other structures to run iterations?
104
+ * Can verify no structural changes, etc;
105
+ * Reparameterization vs message forms
106
+ ** "General" outer loop that checks for timeout, various convergence conditions? Or make for each algo?
107
+
108
+
109
+ *** Learning
110
+ *CRF representation? (More general factor representations? ?)
111
+ * EM? Stochastic EM?
112
+
113
+
114
+ *** Structure learning
115
+ * Several Ising methods (clean up)
116
+ * Independence tests
117
+ * Non-ising group lasso, etc.
118
+ * BN stochastic search
119
+ * BN ILP method
120
+ * ...
121
+
122
+
123
+ *** Special models: Ising models (clean?), RBM/CRBM,
124
+
125
+ *** Other people's algorithms?
126
+ * Vibhav's: sample, assemble sparse JTree until memory limit, solve
127
+ * Need to "decompose" proposal along graph structure?
128
+ * Maua's: solve preserving lists of factors / configs?
129
+ * Model conversions? Binary, pairwise, etc?
130
+
@@ -0,0 +1,40 @@
1
+ ## Useful code for profiling execution speed, etc.
2
+
3
+ try:
4
+ from line_profiler import LineProfiler
5
+
6
+ def do_profile(follow=[]):
7
+ def inner(func):
8
+ def profiled_func(*args, **kwargs):
9
+ try:
10
+ profiler = LineProfiler()
11
+ profiler.add_function(func)
12
+ for f in follow:
13
+ profiler.add_function(f)
14
+ profiler.enable_by_count()
15
+ return func(*args, **kwargs)
16
+ finally:
17
+ profiler.print_stats()
18
+ return profiled_func
19
+ return inner
20
+
21
+ except ImportError:
22
+ def do_profile(follow=[]):
23
+ "Helpful if you accidentally leave in production!"
24
+ def inner(func):
25
+ def nothing(*args, **kwargs):
26
+ return func(*args, **kwargs)
27
+ return nothing
28
+ return inner
29
+
30
+ def get_number():
31
+ for x in xrange(5000000):
32
+ yield x
33
+
34
+
35
+ # To profile the function, decorate e.g.:
36
+
37
+ #@do_profile(follow=[get_number])
38
+ #def __init__(self,model,elimOrder,force_or=False,max_width=None):
39
+
40
+
@@ -0,0 +1,321 @@
1
+ import numpy as np
2
+ import copy
3
+
4
+ from .base import classifier
5
+ from .base import regressor
6
+ from .utils import toIndex, fromIndex, to1ofK, from1ofK
7
+ from numpy import asarray as arr
8
+ from numpy import atleast_2d as twod
9
+ from numpy import asmatrix as mat
10
+
11
+ from scipy.special import expit
12
+ import matplotlib.pyplot as plt
13
+
14
+
15
+ # TODO:
16
+ # (1) Gibbs / CD variants (multiple chains, averaging, rao-blackwell estimates
17
+ # (2) logZ estimates: brute force enumeration, AIS estimate, loopyBP estimate, others?
18
+ # (3) loss functions: reconstruction errors (mse, logP, etc.); (approx) data likelihood; FE difference; others?
19
+ # (4) helpers? logsumexp? see wei & ruslan's code?
20
+ # (5) deep versions; variable numbers of layers? (specialize first)
21
+ # - simple if use functions taking Wvh, bv+Wvx*X, bh+Whx*X? each layer has Wlx, bl terms, + Wll' terms?
22
+
23
+ ################################################################################
24
+ ## BASIC RBM ################################################################
25
+ ################################################################################
26
+
27
+ def _add1(X):
28
+ return np.hstack( (np.ones((X.shape[0],1)),X) )
29
+
30
+ def _sigma(z):
31
+ return expit(z);
32
+ #return 1.0/(1.0+np.exp(-z))
33
+
34
+
35
+ class crbm(object):
36
+ """A restricted Boltzmann machine
37
+
38
+ Attributes:
39
+
40
+ """
41
+
42
+ def __init__(self, nV,nH,nX, Wvh=None,bh=None,bv=None, Wvx=None,Whx=None):
43
+ """Constructor for a (conditional) restricted Boltzmann machine
44
+ Parameters:
45
+ nV : # of visible nodes (observable data)
46
+ nH : # of hidden nodes (latent variables)
47
+ nX : # of always-observed conditioning variables
48
+ Wvh, Wvx, Whx : pairwise weights (default: initialize randomly)
49
+ bh,bv : bias parameters (default: initialize to zero)
50
+ """
51
+ if Wvh is None: Wvh = np.random.rand(nV,nH) * .001
52
+ if Wvx is None: Wvx = np.random.rand(nV,nX) * .001
53
+ if Whx is None: Whx = np.random.rand(nH,nX) * .001
54
+ if bh is None: bh = np.zeros((nH,))
55
+ if bv is None: bv = np.zeros((nV,))
56
+
57
+ self.Wvh = Wvh
58
+ self.Wvx = Wvx
59
+ self.Whx = Whx
60
+ self.bv = bv
61
+ self.bh = bh
62
+
63
+
64
+
65
+ def __repr__(self):
66
+ to_return = 'Restricted Boltzmann machine, VxH={}x{}'.format(self.Wvh.shape[0],self.Wvh.shape[1])
67
+ return to_return
68
+
69
+
70
+ def __str__(self):
71
+ to_return = 'Restricted Boltzmann machine, VxH={}x{}'.format(self.W.shape[0],self.W.shape[1])
72
+ return to_return
73
+
74
+ def nLayers(self):
75
+ return 1
76
+
77
+ @property
78
+ def layers(self):
79
+ """Return list of layer sizes, [N,H1,H2,...]
80
+ N = # of input features ("V")
81
+ Hi = # of hidden nodes in layer i ("H")
82
+ """
83
+ layers = [self.Wvh.shape[0], self.Wvh.shape[1]]
84
+ #if len(self.wts):
85
+ # layers = [self.W.shape[0], self.W.shape[1]]
86
+ # #layers = [self.wts[l].shape[1] for l in range(len(self.wts))]
87
+ # #layers.append( self.wts[-1].shape[0] )
88
+ #else:
89
+ # layers = []
90
+ return layers
91
+
92
+ @layers.setter
93
+ def layers(self, layers):
94
+ raise NotImplementedError
95
+ # adapt / change size of weight matrices (?)
96
+
97
+
98
+
99
+ ## CORE METHODS ################################################################
100
+ # todo: CD, BP; persistent CD? make BP persistent? others?
101
+ # : estimate marginal likelihood in various ways?
102
+
103
+ def marginals():
104
+ raise NotImplementedError
105
+
106
+ def marg_h(self, v, bh=None):
107
+ if bh is None: bh = self.bh
108
+ th = _sigma( v.dot(self.Wvh) + bh ) ## !!! regular rbm vs crbm?
109
+ return th
110
+
111
+ #@profile
112
+ def marg_bp(self, maxiter=100, bv=None,bh=None,stoptol=1e-6):
113
+ '''Estimate the singleton & pairwise marginals using belief propagation'''
114
+ Wvh = self.Wvh # pass in bv, bh to enable Whx etc?
115
+ if bv is None: bv = self.bv
116
+ if bh is None: bh = self.bh
117
+ Mvh = np.empty(Wvh.shape); Mvh.fill(0.5);
118
+ Mhv = Mvh.T.copy()
119
+ tv, th = _sigma(self.bv), _sigma(self.bh)
120
+ tvOld = 0*tv;
121
+ for t in range(maxiter):
122
+ # h to v:
123
+ Lvh1 = (1 - Mhv).T * th #Lvh1 = (1 - Mhv).T.dot( np.diag(th) )
124
+ Lvh2 = Mhv.T * ( 1-th ) #Lvh2 = Mhv.T.dot( np.diag( 1-th ) )
125
+ Mvh = _sigma( np.log( (np.exp(Wvh)*Lvh1 + Lvh2)/(Lvh1+Lvh2) ) )
126
+ tv = _sigma( bv + np.log( Mvh/(1-Mvh) ).sum(1) )
127
+ if np.max(np.abs(tv-tvOld)) < stoptol: break;
128
+ # v to h:
129
+ Lhv1 = (1 - Mvh).T * tv #Lhv1 = (1 - Mvh).T.dot( np.diag(tv) )
130
+ Lhv2 = Mvh.T * (1-tv) #Lhv2 = Mvh.T.dot( np.diag(1-tv) )
131
+ Mhv = _sigma( np.log( (np.exp(Wvh.T)*Lhv1+Lhv2)/(Lhv1+Lhv2) ) )
132
+ th = _sigma( bh + np.log( Mhv/(1-Mhv) ).sum(1) )
133
+ Gsum = np.outer( 1-tv, 1-th ) * Mvh * Mhv.T
134
+ Gsum+= np.outer( tv, 1-th)*(1-Mvh)*Mhv.T
135
+ Gsum+= np.outer(1-tv,th)*Mvh*(1-Mhv.T)
136
+ G = np.exp(Wvh)*np.outer(tv,th)*(1-Mvh)*(1-Mhv.T)
137
+ G /= (Gsum+G)
138
+ return G,tv,th
139
+
140
+
141
+ def marg_cd(self, nstep=1,vinit=None, bv=None,bh=None, nchains=1):
142
+ '''Estimate the singleton & pairwise marginals using gibbs sampling (for contrastive divergence)'''
143
+ Wvh = self.Wvh # pass in bv, bh to enable Whx etc?
144
+ if bv is None: bv = self.bv
145
+ if bh is None: bh = self.bh
146
+ if vinit is None: raise NotImplementedError; # todo: init using p(v)
147
+ G,tv,th = 0,0,0
148
+ for c in range(nchains):
149
+ v = vinit;
150
+ for s in range(nstep):
151
+ ph = 1 / (1+np.exp(-v.dot(Wvh)-bh));
152
+ h = (np.random.rand(*ph.shape) < ph);
153
+ pv = 1 / (1+np.exp(-Wvh.dot(h)-bv));
154
+ v = (np.random.rand(*pv.shape) < pv);
155
+ tv += v; th += h; G += np.outer(v,h);
156
+ return G/nchains,tv/nchains,th/nchains
157
+ # TODO: variants: use p(h|v), or use all K samples
158
+
159
+
160
+ def nll_gap(self, Xtr,Ytr, Xva,Yva):
161
+ fe = np.mean( np.sum(Ytr*(self.bv + Xtr.dot(self.Wvx.T)),1) +
162
+ np.sum(np.log(1.0+np.exp( Ytr.dot(self.Wvh) + Xtr.dot(self.Whx.T) + self.bh ) ),1) )
163
+ fe-= np.mean( np.sum(Yva*(self.bv + Xva.dot(self.Wvx.T)),1) +
164
+ np.sum(np.log(1.0+np.exp( Yva.dot(self.Wvh) + Xva.dot(self.Whx.T) + self.bh ) ),1) )
165
+ return fe
166
+
167
+ def err(self,X,Y):
168
+ Y = arr( Y )
169
+ Yhat = arr( self.predict(X) )
170
+ return np.mean(Yhat.reshape(Y.shape) != Y)
171
+
172
+ def nll(self,X,Y):
173
+ # TODO: fix; evaluate/estimate actual NLL?
174
+ P = self.predictSoft(X);
175
+ J = -np.mean( Y*np.log(P) + (1-Y)*np.log(1-P) );
176
+ return J
177
+
178
+ def predictLBP(self, X):
179
+ if len(X.shape)==1: X = X.reshape(1,-1)
180
+ Y = np.zeros((X.shape[0],self.Wvx.shape[0]));
181
+ for j in range(X.shape[0]):
182
+ bxh = self.bh + self.Whx.dot(X[j,:].T)
183
+ bxv = self.bv + self.Wvx.dot(X[j,:].T)
184
+ mu = marg_h(self, Y[j,:],bxh)
185
+ G,tv,th = marg_bp(self, 5, bxv, bxh)
186
+ Y[j,:] = tv;
187
+ return Y
188
+
189
+ def predictGibbs(self, X):
190
+ if len(X.shape)==1: X = X.reshape(1,-1)
191
+ Y = np.zeros((X.shape[0],self.Wvx.shape[0]));
192
+ for j in range(X.shape[0]):
193
+ bxh = self.bh + self.Whx.dot(X[j,:].T)
194
+ bxv = self.bv + self.Wvx.dot(X[j,:].T)
195
+ mu = marg_h(self, Y[j,:],bxh)
196
+ G,tv,th = marg_cd(self, 15, np.random.rand(Y[j,:].shape[0]), bxv, bxh)
197
+ Y[j,:] = tv;
198
+ return Y
199
+
200
+ def predict(self, X):
201
+ # Hard prediction. TODO: create sampling function, MAP prediction function
202
+ return self.predictSoft(X) > 0.5;
203
+
204
+ def predictSoft(self, X):
205
+ """Make 'soft' (per-class confidence) predictions of the rbm on data X.
206
+
207
+ Args:
208
+ X : MxN numpy array containing M data points with N features each
209
+
210
+ Returns:
211
+ P : MxC numpy array of C class probabilities for each of the M data
212
+ """
213
+ Y = np.zeros((X.shape[0],self.Wvx.shape[0]));
214
+ for j in range(X.shape[0]):
215
+ bxh = self.bh + self.Whx.dot(X[j,:].T)
216
+ bxv = self.bv + self.Wvx.dot(X[j,:].T)
217
+ mu = self.marg_h(Y[j,:],bxh)
218
+ G,tv,th = self.marg_bp(5, bxv, bxh)
219
+ Y[j,:] = tv;
220
+ return Y
221
+
222
+
223
+ # TODO: add momentum for learning update
224
+ def train(self, X, Y, Xv=None,Yv=None, stepsize=.01, stopGap=0.1, stopEpoch=10):
225
+ """Train the (c)RBM
226
+
227
+ Args:
228
+ X : MxNx numpy array containing M data points with N features each
229
+ Y : MxNv numpy array of targets (visible units) for each data point in X
230
+ stepsize : scalar
231
+ The stepsize for gradient descent (decreases as 1 / iter).
232
+ stopTol : scalar
233
+ Tolerance for stopping criterion.
234
+ stopIter : int
235
+ The maximum number of steps before stopping.
236
+ activation : str
237
+ 'logistic', 'htangent', or 'custom'. Sets the activation functions.
238
+
239
+ """
240
+ # TODO: Shape & argument checking
241
+
242
+ # outer loop of (mini-batch) stochastic gradient descent
243
+ it, j = 1, 0 # iteration number & data index
244
+ nextPrint = 1 # next time to print info
245
+ done = 0 # end of loop flag
246
+ nBatch = 40
247
+
248
+ while not done:
249
+ step_i = 3.0*stepsize / (2.0+it) # step size evolution; classic 1/t decrease
250
+
251
+ dWvh, dWvx, dWhx, dbv, dbh = 0.0, 0.0, 0.0, 0.0, 0.0
252
+ # stochastic gradient update (one pass)
253
+ for jj in range(nBatch):
254
+ #print('j={}; jj={};'.format(j,jj));
255
+ j += 1
256
+ if j >= Y.shape[0]: j=0; it+=1;
257
+ # compute conditional model & required probabilities
258
+ bxh = self.bh + self.Whx.dot(X[j,:].T)
259
+ bxv = self.bv + self.Wvx.dot(X[j,:].T)
260
+ mu = self.marg_h(Y[j,:],bxh)
261
+ G,tv,th = self.marg_cd( 1, Y[j,:], bxv, bxh, 1)
262
+ #G,tv,th = self.marg_bp( min(4+it,50), bxv, bxh )
263
+ if (jj==1): #(np.random.rand() < .1):
264
+ plt.figure(1);
265
+ plt.subplot(221); plt.imshow(X[j,:].reshape(28,28)); plt.title('Observed X'); plt.draw();
266
+ plt.subplot(222); plt.imshow(tv.reshape(28,28)); plt.title('Model Prob'); plt.draw();
267
+ plt.subplot(223); plt.imshow(Y[j,:].reshape(28,28)); plt.title('Visible Y'); plt.draw();
268
+ plt.pause(.01);
269
+ # take gradient step:
270
+ dWvh += (np.outer(Y[j,:], mu) - G)
271
+ dWvx += (np.outer(Y[j,:], X[j,:]) - np.outer(tv,X[j,:]))
272
+ dWhx += (np.outer(mu, X[j,:]) - np.outer(th,X[j,:]))
273
+ dbv += (Y[j,:] - tv)
274
+ dbh += (mu - th)
275
+
276
+ self.Wvh += step_i * dWvh / nBatch
277
+ self.Wvx += step_i * dWvx / nBatch
278
+ self.Whx += step_i * dWhx / nBatch
279
+ self.bv += step_i * dbv / nBatch
280
+ self.bh += step_i * dbh / nBatch
281
+
282
+ print('it {} : Gap = {}'.format(it,self.nll_gap(X,Y,Xv,Yv)));
283
+ print(' {} {} {} {} {}'.format(np.mean(self.Wvx**2),np.mean(self.Whx**2),np.mean(self.Wvh**2),np.mean(self.bv**2),np.mean(self.bh**2)));
284
+
285
+ Jtr,Jva = 0,0 #self.nll(X,Y),self.nll(Xv,Yv);
286
+ if it >= nextPrint:
287
+ print('it {} : Gap = {}'.format(it,self.nll_gap(X,Y,Xv,Yv)));
288
+ print(' {} {} {} {} {}'.format(np.mean(self.Wvx**2),np.mean(self.Whx**2),np.mean(self.Wvh**2),np.mean(self.bv**2),np.mean(self.bh**2)));
289
+ #print('it {} : Jtr = {} / Jva = {}'.format(it,Jtr,Jva))
290
+ nextPrint += 1; #*= 2
291
+
292
+ # check if finished
293
+ done = (it > 1) and ((Jva - Jtr) > stopGap) or it >= stopEpoch
294
+ #it += 1 # counting epochs elsewhere now
295
+
296
+
297
+
298
+
299
+ #def err_k(self, X, Y):
300
+ # """Compute misclassification error rate. Assumes Y in 1-of-k form. """
301
+ # return self.err(X, from1ofK(Y,self.classes).ravel())
302
+ #
303
+ #
304
+ #def mse(self, X, Y):
305
+ # """Compute mean squared error of predictor 'obj' on test data (X,Y). """
306
+ # return mse_k(X, to1ofK(Y))
307
+ #
308
+ #
309
+ #def mse_k(self, X, Y):
310
+ # """Compute mean squared error of predictor; assumes Y is in 1-of-k format. """
311
+ # return np.power(Y - self.predictSoft(X), 2).sum(1).mean(0)
312
+
313
+
314
+ ## MUTATORS ####################################################################
315
+
316
+
317
+
318
+ ################################################################################
319
+ ################################################################################
320
+ ################################################################################
321
+