Deep Learning models are widely used in supervised learning problem that involves learning to approximate the mean of the conditional distribution \(\hat{y}=f_{\theta}(x)\). In this case usually the network is trained using minimization of sum of squares errors or cross entropy error function over a set of training data \(\{\mathbf{x}_{1:N},\mathbf{y}_{1:N}\}\) of the form With this approach it is explicity assumed that there is a deterministic \(1-to-1\) mapping between a given input variables \(\mathbf{x}=\{x_1, \ldots, x_d\}\) and target variable \(\mathbf{y}=\{y_1, \ldots,y_c\}\) without any uncertainity.

As the result the output of the network trained by this approach approximates the conditional mean of the output in the training data conditioned on the input vector. For classification problem with well chosed target coding scheme these averages represents the posterior probability of class membership and thus can be regarded as optimal. For problem involving the prediction of continous variable especiall when the mapping to learned is multi-valued, the conditional averages is not usually a good description of data and doesnt have power to modal distribution of output with complex. In this case, the conditional averages are not usually good descriptions of data and don’t have the power to modal output with complexity. One way to solve this problem is to model the complete conditional probability density instead, and this is the approach used by Mixture Density Networks (MDN).

Mixture Density Network

A MDN as proposed by Bishop, is a flexible framework for modelling an arbitray conditional probability distribution \(p(\mathbf{y}|\mathbf{x})\) as a mixture of distributions such that

\[p(\mathbf{y}|\mathbf{x}) = \sum_{i=1}^K \pi _i \phi(\mathbf{y}|\theta _i)\]

where \(M\) is the number of mixture components, \(\phi\) can be any type of parametric distribution with parameters \(\theta\), and \(\pi\) is the respective component’s weight (mixing coefficient) as a mixture of distributions. It combines a mixture model with a deep neural network,(DNN) which parameterizes MDN parameters. The mixture weights \(\pi_k(\mathbf{x})\) represents the relative amounts by of each mixture components which can be intrepreted as the probabilities of the \(k-\) components for a given observation \(\mathbf{x}\). If we introduce a latent variable \(\mathbf{z}\) with \(k\) possible states, then \(\pi_k(\mathbf{x})\) will represents the probability distribution of these states \(p(\mathbf{z})\).

As a result, MDN can handle multimodality better than a standard discriminative neural network. Its mixing coefficient model the probability from which a data point was sampled, allowing to encodes uncertainty about the prediction.

Considering gausian distribution, DNN is used to map a set of input features \(\mathbf{x}_{1:d}\) to the parameters of a GMM i.e mixture weights \(\pi_k(\mathbf{x})\), mean \(\mu _k(\mathbf{x})\) and the covariance matrices \(\sigma_k^2(\mathbf{x})\) which inturn gives a full probability density function of an output feature \(\mathbf{y}\) conditioned on the input features.

\[p(\mathbf{y}|\mathbf{x})=\sum_{k=1}^M \pi_k(\mathbf{x}) \mathcal{N}(\mathbf{y}; \mu_k(\mathbf{x}), \sigma_k^2(\mathbf{x}))\]

where \(M\) is the number of components in the mixture and

\[\mathcal{N}(\mathbf{y}; \mu_k(\mathbf{x}), \sigma_k^2(\mathbf{x})) = \frac{1}{(2\sigma_k^2(\mathbf{x}))^{c/2}}\exp\left[\frac{||\mathbf{y}-\mu_k(\mathbf{x})||^2}{2\sigma_k^2(\mathbf{x})}\right]\]

In order to gurantee that \(p(\mathbf{y} \mid \mathbf{x})\) is a probability distribution, the outputs of the networks need to be constrained such that the variance should remains positive and the mixing coefficients lie between zero and one and sum to one. To achieve these constraints:

  1. The mean of the \(k-th\) kernel is modelled directly as the network outputs: \(\mu_{k}^i(\mathbf{x})=z_{k}^{\mu i} \text{ where } i = 1,\ldots, c\). For stability it is recomended to initialize the bias of this linaer layers to the possible centers of gaussians components.
    self._mu = nn.Linear(self.z_size,self.n_gauss*self.out_size)
    nn.init.uniform_(self._mu.bias, a=-2.0, b=2.0)
    
  2. The variances \(\sigma_k\) is represented by an exponential activation function of the corrensponding network output.\(\sigma_k(\mathbf{x}) = \exp(z_k^{\sigma})\). One advantage of the exponential function is always give a positive output and never really reaches zero. Yet it grows very fast in case of datasets with high variance which might cause unstability. Altenatively, different activation function such as ELUPLUSONE and SIGMOID can be used. The ELUPLUSONE is implemented as
    def elu_plus_one(z, eps=1e-15):
     return F.elu(z)+1+eps
    

    while the modified sigmoid

    def clip_log_sigmoid(z, min_std=0.01):
     log_var = F.logsigmoid(z)
     log_var = torch.clamp(log_var, math.log(min_std), -math.log(min_std))
     sigma = torch.exp(0.5 * log_var)
     return sigma
    
  3. The mixing coefficient \(\pi _k(\mathbf{x}) \in [0,1]\) are probabilities such that \(\sum_k \pi _k(\mathbf{x}) =1\). This is modelled as the softmax transformation of the corresponding output.\(\pi_k = \frac{\exp(z_k^{\pi})}{\sum_{j=1}^M \exp(z_j^{\pi})}\).
     self._pi = nn.Linear(self.z_size,self.n_gauss)
     pi = torch.softmax(self._pi(z_in), -1)
    

    To avoid mode collapse, gumble softmax can be used as alternative to softmax such that \(\begin{aligned} \pi_k &= \frac{\exp(\log x_k + g_k)/\tau)}{\sum_{j=1}^K \exp(\log x_j + g_j)/\tau)}\\ g_k &\sim \mathrm{Gamble}(0, 1) \end{aligned}\)

     self._pi = nn.Linear(self.z_size,self.n_gauss)
     pi = F.gumbel_softmax(self._pi(z_in), tau=1, hard=False)
    

Training MDN

The MDN model can be trained using the back propagation algorithm under the maximum likelihood criterion. Suppose \(\theta\) is the vector of trainable parameter, we can redefine our model as a function of \(\mathbf{x}\) parameterized by \(\theta\)

\[p(\mathbf{y}|\mathbf{x}, \mathbf{\theta})=\sum_{k=1}^M \pi_k(\mathbf{x}, \mathbf{\theta}) \mathcal{N}(\mathbf{y}; \mu_k(\mathbf{x}, \mathbf{\theta}), \sigma_k^2(\mathbf{x}, \mathbf{\theta}))\]

Considering a data set \(\mathcal{D}= \{ \mathbf{x}_{1:N},\mathbf{y}_{1:N}\}\) we want to maximize

\[p(\mathbf{\theta}|\mathcal{D}) = p(\mathbf{\theta}|\mathbf{Y},\mathbf{X})\]

By Bayes’s theorem this is equivalent to

\[p(\mathbf{\theta}|\mathbf{Y},\mathbf{X})p(\mathbf{Y}) = p(\mathbf{Y},\mathbf{\theta} |\mathbf{X}) = p(\mathbf{Y}|\mathbf{X},\mathbf{\theta})p(\mathbf{\theta})\]

which leads to

\[p(\mathbf{\theta}|\mathbf{Y},\mathbf{X}) = \frac{p(\mathbf{Y}|\mathbf{X},\mathbf{\theta})p(\mathbf{\theta})}{p(\mathbf{Y})} \propto p(\mathbf{Y}|\mathbf{X},\mathbf{\theta})p(\mathbf{\theta})\]

where

\[p(\mathbf{Y}|\mathbf{X},\mathbf{\theta})=\prod_{n=1}^N p(\mathbf{y}_n|\mathbf{x}_n, \mathbf{\theta})\]

which is simply the product of the conditional densities for each pattern.

To define an error function, the standard approach is the maximum likelihood method, which requires maximisation of the log-likelihood function or, equivalently, minimisation of the negative logarithm of the likelihood. Therefore, the error function for the Mixture Density Network is:

\[\begin{aligned} E(\theta, \mathcal{D})&=-\log p(\mathbf{\theta}|\mathbf{Y},\mathbf{X})= -\log p(\mathbf{Y}|\mathbf{X},\mathbf{\theta})p(\mathbf{\theta})\\ &= -\left(\log \prod_{n=1}^N p(\mathbf{y}_n|\mathbf{x}_n, \mathbf{\theta}) + \log p(\mathbf{\theta})\right)\\ &=-\left(\sum_{n=1}^N \log \sum_{k=1}^M \pi_k(\mathbf{x}) \mathcal{N}(\mathbf{y}; \mu_k(\mathbf{x}), \sigma_k^2(\mathbf{x})) + \log p(\mathbf{\theta})\right)\\ \end{aligned}\]

If we assume a non-informative prior of \(p(\mathbf{\theta})=1\) the error function simplify to

\[E(\theta, \mathcal{D}) = -\sum_{n=1}^N \log \sum_{k=1}^M \pi_k(\mathbf{x}) \mathcal{N}(\mathbf{y}; \mu_k(\mathbf{x}), \sigma_k^2(\mathbf{x}))\]

Implementation details

The pytorch distribution provide elegent way for implementing MDN. The MixtureSameFamily distribution which implements a (batch of) mixture distribution where all component are from different parameterizations of the same distribution type can be used for modelling the MDN. It is parameterized by a Categorical distribution (over k component) and a component distribution.

1
2
3
4
5
6
7
8
9
    import torch.distributions as dist

    pi = torch.softmax(self._pi(x), -1)
    mu = self._mu(x).reshape(-1,self.n_gauss, self.out_dim)
    sigma = clip_log_sigmoid(self._sigma(x).reshape(-1,self.n_gauss, self.out_dim))
    
    mix = dist.Categorical(pi)
    comp = dist.Independent(dist.Normal(mu, sigma), 1) 
    gmm = dist.MixtureSameFamily(mix, comp)

Once the the MDN is defined we can easily compute the negative log-likelihood as follows

    def log_nlloss(self, y, gmm):
        logprobs = gmm.log_prob(y)
        loss = -torch.mean(logprobs)
        return loss