-
Notifications
You must be signed in to change notification settings - Fork 744
Renyi divergence #769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Renyi divergence #769
Changes from 24 commits
5a5f8a9
3d65841
523210a
4682fa7
0e30c18
cc63477
c79dbb2
397eb71
005dd03
a38fa82
102453c
8fdbe52
f85d97d
b146080
d618579
9f9a889
66b8e87
17bdb8b
377ff9c
4c67eed
5aa9a25
d4f98b0
c340e46
18fd32f
8623ebc
57c5ba0
6fc9b8a
0df0215
671541b
9e0a3b7
821d102
6de5523
9d58f4f
bdcaa8f
dfad744
e5c4867
89fe5cd
da83c97
a9139ed
717d236
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -100,3 +100,9 @@ docs/*.html | |
| # IDE related | ||
| .idea/ | ||
| .vscode/ | ||
|
|
||
| # data: | ||
| data/* | ||
|
|
||
| # atom | ||
| .remote-sync.json | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,184 @@ | ||
| from __future__ import absolute_import | ||
| from __future__ import division | ||
| from __future__ import print_function | ||
|
|
||
| import six | ||
| import numpy as np | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As convention, we alphabetize the ordering of the import libraries. |
||
| import tensorflow as tf | ||
|
|
||
| from edward.inferences.variational_inference import VariationalInference | ||
| from edward.models import RandomVariable | ||
| from edward.util import copy | ||
|
|
||
| try: | ||
| from edward.models import Normal | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As convention, we use 2-space indent. |
||
| from tensorflow.contrib.distributions import kl_divergence | ||
| except Exception as e: | ||
| raise ImportError( | ||
| "{0}. Your TensorFlow version is not supported.".format(e)) | ||
|
|
||
|
|
||
| class Renyi_divergence(VariationalInference): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As convention, we use CamelCase for class names. |
||
| """Variational inference with the Renyi divergence | ||
|
|
||
| $ \text{D}_{R}^{(\alpha)}(q(z)||p(z \mid x)) | ||
| = \frac{1}{\alpha-1} \log \int q(z)^{\alpha} p(z \mid x)^{1-\alpha} dz $ | ||
|
|
||
| To perform the optimization, this class uses the techniques from | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Periods at end of sentences. (If you'd look at the generated API for the class, I recommend compiling the website following instructions from |
||
| Renyi Divergence Variational Inference (Y. Li & al, 2016) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use bibtex for handling references in docstrings. This is handled by adding the appropriate bib entry to When using references, you can produce |
||
|
|
||
| # Notes: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstrings are parsed as Markdown and formatted in a somewhat specific way as they appear on the API docs. I recommend following the other classes, where you would denote a subsection as |
||
| - Renyi divergence does not have any analytic version. | ||
| - Renyi divergence does not have any version for non reparametrizable | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does but the gradient estimator in Also, instead of checking this during |
||
| models. | ||
| - backward_pass = 'max': (extreme case $\alpha \rightarrow -\infty$) | ||
| the algorithm chooses the sample that has the maximum unnormalised | ||
| importance weight. This does not minimize the Renyi divergence | ||
| anymore. | ||
| - backward_pass = 'min': (extreme case $\alpha \rightarrow +\infty$) | ||
| the algorithm chooses the sample that has the minimum unnormalised | ||
| importance weight. This does not minimize the Renyi divergence | ||
| anymore. This mode is not describe in the paper but implemented | ||
| in the publicly available implementation of the paper's experiments. | ||
| """ | ||
| def __init__(self, *args, **kwargs): | ||
| super(Renyi_divergence, self).__init__(*args, **kwargs) | ||
|
|
||
| def initialize(self, | ||
| n_samples=32, | ||
| alpha=1., | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As convention, we append all numerics with 0, e.g., |
||
| backward_pass='full', | ||
| *args, **kwargs): | ||
| """Initialize inference algorithm. It initializes hyperparameters | ||
| and builds ops for the algorithm's computation graph. | ||
|
|
||
| Args: | ||
| n_samples: int, optional. | ||
| Number of samples from variational model for calculating | ||
| stochastic gradients. | ||
| alpha: float, optional. | ||
| Renyi divergence coefficient. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be useful to specify the domain of the coefficient. E.g., |
||
| backward_pass: str, optional. | ||
| Backward pass mode to be used. | ||
| Options: 'min', 'max', 'full' | ||
| (see Renyi Divergence Variational Inference (Y. Li & al, 2016) | ||
| section 4.2) | ||
| """ | ||
| self.n_samples = n_samples | ||
| self.alpha = alpha | ||
| self.backward_pass = backward_pass | ||
|
|
||
| return super(Renyi_divergence, self).initialize(*args, **kwargs) | ||
|
|
||
| def build_loss_and_gradients(self, var_list): | ||
| """Build the Renyi ELBO function. | ||
|
|
||
| Its automatic differentiation is a stochastic gradient of | ||
|
|
||
| $ \mcalL_{R}^{\alpha}(q; x) = | ||
| \frac{1}{1-\alpha} \log \dsE_{q} \left[ | ||
| \left( \frac{p(x, z)}{q(z)}\right)^{1-\alpha} \right] $ | ||
|
|
||
| It uses: | ||
| 1. Monte Carlo approximation of the ELBO (Y. Li & al, 2016) | ||
| 2. Reparameterization gradients (Kingma & al, 2014) | ||
| 3. Stochastic approximation of the joint distribution (Y. Li & al, 2016) | ||
|
|
||
| # Notes | ||
| If the model is not reparameterizable, it returns a | ||
| NotImplementedError. | ||
| See Renyi Divergence Variational Inference (Y. Li & al, 2016) | ||
| for more details. | ||
| """ | ||
| is_reparameterizable = all([ | ||
| rv.reparameterization_type == | ||
| tf.contrib.distributions.FULLY_REPARAMETERIZED | ||
| for rv in six.itervalues(self.latent_vars)]) | ||
|
|
||
| if is_reparameterizable: | ||
| p_log_prob = [0.0] * self.n_samples | ||
| q_log_prob = [0.0] * self.n_samples | ||
| base_scope = tf.get_default_graph().unique_name("inference") + '/' | ||
| for s in range(self.n_samples): | ||
| # Form dictionary in order to replace conditioning on prior or | ||
| # observed variable with conditioning on a specific value. | ||
| scope = base_scope + tf.get_default_graph().unique_name("sample") | ||
| dict_swap = {} | ||
| for x, qx in six.iteritems(self.data): | ||
| if isinstance(x, RandomVariable): | ||
| if isinstance(qx, RandomVariable): | ||
| qx_copy = copy(qx, scope=scope) | ||
| dict_swap[x] = qx_copy.value() | ||
| else: | ||
| dict_swap[x] = qx | ||
|
|
||
| for z, qz in six.iteritems(self.latent_vars): | ||
| # Copy q(z) to obtain new set of posterior samples. | ||
| qz_copy = copy(qz, scope=scope) | ||
| dict_swap[z] = qz_copy.value() | ||
| q_log_prob[s] += tf.reduce_sum( | ||
| self.scale.get(z, 1.0) * qz_copy.log_prob(dict_swap[z])) | ||
|
|
||
| for z in six.iterkeys(self.latent_vars): | ||
| z_copy = copy(z, dict_swap, scope=scope) | ||
| p_log_prob[s] += tf.reduce_sum( | ||
| self.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) | ||
|
|
||
| for x in six.iterkeys(self.data): | ||
| if isinstance(x, RandomVariable): | ||
| x_copy = copy(x, dict_swap, scope=scope) | ||
| p_log_prob[s] += tf.reduce_sum( | ||
| self.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) | ||
|
|
||
| logF = [p - q for p, q in zip(p_log_prob, q_log_prob)] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of |
||
|
|
||
| if self.backward_pass == 'max': | ||
| logF = tf.stack(logF) | ||
| logF = tf.reduce_max(logF, 0) | ||
| loss = tf.reduce_mean(logF) | ||
| elif self.backward_pass == 'min': | ||
| logF = tf.stack(logF) | ||
| logF = tf.reduce_min(logF, 0) | ||
| loss = tf.reduce_mean(logF) | ||
| elif isclose(self.alpha, 1.0, abs_tol=10e-3): | ||
| loss = tf.reduce_mean(logF) | ||
| else: | ||
| logF = tf.stack(logF) | ||
| logF = logF * (1 - self.alpha) | ||
| logF_max = tf.reduce_max(logF, 0) | ||
| logF = tf.log( | ||
| tf.maximum(1e-9, | ||
| tf.reduce_mean(tf.exp(logF - logF_max), 0))) | ||
| logF=(logF + logF_max) / (1 - self.alpha) | ||
| loss=tf.reduce_mean(logF) | ||
| loss=-loss | ||
|
|
||
| if self.logging: | ||
| p_log_prob=tf.reduce_mean(p_log_prob) | ||
| q_log_prob=tf.reduce_mean(q_log_prob) | ||
| tf.summary.scalar("loss/p_log_prob", p_log_prob, | ||
| collections=[self._summary_key]) | ||
| tf.summary.scalar("loss/q_log_prob", q_log_prob, | ||
| collections=[self._summary_key]) | ||
|
|
||
| grads=tf.gradients(loss, var_list) | ||
| grads_and_vars=list(zip(grads, var_list)) | ||
| return loss, grads_and_vars | ||
| else: | ||
| raise NotImplementedError("Variational Renyi inference only works with reparameterizable models") | ||
|
|
||
|
|
||
| ############# | ||
| ### UTILS ### | ||
| ############# | ||
| def isclose(a, b, rel_tol=0.0, abs_tol=1e-3): | ||
| r""" | ||
| Almost equal | ||
|
|
||
| :param a: | ||
| :param b: | ||
| :param rel_tol: | ||
| :param abs_tol: | ||
| :return: Bool | ||
| """ | ||
| return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| #!/usr/bin/env python | ||
| """Variational auto-encoder for MNIST data. | ||
|
|
||
| References | ||
| ---------- | ||
| http://edwardlib.org/tutorials/decoder | ||
| http://edwardlib.org/tutorials/inference-networks | ||
| """ | ||
| from __future__ import absolute_import | ||
| from __future__ import division | ||
| from __future__ import print_function | ||
|
|
||
| import edward as ed | ||
| from edward.inferences.renyi_divergence import Renyi_divergence | ||
| import numpy as np | ||
| import os | ||
| import tensorflow as tf | ||
|
|
||
| from edward.models import Bernoulli, Normal | ||
| from edward.util import Progbar | ||
| from keras.layers import Dense | ||
| from scipy.misc import imsave | ||
| from tensorflow.examples.tutorials.mnist import input_data | ||
|
|
||
| DATA_DIR = "data/mnist" | ||
| IMG_DIR = "img" | ||
|
|
||
| if not os.path.exists(DATA_DIR): | ||
| os.makedirs(DATA_DIR) | ||
| if not os.path.exists(IMG_DIR): | ||
| os.makedirs(IMG_DIR) | ||
|
|
||
| ed.set_seed(42) | ||
|
|
||
| M = 100 # batch size during training | ||
| d = 2 # latent dimension | ||
| alpha = 0.5 # alpha values for reny divergence | ||
| n_samples = 5 # number of samples used to estimate the Renyi ELBO | ||
| backward_pass = 'max' # Back propagation style ('min', 'max' or 'full') | ||
|
|
||
| # DATA. MNIST batches are fed at training time. | ||
| mnist = input_data.read_data_sets(DATA_DIR) | ||
|
|
||
| # MODEL | ||
| # Define a subgraph of the full model, corresponding to a minibatch of | ||
| # size M. | ||
| z = Normal(loc=tf.zeros([M, d]), scale=tf.ones([M, d])) | ||
| hidden = Dense(256, activation='relu')(z.value()) | ||
| x = Bernoulli(logits=Dense(28 * 28)(hidden)) | ||
|
|
||
| # INFERENCE | ||
| # Define a subgraph of the variational model, corresponding to a | ||
| # minibatch of size M. | ||
| x_ph = tf.placeholder(tf.int32, [M, 28 * 28]) | ||
| hidden = Dense(256, activation='relu')(tf.cast(x_ph, tf.float32)) | ||
| qz = Normal(loc=Dense(d)(hidden), | ||
| scale=Dense(d, activation='softplus')(hidden)) | ||
|
|
||
| # Bind p(x, z) and q(z | x) to the same TensorFlow placeholder for x. | ||
| inference = Renyi_divergence({z: qz}, data={x: x_ph}) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code looks exactly the same as an older version of Ideally, we'd like a specific application where If you don't have time for this, we can leave it off for now and raise it as a Github issue post-merging this PR. |
||
| optimizer = tf.train.RMSPropOptimizer(0.01, epsilon=1.0) | ||
| inference.initialize(optimizer=optimizer, | ||
| n_samples=n_samples, | ||
| alpha=alpha, | ||
| backward_pass=backward_pass) | ||
| sess = ed.get_session() | ||
| tf.global_variables_initializer().run() | ||
|
|
||
| n_epoch = 100 | ||
| n_iter_per_epoch = 1000 | ||
| for epoch in range(n_epoch): | ||
| avg_loss = 0.0 | ||
|
|
||
| pbar = Progbar(n_iter_per_epoch) | ||
| for t in range(1, n_iter_per_epoch + 1): | ||
| pbar.update(t) | ||
| x_train, _ = mnist.train.next_batch(M) | ||
| x_train = np.random.binomial(1, x_train) | ||
| info_dict = inference.update(feed_dict={x_ph: x_train}) | ||
| avg_loss += info_dict['loss'] | ||
|
|
||
| # Print a lower bound to the average marginal likelihood for an | ||
| # image. | ||
| avg_loss = avg_loss / n_iter_per_epoch | ||
| avg_loss = avg_loss / M | ||
| print("log p(x) >= {:0.3f}".format(avg_loss)) | ||
|
|
||
| # Prior predictive check. | ||
| imgs = sess.run(x) | ||
| for m in range(M): | ||
| imsave(os.path.join(IMG_DIR, '%d.png') % m, imgs[m].reshape(28, 28)) | ||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove changes that aren't relevant for this PR? This includes changes to
.gitignorehere as well as deletion of CSVs.