diff --git a/docs/tex/bib.bib b/docs/tex/bib.bib index fb1284f1c..b075f49a1 100644 --- a/docs/tex/bib.bib +++ b/docs/tex/bib.bib @@ -712,6 +712,14 @@ @article{johnson2016composing year = {2016}, } +@article{li2016renyi, + title={R{\'e}nyi divergence variational inference}, + author={Li, Yingzhen and Turner, Richard E}, + booktitle={Advances in Neural Information Processing Systems}, + pages={1073--1081}, + year={2016} +} + @article{mohamed2016learning, author = {Mohamed, Shakir and Lakshminarayanan, Balaji}, title = {{Learning in Implicit Generative Models}}, @@ -801,4 +809,3 @@ @inproceedings{tran2017deep booktitle = {International Conference on Learning Representations}, year = {2017} } - diff --git a/edward/__init__.py b/edward/__init__.py index ce3305795..ba8e6d746 100644 --- a/edward/__init__.py +++ b/edward/__init__.py @@ -15,7 +15,7 @@ KLpq, KLqp, ReparameterizationKLqp, ReparameterizationKLKLqp, \ ReparameterizationEntropyKLqp, ScoreKLqp, ScoreKLKLqp, ScoreEntropyKLqp, \ ScoreRBKLqp, WakeSleep, GANInference, BiGANInference, WGANInference, \ - ImplicitKLqp, MAP, Laplace, complete_conditional, Gibbs + ImplicitKLqp, MAP, Laplace, complete_conditional, Gibbs, RenyiDivergence from edward.models import RandomVariable from edward.util import check_data, check_latent_vars, copy, dot, \ get_ancestors, get_blanket, get_children, get_control_variate_coef, \ @@ -56,6 +56,7 @@ 'BiGANInference', 'WGANInference', 'ImplicitKLqp', + 'RenyiDivergence', 'MAP', 'Laplace', 'complete_conditional', diff --git a/edward/inferences/__init__.py b/edward/inferences/__init__.py index 38262fcb7..2aed3fd1f 100644 --- a/edward/inferences/__init__.py +++ b/edward/inferences/__init__.py @@ -22,6 +22,7 @@ from edward.inferences.variational_inference import * from edward.inferences.wake_sleep import * from edward.inferences.wgan_inference import * +from edward.inferences.renyi_divergence import * from tensorflow.python.util.all_util import remove_undocumented @@ -51,6 +52,7 @@ 'VariationalInference', 'WakeSleep', 'WGANInference', + 'RenyiDivergence', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/edward/inferences/renyi_divergence.py b/edward/inferences/renyi_divergence.py new file mode 100644 index 000000000..447e457f0 --- /dev/null +++ b/edward/inferences/renyi_divergence.py @@ -0,0 +1,174 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import six +import tensorflow as tf + +from edward.inferences.variational_inference import VariationalInference +from edward.models import RandomVariable +from edward.util import copy + +try: + from edward.models import Normal + from tensorflow.contrib.distributions import kl_divergence +except Exception as e: + raise ImportError("{0}. Your TensorFlow version is not supported.".format(e)) + + +class RenyiDivergence(VariationalInference): + """Variational inference with the Renyi divergence [@li2016renyi]. + + It minimizes the Renyi divergence + + $ \text{D}_{R}^{(\alpha)}(q(z)||p(z \mid x)) + = \frac{1}{\alpha-1} \log \int q(z)^{\alpha} p(z \mid x)^{1-\alpha} dz.$ + + The optimization is performed using the gradient estimator as defined in + @li2016renyi. + + #### Notes + + The gradient estimator used here does not have any analytic version. + + The gradient estimator used here does not have any version for non + reparametrizable models. + + backward_pass = 'max': (extreme case $\alpha \rightarrow -\infty$) + the algorithm chooses the sample that has the maximum unnormalised + importance weight. This does not minimize the Renyi divergence + anymore. + + backward_pass = 'min': (extreme case $\alpha \rightarrow +\infty$) + the algorithm chooses the sample that has the minimum unnormalised + importance weight. This does not minimize the Renyi divergence + anymore. This mode is not describe in the paper but implemented + in the publicly available implementation of the paper's experiments. + """ + + def __init__(self, *args, **kwargs): + + super(RenyiDivergence, self).__init__(*args, **kwargs) + + is_reparameterizable = all([ + rv.reparameterization_type == + tf.contrib.distributions.FULLY_REPARAMETERIZED + for rv in six.itervalues(self.latent_vars)]) + + if not is_reparameterizable: + raise NotImplementedError( + "Variational Renyi inference only works with reparameterizable" + " models.") + + def initialize(self, + n_samples=32, + alpha=1.0, + backward_pass='full', + *args, **kwargs): + """Initialize inference algorithm. It initializes hyperparameters + and builds ops for the algorithm's computation graph. + + Args: + n_samples: int, optional. + Number of samples from variational model for calculating + stochastic gradients. + alpha: float, optional. + Renyi divergence coefficient. $\alpha \in \mathbb{R}$. + When $\alpha < 0$, the algorithm still does something sensible but + does not minimize the Renyi divergence anymore. + (see [@li2016renyi] - section 4.2) + backward_pass: str, optional. + Backward pass mode to be used. + Options: 'min', 'max', 'full' + (see [@li2016renyi] - section 4.2) + """ + self.n_samples = n_samples + self.alpha = alpha + self.backward_pass = backward_pass + + return super(RenyiDivergence, self).initialize(*args, **kwargs) + + def build_loss_and_gradients(self, var_list): + """Build the Renyi ELBO function. + + Its automatic differentiation is a stochastic gradient of + + $ \mcalL_{R}^{\alpha}(q; x) = + \frac{1}{1-\alpha} \log \dsE_{q} \left[ + \left( \frac{p(x, z)}{q(z)}\right)^{1-\alpha} \right].$ + + It uses: + + + Monte Carlo approximation of the ELBO [@li2016renyi]. + + Reparameterization gradients [@kingma2014auto]. + + Stochastic approximation of the joint distribution [@li2016renyi]. + + #### Notes + + + If the model is not reparameterizable, it returns a + NotImplementedError. + + See Renyi Divergence Variational Inference [@li2016renyi] for + more details. + """ + p_log_prob = [0.0] * self.n_samples + q_log_prob = [0.0] * self.n_samples + base_scope = tf.get_default_graph().unique_name("inference") + '/' + for s in range(self.n_samples): + # Form dictionary in order to replace conditioning on prior or + # observed variable with conditioning on a specific value. + scope = base_scope \ + + tf.get_default_graph().unique_name("sample") + dict_swap = {} + for x, qx in six.iteritems(self.data): + if isinstance(x, RandomVariable): + if isinstance(qx, RandomVariable): + qx_copy = copy(qx, scope=scope) + dict_swap[x] = qx_copy.value() + else: + dict_swap[x] = qx + + for z, qz in six.iteritems(self.latent_vars): + # Copy q(z) to obtain new set of posterior samples. + qz_copy = copy(qz, scope=scope) + dict_swap[z] = qz_copy.value() + q_log_prob[s] += tf.reduce_sum( + self.scale.get(z, 1.0) * qz_copy.log_prob(dict_swap[z])) + + for z in six.iterkeys(self.latent_vars): + z_copy = copy(z, dict_swap, scope=scope) + p_log_prob[s] += tf.reduce_sum( + self.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) + + for x in six.iterkeys(self.data): + if isinstance(x, RandomVariable): + x_copy = copy(x, dict_swap, scope=scope) + p_log_prob[s] += tf.reduce_sum( + self.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) + + log_ratios = [p - q for p, q in zip(p_log_prob, q_log_prob)] + + if self.backward_pass == 'max': + loss = tf.reduce_max(log_ratios, 0) + elif self.backward_pass == 'min': + loss = tf.reduce_min(log_ratios, 0) + elif np.abs(self.alpha - 1.0) < 10e-3: + loss = tf.reduce_mean(log_ratios) + else: + log_ratios = tf.stack(log_ratios) + log_ratios = log_ratios * (1 - self.alpha) + log_ratios_max = tf.reduce_max(log_ratios, 0) + log_ratios = tf.log( + tf.maximum(1e-9, + tf.reduce_mean(tf.exp(log_ratios - log_ratios_max), 0))) + log_ratios = (log_ratios + log_ratios_max) / (1 - self.alpha) + loss = tf.reduce_mean(log_ratios) + loss = -loss + + if self.logging: + p_log_prob = tf.reduce_mean(p_log_prob) + q_log_prob = tf.reduce_mean(q_log_prob) + tf.summary.scalar("loss/p_log_prob", p_log_prob, + collections=[self._summary_key]) + tf.summary.scalar("loss/q_log_prob", q_log_prob, + collections=[self._summary_key]) + + grads = tf.gradients(loss, var_list) + grads_and_vars = list(zip(grads, var_list)) + return loss, grads_and_vars diff --git a/examples/vae.py b/examples/vae.py index 0550e5d3c..0262fd147 100644 --- a/examples/vae.py +++ b/examples/vae.py @@ -1,5 +1,13 @@ #!/usr/bin/env python -"""Variational auto-encoder for MNIST data. +"""Renyi Variational auto-encoder for MNIST data. + +We here use the Renyi variational objective [@li2016renyi]. +This objective allows to vary the divergence measured used for optimization, by +tuning the meta-parameter $\alpha$. + +#### Notes +The Renyi variational objective reduces down to the classic kl_divergence for +$\alpha=1$ and the "standard" VAE is obtained. References ---------- @@ -18,39 +26,26 @@ from edward.models import Bernoulli, Normal from edward.util import Progbar from keras.layers import Dense -from observations import mnist +from tensorflow.examples.tutorials.mnist import input_data from scipy.misc import imsave - -def generator(array, batch_size): - """Generate batch with respect to array's first axis.""" - start = 0 # pointer to where we are in iteration - while True: - stop = start + batch_size - diff = stop - array.shape[0] - if diff <= 0: - batch = array[start:stop] - start += batch_size - else: - batch = np.concatenate((array[start:], array[:diff])) - start = diff - batch = batch.astype(np.float32) / 255.0 # normalize pixel intensities - batch = np.random.binomial(1, batch) # binarize images - yield batch - - ed.set_seed(42) -data_dir = "/tmp/data" -out_dir = "/tmp/out" -if not os.path.exists(out_dir): - os.makedirs(out_dir) +DATA_DIR = "data/mnist" +IMG_DIR = "img" +if not os.path.exists(DATA_DIR): + os.makedirs(DATA_DIR) +if not os.path.exists(IMG_DIR): + os.makedirs(IMG_DIR) + M = 100 # batch size during training d = 2 # latent dimension +alpha = 0.5 # alpha values for renyi divergence +n_samples = 5 # number of samples used to estimate the Renyi ELBO +backward_pass = 'full' # Back propagation style ('min', 'max' or 'full') # DATA. MNIST batches are fed at training time. -(x_train, _), (x_test, _) = mnist(data_dir) -x_train_generator = generator(x_train, M) +mnist = input_data.read_data_sets(DATA_DIR) # MODEL # Define a subgraph of the full model, corresponding to a minibatch of @@ -68,32 +63,39 @@ def generator(array, batch_size): scale=Dense(d, activation='softplus')(hidden)) # Bind p(x, z) and q(z | x) to the same TensorFlow placeholder for x. -inference = ed.KLqp({z: qz}, data={x: x_ph}) +inference = ed.RenyiDivergence({z: qz}, data={x: x_ph}) optimizer = tf.train.RMSPropOptimizer(0.01, epsilon=1.0) -inference.initialize(optimizer=optimizer) - +inference.initialize(optimizer=optimizer, + n_samples=n_samples, + alpha=alpha, + backward_pass=backward_pass) +# inference = ed.KLqp({z: qz}, data={x: x_ph}) +# optimizer = tf.train.RMSPropOptimizer(0.01, epsilon=1.0) +# inference.initialize(optimizer=optimizer) + +sess = ed.get_session() tf.global_variables_initializer().run() n_epoch = 100 -n_iter_per_epoch = x_train.shape[0] // M -for epoch in range(1, n_epoch + 1): - print("Epoch: {0}".format(epoch)) +n_iter_per_epoch = 1000 +for epoch in range(n_epoch): avg_loss = 0.0 pbar = Progbar(n_iter_per_epoch) for t in range(1, n_iter_per_epoch + 1): pbar.update(t) - x_batch = next(x_train_generator) - info_dict = inference.update(feed_dict={x_ph: x_batch}) + x_train, _ = mnist.train.next_batch(M) + x_train = np.random.binomial(1, x_train) + info_dict = inference.update(feed_dict={x_ph: x_train}) avg_loss += info_dict['loss'] # Print a lower bound to the average marginal likelihood for an # image. avg_loss = avg_loss / n_iter_per_epoch avg_loss = avg_loss / M - print("-log p(x) <= {:0.3f}".format(avg_loss)) + print("log p(x) >= {:0.3f}".format(avg_loss)) # Prior predictive check. - images = x.eval() + imgs = sess.run(x) for m in range(M): - imsave(os.path.join(out_dir, '%d.png') % m, images[m].reshape(28, 28)) + imsave(os.path.join(IMG_DIR, '%d.png') % m, imgs[m].reshape(28, 28)) diff --git a/tests/inferences/test_renyi_divergence.py b/tests/inferences/test_renyi_divergence.py new file mode 100644 index 000000000..7dab7c2d1 --- /dev/null +++ b/tests/inferences/test_renyi_divergence.py @@ -0,0 +1,110 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import edward as ed +import numpy as np +import tensorflow as tf + +from edward.models import Bernoulli, Normal + + +class test_renyi_divergence_class(tf.test.TestCase): + + def _test_normal_normal(self, *args, **kwargs): + with self.test_session() as sess: + x_data = np.array([0.0] * 50, dtype=np.float32) + + mu = Normal(loc=0.0, scale=1.0) + x = Normal(loc=mu, scale=1.0, sample_shape=50) + + qmu_loc = tf.Variable(tf.random_normal([])) + qmu_scale = tf.nn.softplus(tf.Variable(tf.random_normal([]))) + qmu = Normal(loc=qmu_loc, scale=qmu_scale) + + # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) + inference = ed.RenyiDivergence({mu: qmu}, data={x: x_data}) + inference.run(*args, **kwargs) + + self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1) + self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51), + rtol=1e-1, atol=1e-1) + + variables = tf.get_collection( + tf.GraphKeys.GLOBAL_VARIABLES, scope='optimizer') + old_t, old_variables = sess.run([inference.t, variables]) + self.assertEqual(old_t, inference.n_iter) + sess.run(inference.reset) + new_t, new_variables = sess.run([inference.t, variables]) + self.assertEqual(new_t, 0) + self.assertNotEqual(old_variables, new_variables) + + def _test_model_parameter(self, *args, **kwargs): + with self.test_session() as sess: + x_data = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 1]) + + p = tf.sigmoid(tf.Variable(0.5)) + x = Bernoulli(probs=p, sample_shape=10) + + inference = ed.RenyiDivergence({}, data={x: x_data}) + inference.run(*args, **kwargs) + + self.assertAllClose(p.eval(), 0.2, rtol=5e-2, atol=5e-2) + + def test_renyi_divergence(self): + # normal-normal - special case - KL: + self._test_normal_normal(n_samples=5, + n_iter=200, + alpha=1.0, + backward_pass='full') + # normal-normal - special case - Max: + self._test_normal_normal(n_samples=1, + n_iter=200, + alpha=2.0, + backward_pass='max') + # normal-normal - special case - Min: + self._test_normal_normal(n_samples=5, + n_iter=200, + alpha=2.0, + backward_pass='min') + # normal-normal - normal case - alpha < 0: + self._test_normal_normal(n_samples=1, + n_iter=200, + alpha=-0.5, + backward_pass='full') + # normal-normal - normal case - alpha > 0: + self._test_normal_normal(n_samples=1, + n_iter=200, + alpha=0.5, + backward_pass='full') + + # model parameter - special case - KL: + self._test_model_parameter(n_samples=5, + n_iter=100, + alpha=1.0, + backward_pass='full') + # model parameter - special case - Max: + self._test_model_parameter(n_samples=5, + n_iter=100, + alpha=1.0, + backward_pass='max') + # model parameter - special case - Min: + self._test_model_parameter(n_samples=5, + n_iter=100, + alpha=1.0, + backward_pass='min') + # model parameter - normal case - alpha < 0: + self._test_model_parameter(n_samples=5, + n_iter=100, + alpha=-0.5, + backward_pass='full') + # model parameter - normal case - alpha > 0: + self._test_model_parameter(n_samples=5, + n_iter=100, + alpha=0.5, + backward_pass='full') + + +if __name__ == '__main__': + ed.set_seed(42) + tf.test.main()