pygms 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pygms-0.4.0/.gitignore +71 -0
- pygms-0.4.0/PKG-INFO +45 -0
- pygms-0.4.0/README.md +28 -0
- pygms-0.4.0/pygms/.RData +0 -0
- pygms-0.4.0/pygms/.Rhistory +8 -0
- pygms-0.4.0/pygms/.swo +0 -0
- pygms-0.4.0/pygms/__init__.py +28 -0
- pygms-0.4.0/pygms/_notes.txt +130 -0
- pygms-0.4.0/pygms/_profile.py +40 -0
- pygms-0.4.0/pygms/_rbm.py +321 -0
- pygms-0.4.0/pygms/_todo.txt +131 -0
- pygms-0.4.0/pygms/causal.py +42 -0
- pygms-0.4.0/pygms/csp.py +385 -0
- pygms-0.4.0/pygms/data/__init__.py +23 -0
- pygms-0.4.0/pygms/data/__init__.py.bak +23 -0
- pygms-0.4.0/pygms/data/catalog.py +234 -0
- pygms-0.4.0/pygms/data/catalog.py.bak +234 -0
- pygms-0.4.0/pygms/data/sources.json +9 -0
- pygms-0.4.0/pygms/decisions.py +199 -0
- pygms-0.4.0/pygms/development.py +39 -0
- pygms-0.4.0/pygms/draw.py +366 -0
- pygms-0.4.0/pygms/factor.py +15 -0
- pygms-0.4.0/pygms/factorGauss.py +344 -0
- pygms-0.4.0/pygms/factorNumpy.py +718 -0
- pygms-0.4.0/pygms/factorSparse.py +698 -0
- pygms-0.4.0/pygms/factorTorch.py +758 -0
- pygms-0.4.0/pygms/filetypes.py +771 -0
- pygms-0.4.0/pygms/graphmodel.py +798 -0
- pygms-0.4.0/pygms/indexedheap.py +109 -0
- pygms-0.4.0/pygms/ising.py +698 -0
- pygms-0.4.0/pygms/jupyter.py +223 -0
- pygms-0.4.0/pygms/learning.py +246 -0
- pygms-0.4.0/pygms/messagepass.py +226 -0
- pygms-0.4.0/pygms/misc.py +358 -0
- pygms-0.4.0/pygms/montecarlo.py +513 -0
- pygms-0.4.0/pygms/regiongraph.py +223 -0
- pygms-0.4.0/pygms/search.py +282 -0
- pygms-0.4.0/pygms/search1.py +680 -0
- pygms-0.4.0/pygms/search2.py +324 -0
- pygms-0.4.0/pygms/search_sum.py +402 -0
- pygms-0.4.0/pygms/varset_py.py +86 -0
- pygms-0.4.0/pygms/varset_py2.py +152 -0
- pygms-0.4.0/pygms/weighted.py +380 -0
- pygms-0.4.0/pygms/wmb.py +733 -0
- pygms-0.4.0/pygms/wogm.py +312 -0
- pygms-0.4.0/pyproject.toml +38 -0
pygms-0.4.0/.gitignore
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
notes.txt
|
|
2
|
+
*.swp
|
|
3
|
+
|
|
4
|
+
# Byte-compiled / optimized / DLL files
|
|
5
|
+
__pycache__/
|
|
6
|
+
*.py[cod]
|
|
7
|
+
*$py.class
|
|
8
|
+
|
|
9
|
+
# C extensions
|
|
10
|
+
*.so
|
|
11
|
+
|
|
12
|
+
# Distribution / packaging
|
|
13
|
+
.Python
|
|
14
|
+
env/
|
|
15
|
+
build/
|
|
16
|
+
develop-eggs/
|
|
17
|
+
dist/
|
|
18
|
+
downloads/
|
|
19
|
+
eggs/
|
|
20
|
+
.eggs/
|
|
21
|
+
lib/
|
|
22
|
+
lib64/
|
|
23
|
+
parts/
|
|
24
|
+
sdist/
|
|
25
|
+
var/
|
|
26
|
+
*.egg-info/
|
|
27
|
+
.installed.cfg
|
|
28
|
+
*.egg
|
|
29
|
+
|
|
30
|
+
# PyInstaller
|
|
31
|
+
# Usually these files are written by a python script from a template
|
|
32
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
33
|
+
*.manifest
|
|
34
|
+
*.spec
|
|
35
|
+
|
|
36
|
+
# Installer logs
|
|
37
|
+
pip-log.txt
|
|
38
|
+
pip-delete-this-directory.txt
|
|
39
|
+
|
|
40
|
+
# Unit test / coverage reports
|
|
41
|
+
htmlcov/
|
|
42
|
+
.tox/
|
|
43
|
+
.coverage
|
|
44
|
+
.coverage.*
|
|
45
|
+
.cache
|
|
46
|
+
nosetests.xml
|
|
47
|
+
coverage.xml
|
|
48
|
+
*,cover
|
|
49
|
+
.hypothesis/
|
|
50
|
+
|
|
51
|
+
# Translations
|
|
52
|
+
*.mo
|
|
53
|
+
*.pot
|
|
54
|
+
|
|
55
|
+
# Django stuff:
|
|
56
|
+
*.log
|
|
57
|
+
|
|
58
|
+
# Sphinx documentation
|
|
59
|
+
docs/_build/
|
|
60
|
+
|
|
61
|
+
# PyBuilder
|
|
62
|
+
target/
|
|
63
|
+
|
|
64
|
+
#Ipython Notebook
|
|
65
|
+
.ipynb_checkpoints
|
|
66
|
+
|
|
67
|
+
#VSCode
|
|
68
|
+
.vscode
|
|
69
|
+
|
|
70
|
+
# Mac osX meta data
|
|
71
|
+
.DS_Store
|
pygms-0.4.0/PKG-INFO
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pygms
|
|
3
|
+
Version: 0.4.0
|
|
4
|
+
Summary: Python Graphical Models Toolbox
|
|
5
|
+
Author-email: Alexander Ihler <ihler@ics.uci.edu>
|
|
6
|
+
License-Expression: BSD-2-Clause
|
|
7
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
8
|
+
Classifier: Operating System :: OS Independent
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Requires-Python: >=3.6
|
|
11
|
+
Requires-Dist: matplotlib>=3.2
|
|
12
|
+
Requires-Dist: networkx>=2.5
|
|
13
|
+
Requires-Dist: numpy>=1.18
|
|
14
|
+
Requires-Dist: scipy>=1.4
|
|
15
|
+
Requires-Dist: sortedcontainers>=1.5.7
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
|
|
18
|
+
pyGMs : A Python toolbox for Graphical Models
|
|
19
|
+
================
|
|
20
|
+
|
|
21
|
+
This code provides a simple Python-based interface for defining probabilistic
|
|
22
|
+
graphical models (Bayesian networks, factor graphs, etc.) over discrete random
|
|
23
|
+
variables, along with a number of routines for approximate inference. It is
|
|
24
|
+
being developed for use in teaching, as well as prototyping for research.
|
|
25
|
+
|
|
26
|
+
The code currently uses [NumPy](http://www.numpy.org/) for representing and
|
|
27
|
+
operating on the table-based representation of discrete factors, and
|
|
28
|
+
[SortedContainers](https://pypi.python.org/pypi/sortedcontainers) for some
|
|
29
|
+
internal representations. Smaller portions use [networkx](https://networkx.org/)
|
|
30
|
+
and [scipy](https://www.scipy.org/) as well.
|
|
31
|
+
|
|
32
|
+
## Installation
|
|
33
|
+
|
|
34
|
+
Simply download or clone the repository to a directory *pyGMs*, and add its
|
|
35
|
+
parent directory to your Python path, either:
|
|
36
|
+
```
|
|
37
|
+
$ export PYTHONPATH=${PYTHONPATH}:/directory/containing/
|
|
38
|
+
```
|
|
39
|
+
or in Python
|
|
40
|
+
```
|
|
41
|
+
import sys
|
|
42
|
+
sys.path.append('/directory/containing/')
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
|
pygms-0.4.0/README.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
pyGMs : A Python toolbox for Graphical Models
|
|
2
|
+
================
|
|
3
|
+
|
|
4
|
+
This code provides a simple Python-based interface for defining probabilistic
|
|
5
|
+
graphical models (Bayesian networks, factor graphs, etc.) over discrete random
|
|
6
|
+
variables, along with a number of routines for approximate inference. It is
|
|
7
|
+
being developed for use in teaching, as well as prototyping for research.
|
|
8
|
+
|
|
9
|
+
The code currently uses [NumPy](http://www.numpy.org/) for representing and
|
|
10
|
+
operating on the table-based representation of discrete factors, and
|
|
11
|
+
[SortedContainers](https://pypi.python.org/pypi/sortedcontainers) for some
|
|
12
|
+
internal representations. Smaller portions use [networkx](https://networkx.org/)
|
|
13
|
+
and [scipy](https://www.scipy.org/) as well.
|
|
14
|
+
|
|
15
|
+
## Installation
|
|
16
|
+
|
|
17
|
+
Simply download or clone the repository to a directory *pyGMs*, and add its
|
|
18
|
+
parent directory to your Python path, either:
|
|
19
|
+
```
|
|
20
|
+
$ export PYTHONPATH=${PYTHONPATH}:/directory/containing/
|
|
21
|
+
```
|
|
22
|
+
or in Python
|
|
23
|
+
```
|
|
24
|
+
import sys
|
|
25
|
+
sys.path.append('/directory/containing/')
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
|
pygms-0.4.0/pygms/.RData
ADDED
|
Binary file
|
pygms-0.4.0/pygms/.swo
ADDED
|
Binary file
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""
|
|
2
|
+
pyGMs: Python Graphical Model code
|
|
3
|
+
|
|
4
|
+
A simple graphical model class for learning about, testing, and developing algorithms
|
|
5
|
+
for graphical models.
|
|
6
|
+
|
|
7
|
+
Version 0.4.0 (2026-03-31)
|
|
8
|
+
|
|
9
|
+
(c) 2015-2026 Alexander Ihler under the FreeBSD license; see license.txt for details.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from sortedcontainers import SortedSet as sset;
|
|
13
|
+
|
|
14
|
+
from .factor import *
|
|
15
|
+
#from .factorSparse import *
|
|
16
|
+
from .graphmodel import *
|
|
17
|
+
from .filetypes import *
|
|
18
|
+
from .misc import *
|
|
19
|
+
from .draw import *
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
__title__ = 'pygms'
|
|
23
|
+
__version__ = '0.4.0'
|
|
24
|
+
__author__ = 'Alexander Ihler'
|
|
25
|
+
__license__ = 'BSD'
|
|
26
|
+
__copyright__ = '2015-2026, Alexander Ihler'
|
|
27
|
+
|
|
28
|
+
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
|
|
2
|
+
General principles:
|
|
3
|
+
|
|
4
|
+
(1) GraphModel & its inheritors are containers for collections of factors (funtions over a few variables each)
|
|
5
|
+
(2) Variables are assumed to be X0...Xn, although some may be unused. When the dimension of Xi is known/not required,
|
|
6
|
+
the index itself can be used instead.
|
|
7
|
+
(3) Configurations (x0...xn) can be represented in different ways, depending on the circumstances:
|
|
8
|
+
* A map {0:x0, 1:x1, ...}, with unspecified values of a partial configuration left out
|
|
9
|
+
* An nparray, tuple, or list [x0,x1,...xn], with missing values (if allowed) set to NaN
|
|
10
|
+
* A collection of data X=[xa,xb,xc...] with xa=[xa0,...xan], so that X[i] is the ith data point
|
|
11
|
+
* This can be a list of lists or tuples, or a 2D numpy array
|
|
12
|
+
* We can detect whether a single data point or multiple are specified using try: next(iter(next(iter(X))))
|
|
13
|
+
* Some functions only accept single data (should check & error), others expect multiple, or can take either.
|
|
14
|
+
(4) All data are expected to take values 0...d-1, for discrete variables with d states, even e.g. "Ising" models.
|
|
15
|
+
(5) "isLog" indicates whether the model consists of log-factors: G(x) = \sum_a g_a(x_a),
|
|
16
|
+
or exp-factors: F(x) = \sum_a f_a(x_a). Functions which make use of the joint value (log-likelihood, etc.)
|
|
17
|
+
or multiple factors in combination (e.g., variable elimination) use this in computing their quantities.
|
|
18
|
+
The model representation may be switched using "toLog" and "toExp".
|
|
19
|
+
(6) Sampling functions should return config-value pairs: x,lnq where x \sim q(X) and lnq = log(q(x)).
|
|
20
|
+
(7) Training functions take the form "fit_method" for methods that estimate graph structure, and "refit_method"
|
|
21
|
+
for methods that preserve the current factors/cliques and simply update the parameters.
|
|
22
|
+
(8) By default, GraphModel makes a copy of each factor during construction. This is because some functions
|
|
23
|
+
(such as refit or reparameterization inference) may change the values of the factor tables. If this is not
|
|
24
|
+
desired, "gm = GraphModel(factors, copy=False)" will use references to the factors (useful for memory sharing,
|
|
25
|
+
for example), but set gm.lock = True to indicate that functions should not alter these factors.
|
|
26
|
+
The function "gm.copy()" returns a new copy of the model with non-const factors.
|
|
27
|
+
(9) Some functions (conditioning, manual factor additions, etc.) may alter the structure / cliques of the model as
|
|
28
|
+
well. Iterative methods (message passing inference, etc.) can check "gm.sig" to make sure the structure has
|
|
29
|
+
not changed between iterations to ensure validity, and may raise an exception if it does.
|
|
30
|
+
(TODO: add another lock flag for structure; functions should check before altering structure.)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
===============================================================================
|
|
41
|
+
|
|
42
|
+
*** Version that can use torch tensors?
|
|
43
|
+
* Allow for "direct" optimization of various quantities?
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
*** Data representation !!!
|
|
47
|
+
* Want x[i] to be data point i?
|
|
48
|
+
* Want to be able to pass lists of tuples, or a single tuple? (No single tuples?) (Convert to arrays?)
|
|
49
|
+
* If "not 2D" convert before operation? (Note: some operations can only take single tuples)
|
|
50
|
+
* Also want to be able to access X[:,j] = {x_j for all data}
|
|
51
|
+
* Want consistent with standard torch forms
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
*** Exact inference
|
|
56
|
+
* Sensitivity analysis?
|
|
57
|
+
* Exp family form: dLogZ/dTheta = E[X]
|
|
58
|
+
* "BN" form: specify p(A=a|E=e) and x=p(Xi=xi|Xpai=xpai) ?
|
|
59
|
+
* all-to-one version requires BN form and bnOrder, and "one" p(A=a|E=e)
|
|
60
|
+
* one-to-all version requires p(Xi=xi|Xpai=xpai); BN can be unnormalized
|
|
61
|
+
* BN specialized functions? Evidence pruning?
|
|
62
|
+
* CSP specialized functions?
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
*** Sampling
|
|
66
|
+
* sampling function returns config, logP = sample( [partial?] )
|
|
67
|
+
* MCMC returns config, logP = sample( startcfg )
|
|
68
|
+
*** Maybe make a generator object? "yield"?
|
|
69
|
+
* Method to "aggregate" samples? ("Query" object...) => generic "estimate marginal" f'ns, etc?
|
|
70
|
+
* Useful for repeated / cts improvement inference (save state)
|
|
71
|
+
* QueryMarginals (list cliques); QueryExpectations (list f'ns); QueryHistory (log all)
|
|
72
|
+
|
|
73
|
+
* Basic Gibbs (two versions)
|
|
74
|
+
** Structured gibbs? Sets to sample; generate conditioned sub-models; VE-sample? (In-place slices: efficient?)
|
|
75
|
+
* MH: any common proposals?
|
|
76
|
+
TODO?
|
|
77
|
+
* Importance sampling (several? WMB, Tree BP, MF?)
|
|
78
|
+
* Annealed IS
|
|
79
|
+
* Estimators? Discriminance sampling?
|
|
80
|
+
|
|
81
|
+
*** Search
|
|
82
|
+
* Pseudo-tree object
|
|
83
|
+
* Heuristic f'n: takes partial config, returns cost-to-go of internal PTree given config
|
|
84
|
+
* Next variable f'n: given partial config, what are the next vars to condition on?
|
|
85
|
+
* Node priority f'n: when adding node, when should we re-examine?
|
|
86
|
+
* Nodes link to heuristic, priority f'n so they can modify / be dynamic? PFunc can use heuristic?
|
|
87
|
+
* Search algos:
|
|
88
|
+
* Basic DFS
|
|
89
|
+
* A/O DFS? (Rotating?)
|
|
90
|
+
* Best-First / A*
|
|
91
|
+
* A/O A*? Mem limited A*?
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
*** Variational
|
|
96
|
+
* Various forms of DD? (SoftArc done/easy; others?)
|
|
97
|
+
* Basic BP algo? Fancier (scheduling, etc)? Region models?
|
|
98
|
+
* NMF done; structured MF?
|
|
99
|
+
* WMB:
|
|
100
|
+
* Basic incremental MB: build; msg pass, merge; msg pass, merge; ...
|
|
101
|
+
* Sholeh algorithm?
|
|
102
|
+
|
|
103
|
+
??? Iterative algorithms, use yield or other structures to run iterations?
|
|
104
|
+
* Can verify no structural changes, etc;
|
|
105
|
+
* Reparameterization vs message forms
|
|
106
|
+
** "General" outer loop that checks for timeout, various convergence conditions? Or make for each algo?
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
*** Learning
|
|
110
|
+
*CRF representation? (More general factor representations? ?)
|
|
111
|
+
* EM? Stochastic EM?
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
*** Structure learning
|
|
115
|
+
* Several Ising methods (clean up)
|
|
116
|
+
* Independence tests
|
|
117
|
+
* Non-ising group lasso, etc.
|
|
118
|
+
* BN stochastic search
|
|
119
|
+
* BN ILP method
|
|
120
|
+
* ...
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
*** Special models: Ising models (clean?), RBM/CRBM,
|
|
124
|
+
|
|
125
|
+
*** Other people's algorithms?
|
|
126
|
+
* Vibhav's: sample, assemble sparse JTree until memory limit, solve
|
|
127
|
+
* Need to "decompose" proposal along graph structure?
|
|
128
|
+
* Maua's: solve preserving lists of factors / configs?
|
|
129
|
+
* Model conversions? Binary, pairwise, etc?
|
|
130
|
+
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
## Useful code for profiling execution speed, etc.
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from line_profiler import LineProfiler
|
|
5
|
+
|
|
6
|
+
def do_profile(follow=[]):
|
|
7
|
+
def inner(func):
|
|
8
|
+
def profiled_func(*args, **kwargs):
|
|
9
|
+
try:
|
|
10
|
+
profiler = LineProfiler()
|
|
11
|
+
profiler.add_function(func)
|
|
12
|
+
for f in follow:
|
|
13
|
+
profiler.add_function(f)
|
|
14
|
+
profiler.enable_by_count()
|
|
15
|
+
return func(*args, **kwargs)
|
|
16
|
+
finally:
|
|
17
|
+
profiler.print_stats()
|
|
18
|
+
return profiled_func
|
|
19
|
+
return inner
|
|
20
|
+
|
|
21
|
+
except ImportError:
|
|
22
|
+
def do_profile(follow=[]):
|
|
23
|
+
"Helpful if you accidentally leave in production!"
|
|
24
|
+
def inner(func):
|
|
25
|
+
def nothing(*args, **kwargs):
|
|
26
|
+
return func(*args, **kwargs)
|
|
27
|
+
return nothing
|
|
28
|
+
return inner
|
|
29
|
+
|
|
30
|
+
def get_number():
|
|
31
|
+
for x in xrange(5000000):
|
|
32
|
+
yield x
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# To profile the function, decorate e.g.:
|
|
36
|
+
|
|
37
|
+
#@do_profile(follow=[get_number])
|
|
38
|
+
#def __init__(self,model,elimOrder,force_or=False,max_width=None):
|
|
39
|
+
|
|
40
|
+
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import copy
|
|
3
|
+
|
|
4
|
+
from .base import classifier
|
|
5
|
+
from .base import regressor
|
|
6
|
+
from .utils import toIndex, fromIndex, to1ofK, from1ofK
|
|
7
|
+
from numpy import asarray as arr
|
|
8
|
+
from numpy import atleast_2d as twod
|
|
9
|
+
from numpy import asmatrix as mat
|
|
10
|
+
|
|
11
|
+
from scipy.special import expit
|
|
12
|
+
import matplotlib.pyplot as plt
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# TODO:
|
|
16
|
+
# (1) Gibbs / CD variants (multiple chains, averaging, rao-blackwell estimates
|
|
17
|
+
# (2) logZ estimates: brute force enumeration, AIS estimate, loopyBP estimate, others?
|
|
18
|
+
# (3) loss functions: reconstruction errors (mse, logP, etc.); (approx) data likelihood; FE difference; others?
|
|
19
|
+
# (4) helpers? logsumexp? see wei & ruslan's code?
|
|
20
|
+
# (5) deep versions; variable numbers of layers? (specialize first)
|
|
21
|
+
# - simple if use functions taking Wvh, bv+Wvx*X, bh+Whx*X? each layer has Wlx, bl terms, + Wll' terms?
|
|
22
|
+
|
|
23
|
+
################################################################################
|
|
24
|
+
## BASIC RBM ################################################################
|
|
25
|
+
################################################################################
|
|
26
|
+
|
|
27
|
+
def _add1(X):
|
|
28
|
+
return np.hstack( (np.ones((X.shape[0],1)),X) )
|
|
29
|
+
|
|
30
|
+
def _sigma(z):
|
|
31
|
+
return expit(z);
|
|
32
|
+
#return 1.0/(1.0+np.exp(-z))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class crbm(object):
|
|
36
|
+
"""A restricted Boltzmann machine
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, nV,nH,nX, Wvh=None,bh=None,bv=None, Wvx=None,Whx=None):
|
|
43
|
+
"""Constructor for a (conditional) restricted Boltzmann machine
|
|
44
|
+
Parameters:
|
|
45
|
+
nV : # of visible nodes (observable data)
|
|
46
|
+
nH : # of hidden nodes (latent variables)
|
|
47
|
+
nX : # of always-observed conditioning variables
|
|
48
|
+
Wvh, Wvx, Whx : pairwise weights (default: initialize randomly)
|
|
49
|
+
bh,bv : bias parameters (default: initialize to zero)
|
|
50
|
+
"""
|
|
51
|
+
if Wvh is None: Wvh = np.random.rand(nV,nH) * .001
|
|
52
|
+
if Wvx is None: Wvx = np.random.rand(nV,nX) * .001
|
|
53
|
+
if Whx is None: Whx = np.random.rand(nH,nX) * .001
|
|
54
|
+
if bh is None: bh = np.zeros((nH,))
|
|
55
|
+
if bv is None: bv = np.zeros((nV,))
|
|
56
|
+
|
|
57
|
+
self.Wvh = Wvh
|
|
58
|
+
self.Wvx = Wvx
|
|
59
|
+
self.Whx = Whx
|
|
60
|
+
self.bv = bv
|
|
61
|
+
self.bh = bh
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def __repr__(self):
|
|
66
|
+
to_return = 'Restricted Boltzmann machine, VxH={}x{}'.format(self.Wvh.shape[0],self.Wvh.shape[1])
|
|
67
|
+
return to_return
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def __str__(self):
|
|
71
|
+
to_return = 'Restricted Boltzmann machine, VxH={}x{}'.format(self.W.shape[0],self.W.shape[1])
|
|
72
|
+
return to_return
|
|
73
|
+
|
|
74
|
+
def nLayers(self):
|
|
75
|
+
return 1
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def layers(self):
|
|
79
|
+
"""Return list of layer sizes, [N,H1,H2,...]
|
|
80
|
+
N = # of input features ("V")
|
|
81
|
+
Hi = # of hidden nodes in layer i ("H")
|
|
82
|
+
"""
|
|
83
|
+
layers = [self.Wvh.shape[0], self.Wvh.shape[1]]
|
|
84
|
+
#if len(self.wts):
|
|
85
|
+
# layers = [self.W.shape[0], self.W.shape[1]]
|
|
86
|
+
# #layers = [self.wts[l].shape[1] for l in range(len(self.wts))]
|
|
87
|
+
# #layers.append( self.wts[-1].shape[0] )
|
|
88
|
+
#else:
|
|
89
|
+
# layers = []
|
|
90
|
+
return layers
|
|
91
|
+
|
|
92
|
+
@layers.setter
|
|
93
|
+
def layers(self, layers):
|
|
94
|
+
raise NotImplementedError
|
|
95
|
+
# adapt / change size of weight matrices (?)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
## CORE METHODS ################################################################
|
|
100
|
+
# todo: CD, BP; persistent CD? make BP persistent? others?
|
|
101
|
+
# : estimate marginal likelihood in various ways?
|
|
102
|
+
|
|
103
|
+
def marginals():
|
|
104
|
+
raise NotImplementedError
|
|
105
|
+
|
|
106
|
+
def marg_h(self, v, bh=None):
|
|
107
|
+
if bh is None: bh = self.bh
|
|
108
|
+
th = _sigma( v.dot(self.Wvh) + bh ) ## !!! regular rbm vs crbm?
|
|
109
|
+
return th
|
|
110
|
+
|
|
111
|
+
#@profile
|
|
112
|
+
def marg_bp(self, maxiter=100, bv=None,bh=None,stoptol=1e-6):
|
|
113
|
+
'''Estimate the singleton & pairwise marginals using belief propagation'''
|
|
114
|
+
Wvh = self.Wvh # pass in bv, bh to enable Whx etc?
|
|
115
|
+
if bv is None: bv = self.bv
|
|
116
|
+
if bh is None: bh = self.bh
|
|
117
|
+
Mvh = np.empty(Wvh.shape); Mvh.fill(0.5);
|
|
118
|
+
Mhv = Mvh.T.copy()
|
|
119
|
+
tv, th = _sigma(self.bv), _sigma(self.bh)
|
|
120
|
+
tvOld = 0*tv;
|
|
121
|
+
for t in range(maxiter):
|
|
122
|
+
# h to v:
|
|
123
|
+
Lvh1 = (1 - Mhv).T * th #Lvh1 = (1 - Mhv).T.dot( np.diag(th) )
|
|
124
|
+
Lvh2 = Mhv.T * ( 1-th ) #Lvh2 = Mhv.T.dot( np.diag( 1-th ) )
|
|
125
|
+
Mvh = _sigma( np.log( (np.exp(Wvh)*Lvh1 + Lvh2)/(Lvh1+Lvh2) ) )
|
|
126
|
+
tv = _sigma( bv + np.log( Mvh/(1-Mvh) ).sum(1) )
|
|
127
|
+
if np.max(np.abs(tv-tvOld)) < stoptol: break;
|
|
128
|
+
# v to h:
|
|
129
|
+
Lhv1 = (1 - Mvh).T * tv #Lhv1 = (1 - Mvh).T.dot( np.diag(tv) )
|
|
130
|
+
Lhv2 = Mvh.T * (1-tv) #Lhv2 = Mvh.T.dot( np.diag(1-tv) )
|
|
131
|
+
Mhv = _sigma( np.log( (np.exp(Wvh.T)*Lhv1+Lhv2)/(Lhv1+Lhv2) ) )
|
|
132
|
+
th = _sigma( bh + np.log( Mhv/(1-Mhv) ).sum(1) )
|
|
133
|
+
Gsum = np.outer( 1-tv, 1-th ) * Mvh * Mhv.T
|
|
134
|
+
Gsum+= np.outer( tv, 1-th)*(1-Mvh)*Mhv.T
|
|
135
|
+
Gsum+= np.outer(1-tv,th)*Mvh*(1-Mhv.T)
|
|
136
|
+
G = np.exp(Wvh)*np.outer(tv,th)*(1-Mvh)*(1-Mhv.T)
|
|
137
|
+
G /= (Gsum+G)
|
|
138
|
+
return G,tv,th
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def marg_cd(self, nstep=1,vinit=None, bv=None,bh=None, nchains=1):
|
|
142
|
+
'''Estimate the singleton & pairwise marginals using gibbs sampling (for contrastive divergence)'''
|
|
143
|
+
Wvh = self.Wvh # pass in bv, bh to enable Whx etc?
|
|
144
|
+
if bv is None: bv = self.bv
|
|
145
|
+
if bh is None: bh = self.bh
|
|
146
|
+
if vinit is None: raise NotImplementedError; # todo: init using p(v)
|
|
147
|
+
G,tv,th = 0,0,0
|
|
148
|
+
for c in range(nchains):
|
|
149
|
+
v = vinit;
|
|
150
|
+
for s in range(nstep):
|
|
151
|
+
ph = 1 / (1+np.exp(-v.dot(Wvh)-bh));
|
|
152
|
+
h = (np.random.rand(*ph.shape) < ph);
|
|
153
|
+
pv = 1 / (1+np.exp(-Wvh.dot(h)-bv));
|
|
154
|
+
v = (np.random.rand(*pv.shape) < pv);
|
|
155
|
+
tv += v; th += h; G += np.outer(v,h);
|
|
156
|
+
return G/nchains,tv/nchains,th/nchains
|
|
157
|
+
# TODO: variants: use p(h|v), or use all K samples
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def nll_gap(self, Xtr,Ytr, Xva,Yva):
|
|
161
|
+
fe = np.mean( np.sum(Ytr*(self.bv + Xtr.dot(self.Wvx.T)),1) +
|
|
162
|
+
np.sum(np.log(1.0+np.exp( Ytr.dot(self.Wvh) + Xtr.dot(self.Whx.T) + self.bh ) ),1) )
|
|
163
|
+
fe-= np.mean( np.sum(Yva*(self.bv + Xva.dot(self.Wvx.T)),1) +
|
|
164
|
+
np.sum(np.log(1.0+np.exp( Yva.dot(self.Wvh) + Xva.dot(self.Whx.T) + self.bh ) ),1) )
|
|
165
|
+
return fe
|
|
166
|
+
|
|
167
|
+
def err(self,X,Y):
|
|
168
|
+
Y = arr( Y )
|
|
169
|
+
Yhat = arr( self.predict(X) )
|
|
170
|
+
return np.mean(Yhat.reshape(Y.shape) != Y)
|
|
171
|
+
|
|
172
|
+
def nll(self,X,Y):
|
|
173
|
+
# TODO: fix; evaluate/estimate actual NLL?
|
|
174
|
+
P = self.predictSoft(X);
|
|
175
|
+
J = -np.mean( Y*np.log(P) + (1-Y)*np.log(1-P) );
|
|
176
|
+
return J
|
|
177
|
+
|
|
178
|
+
def predictLBP(self, X):
|
|
179
|
+
if len(X.shape)==1: X = X.reshape(1,-1)
|
|
180
|
+
Y = np.zeros((X.shape[0],self.Wvx.shape[0]));
|
|
181
|
+
for j in range(X.shape[0]):
|
|
182
|
+
bxh = self.bh + self.Whx.dot(X[j,:].T)
|
|
183
|
+
bxv = self.bv + self.Wvx.dot(X[j,:].T)
|
|
184
|
+
mu = marg_h(self, Y[j,:],bxh)
|
|
185
|
+
G,tv,th = marg_bp(self, 5, bxv, bxh)
|
|
186
|
+
Y[j,:] = tv;
|
|
187
|
+
return Y
|
|
188
|
+
|
|
189
|
+
def predictGibbs(self, X):
|
|
190
|
+
if len(X.shape)==1: X = X.reshape(1,-1)
|
|
191
|
+
Y = np.zeros((X.shape[0],self.Wvx.shape[0]));
|
|
192
|
+
for j in range(X.shape[0]):
|
|
193
|
+
bxh = self.bh + self.Whx.dot(X[j,:].T)
|
|
194
|
+
bxv = self.bv + self.Wvx.dot(X[j,:].T)
|
|
195
|
+
mu = marg_h(self, Y[j,:],bxh)
|
|
196
|
+
G,tv,th = marg_cd(self, 15, np.random.rand(Y[j,:].shape[0]), bxv, bxh)
|
|
197
|
+
Y[j,:] = tv;
|
|
198
|
+
return Y
|
|
199
|
+
|
|
200
|
+
def predict(self, X):
|
|
201
|
+
# Hard prediction. TODO: create sampling function, MAP prediction function
|
|
202
|
+
return self.predictSoft(X) > 0.5;
|
|
203
|
+
|
|
204
|
+
def predictSoft(self, X):
|
|
205
|
+
"""Make 'soft' (per-class confidence) predictions of the rbm on data X.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
X : MxN numpy array containing M data points with N features each
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
P : MxC numpy array of C class probabilities for each of the M data
|
|
212
|
+
"""
|
|
213
|
+
Y = np.zeros((X.shape[0],self.Wvx.shape[0]));
|
|
214
|
+
for j in range(X.shape[0]):
|
|
215
|
+
bxh = self.bh + self.Whx.dot(X[j,:].T)
|
|
216
|
+
bxv = self.bv + self.Wvx.dot(X[j,:].T)
|
|
217
|
+
mu = self.marg_h(Y[j,:],bxh)
|
|
218
|
+
G,tv,th = self.marg_bp(5, bxv, bxh)
|
|
219
|
+
Y[j,:] = tv;
|
|
220
|
+
return Y
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
# TODO: add momentum for learning update
|
|
224
|
+
def train(self, X, Y, Xv=None,Yv=None, stepsize=.01, stopGap=0.1, stopEpoch=10):
|
|
225
|
+
"""Train the (c)RBM
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
X : MxNx numpy array containing M data points with N features each
|
|
229
|
+
Y : MxNv numpy array of targets (visible units) for each data point in X
|
|
230
|
+
stepsize : scalar
|
|
231
|
+
The stepsize for gradient descent (decreases as 1 / iter).
|
|
232
|
+
stopTol : scalar
|
|
233
|
+
Tolerance for stopping criterion.
|
|
234
|
+
stopIter : int
|
|
235
|
+
The maximum number of steps before stopping.
|
|
236
|
+
activation : str
|
|
237
|
+
'logistic', 'htangent', or 'custom'. Sets the activation functions.
|
|
238
|
+
|
|
239
|
+
"""
|
|
240
|
+
# TODO: Shape & argument checking
|
|
241
|
+
|
|
242
|
+
# outer loop of (mini-batch) stochastic gradient descent
|
|
243
|
+
it, j = 1, 0 # iteration number & data index
|
|
244
|
+
nextPrint = 1 # next time to print info
|
|
245
|
+
done = 0 # end of loop flag
|
|
246
|
+
nBatch = 40
|
|
247
|
+
|
|
248
|
+
while not done:
|
|
249
|
+
step_i = 3.0*stepsize / (2.0+it) # step size evolution; classic 1/t decrease
|
|
250
|
+
|
|
251
|
+
dWvh, dWvx, dWhx, dbv, dbh = 0.0, 0.0, 0.0, 0.0, 0.0
|
|
252
|
+
# stochastic gradient update (one pass)
|
|
253
|
+
for jj in range(nBatch):
|
|
254
|
+
#print('j={}; jj={};'.format(j,jj));
|
|
255
|
+
j += 1
|
|
256
|
+
if j >= Y.shape[0]: j=0; it+=1;
|
|
257
|
+
# compute conditional model & required probabilities
|
|
258
|
+
bxh = self.bh + self.Whx.dot(X[j,:].T)
|
|
259
|
+
bxv = self.bv + self.Wvx.dot(X[j,:].T)
|
|
260
|
+
mu = self.marg_h(Y[j,:],bxh)
|
|
261
|
+
G,tv,th = self.marg_cd( 1, Y[j,:], bxv, bxh, 1)
|
|
262
|
+
#G,tv,th = self.marg_bp( min(4+it,50), bxv, bxh )
|
|
263
|
+
if (jj==1): #(np.random.rand() < .1):
|
|
264
|
+
plt.figure(1);
|
|
265
|
+
plt.subplot(221); plt.imshow(X[j,:].reshape(28,28)); plt.title('Observed X'); plt.draw();
|
|
266
|
+
plt.subplot(222); plt.imshow(tv.reshape(28,28)); plt.title('Model Prob'); plt.draw();
|
|
267
|
+
plt.subplot(223); plt.imshow(Y[j,:].reshape(28,28)); plt.title('Visible Y'); plt.draw();
|
|
268
|
+
plt.pause(.01);
|
|
269
|
+
# take gradient step:
|
|
270
|
+
dWvh += (np.outer(Y[j,:], mu) - G)
|
|
271
|
+
dWvx += (np.outer(Y[j,:], X[j,:]) - np.outer(tv,X[j,:]))
|
|
272
|
+
dWhx += (np.outer(mu, X[j,:]) - np.outer(th,X[j,:]))
|
|
273
|
+
dbv += (Y[j,:] - tv)
|
|
274
|
+
dbh += (mu - th)
|
|
275
|
+
|
|
276
|
+
self.Wvh += step_i * dWvh / nBatch
|
|
277
|
+
self.Wvx += step_i * dWvx / nBatch
|
|
278
|
+
self.Whx += step_i * dWhx / nBatch
|
|
279
|
+
self.bv += step_i * dbv / nBatch
|
|
280
|
+
self.bh += step_i * dbh / nBatch
|
|
281
|
+
|
|
282
|
+
print('it {} : Gap = {}'.format(it,self.nll_gap(X,Y,Xv,Yv)));
|
|
283
|
+
print(' {} {} {} {} {}'.format(np.mean(self.Wvx**2),np.mean(self.Whx**2),np.mean(self.Wvh**2),np.mean(self.bv**2),np.mean(self.bh**2)));
|
|
284
|
+
|
|
285
|
+
Jtr,Jva = 0,0 #self.nll(X,Y),self.nll(Xv,Yv);
|
|
286
|
+
if it >= nextPrint:
|
|
287
|
+
print('it {} : Gap = {}'.format(it,self.nll_gap(X,Y,Xv,Yv)));
|
|
288
|
+
print(' {} {} {} {} {}'.format(np.mean(self.Wvx**2),np.mean(self.Whx**2),np.mean(self.Wvh**2),np.mean(self.bv**2),np.mean(self.bh**2)));
|
|
289
|
+
#print('it {} : Jtr = {} / Jva = {}'.format(it,Jtr,Jva))
|
|
290
|
+
nextPrint += 1; #*= 2
|
|
291
|
+
|
|
292
|
+
# check if finished
|
|
293
|
+
done = (it > 1) and ((Jva - Jtr) > stopGap) or it >= stopEpoch
|
|
294
|
+
#it += 1 # counting epochs elsewhere now
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
#def err_k(self, X, Y):
|
|
300
|
+
# """Compute misclassification error rate. Assumes Y in 1-of-k form. """
|
|
301
|
+
# return self.err(X, from1ofK(Y,self.classes).ravel())
|
|
302
|
+
#
|
|
303
|
+
#
|
|
304
|
+
#def mse(self, X, Y):
|
|
305
|
+
# """Compute mean squared error of predictor 'obj' on test data (X,Y). """
|
|
306
|
+
# return mse_k(X, to1ofK(Y))
|
|
307
|
+
#
|
|
308
|
+
#
|
|
309
|
+
#def mse_k(self, X, Y):
|
|
310
|
+
# """Compute mean squared error of predictor; assumes Y is in 1-of-k format. """
|
|
311
|
+
# return np.power(Y - self.predictSoft(X), 2).sum(1).mean(0)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
## MUTATORS ####################################################################
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
################################################################################
|
|
319
|
+
################################################################################
|
|
320
|
+
################################################################################
|
|
321
|
+
|