Skip to content
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
5a5f8a9
trying ab_divergence
jb-regli Sep 16, 2017
3d65841
adding renyi as special case
jb-regli Sep 16, 2017
523210a
trying ab_divergence
jb-regli Sep 16, 2017
4682fa7
trying ab_divergence
jb-regli Sep 16, 2017
0e30c18
sign error ?
jb-regli Sep 16, 2017
cc63477
ignore data + add renyi divergence
jb-regli Sep 17, 2017
c79dbb2
cleaning
jb-regli Sep 17, 2017
397eb71
docstring
jb-regli Sep 17, 2017
005dd03
renyi divergence
jb-regli Sep 20, 2017
a38fa82
renyi examples in notebook
jb-regli Sep 20, 2017
102453c
renyi examples in notebook
jb-regli Sep 20, 2017
8fdbe52
branch
jb-regli Sep 21, 2017
f85d97d
hard reset
jb-regli Sep 21, 2017
b146080
hard reset
jb-regli Sep 21, 2017
d618579
renyi divergence
jb-regli Sep 27, 2017
9f9a889
renyi divergence improvement
jb-regli Sep 27, 2017
66b8e87
Renyi divergence improvement
jb-regli Sep 27, 2017
17bdb8b
Error
jb-regli Sep 27, 2017
377ff9c
Moved build_loss and gradient in the class
jb-regli Sep 27, 2017
4c67eed
Renyi exqmple + docstring
jb-regli Sep 27, 2017
5aa9a25
Merge branch 'master' of https://github.com/blei-lab/edward
jb-regli Sep 27, 2017
d4f98b0
Merge remote-tracking branch 'origin/master' into renyi_divergence
jb-regli Sep 27, 2017
c340e46
testing
jb-regli Sep 27, 2017
18fd32f
test
jb-regli Sep 27, 2017
8623ebc
Pep8 correction
jb-regli Sep 27, 2017
57c5ba0
remove irrelevant file from PR
jb-regli Sep 27, 2017
6fc9b8a
2-space indent
jb-regli Sep 27, 2017
0df0215
Markdown formated docstring
jb-regli Sep 27, 2017
671541b
Edited docstring
jb-regli Sep 27, 2017
9e0a3b7
Correct order of call
jb-regli Sep 27, 2017
821d102
Updated docstrings
jb-regli Sep 27, 2017
6de5523
Testing for Renyi VI
jb-regli Sep 28, 2017
9d58f4f
Add Renyi_div to shortcut
jb-regli Sep 29, 2017
bdcaa8f
Correct init
jb-regli Sep 29, 2017
dfad744
Allow renyidivergence
jb-regli Sep 29, 2017
e5c4867
Allow quick call
jb-regli Sep 29, 2017
89fe5cd
Debug shortcut
jb-regli Sep 29, 2017
da83c97
restore from edward
jb-regli Sep 29, 2017
a9139ed
Call shortcut for Renyi divergence
jb-regli Sep 29, 2017
717d236
Correct style
jb-regli Sep 29, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,9 @@ docs/*.html
# IDE related
.idea/
.vscode/

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove changes that aren't relevant for this PR? This includes changes to .gitignore here as well as deletion of CSVs.

# data:
data/*

# atom
.remote-sync.json
184 changes: 184 additions & 0 deletions edward/inferences/renyi_divergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six
import numpy as np

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As convention, we alphabetize the ordering of the import libraries.

import tensorflow as tf

from edward.inferences.variational_inference import VariationalInference
from edward.models import RandomVariable
from edward.util import copy

try:
from edward.models import Normal

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As convention, we use 2-space indent.

from tensorflow.contrib.distributions import kl_divergence
except Exception as e:
raise ImportError(
"{0}. Your TensorFlow version is not supported.".format(e))


class Renyi_divergence(VariationalInference):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As convention, we use CamelCase for class names.

"""Variational inference with the Renyi divergence

$ \text{D}_{R}^{(\alpha)}(q(z)||p(z \mid x))
= \frac{1}{\alpha-1} \log \int q(z)^{\alpha} p(z \mid x)^{1-\alpha} dz $

To perform the optimization, this class uses the techniques from

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Periods at end of sentences. (If you'd look at the generated API for the class, I recommend compiling the website following instructions from docs/.)

Renyi Divergence Variational Inference (Y. Li & al, 2016)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use bibtex for handling references in docstrings. This is handled by adding the appropriate bib entry to docs/tex/bib.bib; make sure it's also written in the right order: we sort bib entries by their year, then alphabetically according to their citekey within each year.

When using references, you can produce (Li et al., 2016) and Li et al. (2016) by writing [@li2016renyi] and @li2016renyirespectively, assuming thatli2016renyi` is the citekey.


# Notes:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstrings are parsed as Markdown and formatted in a somewhat specific way as they appear on the API docs. I recommend following the other classes, where you would denote a subsection as #### Notes and when writing bullet points, do, e.g.,

#### Notes

+ bullet 1
+ bullet 2
  + maybe bulleted list in a bullet

- Renyi divergence does not have any analytic version.
- Renyi divergence does not have any version for non reparametrizable

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does but the gradient estimator in @li2016variational doesn't. I recommend just stating that this inference algorithm is restricted to variational approximations whose random variables all satisfy rv.reparameterization_type == tf.contrib.distributions.FULLY_REPARAMETERIZED.

Also, instead of checking this during build_loss_and_gradients I recommend checking this during the __init__. This sort of check is done statically any graph construction similar to how we check for compatible shapes in all latent variables and data during __init__.

models.
- backward_pass = 'max': (extreme case $\alpha \rightarrow -\infty$)
the algorithm chooses the sample that has the maximum unnormalised
importance weight. This does not minimize the Renyi divergence
anymore.
- backward_pass = 'min': (extreme case $\alpha \rightarrow +\infty$)
the algorithm chooses the sample that has the minimum unnormalised
importance weight. This does not minimize the Renyi divergence
anymore. This mode is not describe in the paper but implemented
in the publicly available implementation of the paper's experiments.
"""
def __init__(self, *args, **kwargs):
super(Renyi_divergence, self).__init__(*args, **kwargs)

def initialize(self,
n_samples=32,
alpha=1.,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As convention, we append all numerics with 0, e.g., 1.0.

backward_pass='full',
*args, **kwargs):
"""Initialize inference algorithm. It initializes hyperparameters
and builds ops for the algorithm's computation graph.

Args:
n_samples: int, optional.
Number of samples from variational model for calculating
stochastic gradients.
alpha: float, optional.
Renyi divergence coefficient.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be useful to specify the domain of the coefficient. E.g., Must be greater than 0. or etc.

backward_pass: str, optional.
Backward pass mode to be used.
Options: 'min', 'max', 'full'
(see Renyi Divergence Variational Inference (Y. Li & al, 2016)
section 4.2)
"""
self.n_samples = n_samples
self.alpha = alpha
self.backward_pass = backward_pass

return super(Renyi_divergence, self).initialize(*args, **kwargs)

def build_loss_and_gradients(self, var_list):
"""Build the Renyi ELBO function.

Its automatic differentiation is a stochastic gradient of

$ \mcalL_{R}^{\alpha}(q; x) =
\frac{1}{1-\alpha} \log \dsE_{q} \left[
\left( \frac{p(x, z)}{q(z)}\right)^{1-\alpha} \right] $

It uses:
1. Monte Carlo approximation of the ELBO (Y. Li & al, 2016)
2. Reparameterization gradients (Kingma & al, 2014)
3. Stochastic approximation of the joint distribution (Y. Li & al, 2016)

# Notes
If the model is not reparameterizable, it returns a
NotImplementedError.
See Renyi Divergence Variational Inference (Y. Li & al, 2016)
for more details.
"""
is_reparameterizable = all([
rv.reparameterization_type ==
tf.contrib.distributions.FULLY_REPARAMETERIZED
for rv in six.itervalues(self.latent_vars)])

if is_reparameterizable:
p_log_prob = [0.0] * self.n_samples
q_log_prob = [0.0] * self.n_samples
base_scope = tf.get_default_graph().unique_name("inference") + '/'
for s in range(self.n_samples):
# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
scope = base_scope + tf.get_default_graph().unique_name("sample")
dict_swap = {}
for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable):
if isinstance(qx, RandomVariable):
qx_copy = copy(qx, scope=scope)
dict_swap[x] = qx_copy.value()
else:
dict_swap[x] = qx

for z, qz in six.iteritems(self.latent_vars):
# Copy q(z) to obtain new set of posterior samples.
qz_copy = copy(qz, scope=scope)
dict_swap[z] = qz_copy.value()
q_log_prob[s] += tf.reduce_sum(
self.scale.get(z, 1.0) * qz_copy.log_prob(dict_swap[z]))

for z in six.iterkeys(self.latent_vars):
z_copy = copy(z, dict_swap, scope=scope)
p_log_prob[s] += tf.reduce_sum(
self.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z]))

for x in six.iterkeys(self.data):
if isinstance(x, RandomVariable):
x_copy = copy(x, dict_swap, scope=scope)
p_log_prob[s] += tf.reduce_sum(
self.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x]))

logF = [p - q for p, q in zip(p_log_prob, q_log_prob)]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of logF, what about something like log_ratios, which is more Pythonic in snake_case and also more semantically meaningful?


if self.backward_pass == 'max':
logF = tf.stack(logF)
logF = tf.reduce_max(logF, 0)
loss = tf.reduce_mean(logF)
elif self.backward_pass == 'min':
logF = tf.stack(logF)
logF = tf.reduce_min(logF, 0)
loss = tf.reduce_mean(logF)
elif isclose(self.alpha, 1.0, abs_tol=10e-3):
loss = tf.reduce_mean(logF)
else:
logF = tf.stack(logF)
logF = logF * (1 - self.alpha)
logF_max = tf.reduce_max(logF, 0)
logF = tf.log(
tf.maximum(1e-9,
tf.reduce_mean(tf.exp(logF - logF_max), 0)))
logF=(logF + logF_max) / (1 - self.alpha)
loss=tf.reduce_mean(logF)
loss=-loss

if self.logging:
p_log_prob=tf.reduce_mean(p_log_prob)
q_log_prob=tf.reduce_mean(q_log_prob)
tf.summary.scalar("loss/p_log_prob", p_log_prob,
collections=[self._summary_key])
tf.summary.scalar("loss/q_log_prob", q_log_prob,
collections=[self._summary_key])

grads=tf.gradients(loss, var_list)
grads_and_vars=list(zip(grads, var_list))
return loss, grads_and_vars
else:
raise NotImplementedError("Variational Renyi inference only works with reparameterizable models")


#############
### UTILS ###
#############
def isclose(a, b, rel_tol=0.0, abs_tol=1e-3):
r"""
Almost equal

:param a:
:param b:
:param rel_tol:
:param abs_tol:
:return: Bool
"""
return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
91 changes: 91 additions & 0 deletions examples/vae_renyi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python
"""Variational auto-encoder for MNIST data.

References
----------
http://edwardlib.org/tutorials/decoder
http://edwardlib.org/tutorials/inference-networks
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import edward as ed
from edward.inferences.renyi_divergence import Renyi_divergence
import numpy as np
import os
import tensorflow as tf

from edward.models import Bernoulli, Normal
from edward.util import Progbar
from keras.layers import Dense
from scipy.misc import imsave
from tensorflow.examples.tutorials.mnist import input_data

DATA_DIR = "data/mnist"
IMG_DIR = "img"

if not os.path.exists(DATA_DIR):
os.makedirs(DATA_DIR)
if not os.path.exists(IMG_DIR):
os.makedirs(IMG_DIR)

ed.set_seed(42)

M = 100 # batch size during training
d = 2 # latent dimension
alpha = 0.5 # alpha values for reny divergence
n_samples = 5 # number of samples used to estimate the Renyi ELBO
backward_pass = 'max' # Back propagation style ('min', 'max' or 'full')

# DATA. MNIST batches are fed at training time.
mnist = input_data.read_data_sets(DATA_DIR)

# MODEL
# Define a subgraph of the full model, corresponding to a minibatch of
# size M.
z = Normal(loc=tf.zeros([M, d]), scale=tf.ones([M, d]))
hidden = Dense(256, activation='relu')(z.value())
x = Bernoulli(logits=Dense(28 * 28)(hidden))

# INFERENCE
# Define a subgraph of the variational model, corresponding to a
# minibatch of size M.
x_ph = tf.placeholder(tf.int32, [M, 28 * 28])
hidden = Dense(256, activation='relu')(tf.cast(x_ph, tf.float32))
qz = Normal(loc=Dense(d)(hidden),
scale=Dense(d, activation='softplus')(hidden))

# Bind p(x, z) and q(z | x) to the same TensorFlow placeholder for x.
inference = Renyi_divergence({z: qz}, data={x: x_ph})

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code looks exactly the same as an older version of vae.py but only differs in this line. To keep the VAE versions better synced, could you add a comment suggesting that this is also an alternative in the existing vae.py?

Ideally, we'd like a specific application where ed.RenyiDivergence produces better results by some metric than alternatives. IIRC, the paper had some interesting results for a Bayesian neural net on some specific UCI data sets. That would be great to have and reproduce some of their results.

If you don't have time for this, we can leave it off for now and raise it as a Github issue post-merging this PR.

optimizer = tf.train.RMSPropOptimizer(0.01, epsilon=1.0)
inference.initialize(optimizer=optimizer,
n_samples=n_samples,
alpha=alpha,
backward_pass=backward_pass)
sess = ed.get_session()
tf.global_variables_initializer().run()

n_epoch = 100
n_iter_per_epoch = 1000
for epoch in range(n_epoch):
avg_loss = 0.0

pbar = Progbar(n_iter_per_epoch)
for t in range(1, n_iter_per_epoch + 1):
pbar.update(t)
x_train, _ = mnist.train.next_batch(M)
x_train = np.random.binomial(1, x_train)
info_dict = inference.update(feed_dict={x_ph: x_train})
avg_loss += info_dict['loss']

# Print a lower bound to the average marginal likelihood for an
# image.
avg_loss = avg_loss / n_iter_per_epoch
avg_loss = avg_loss / M
print("log p(x) >= {:0.3f}".format(avg_loss))

# Prior predictive check.
imgs = sess.run(x)
for m in range(M):
imsave(os.path.join(IMG_DIR, '%d.png') % m, imgs[m].reshape(28, 28))
15 changes: 0 additions & 15 deletions notebooks/data/insteval_dept_ranefs_r.csv

This file was deleted.

Loading