Skip to content

scikit-mdn

Mixture density networks, from PyTorch, for scikit-learn

Usage

To get this tool working locally you will first need to install it:

python -m pip install scikit-mdn

Then you can use it in your code. Here is a small demo example.

import numpy as np
from sklearn.datasets import make_moons
from skmdn import MixtureDensityEstimator

# Generate dataset
n_samples = 1000
X_full, _ = make_moons(n_samples=n_samples, noise=0.1)
X = X_full[:, 0].reshape(-1, 1)  # Use only the first column as input
Y = X_full[:, 1].reshape(-1, 1)  # Predict the second column

# Add some noise to Y to make the problem more suitable for MDN
Y += 0.1 * np.random.randn(n_samples, 1)

# Fit the model
mdn = MixtureDensityEstimator()
mdn.fit(X, Y)

# Predict some quantiles on the train set 
means, quantiles = mdn.predict(X, quantiles=[0.01, 0.1, 0.9, 0.99], resolution=100000)
plt.scatter(X, Y)
plt.scatter(X, quantiles[:, 0], color='orange')
plt.scatter(X, quantiles[:, 1], color='green')
plt.scatter(X, quantiles[:, 2], color='green')
plt.scatter(X, quantiles[:, 3], color='orange')
plt.scatter(X, means, color='red')

This is what the chart looks like:

Example chart

You can see how it is able to predict the quantiles of the distribution, and the mean.

API

This is the main object that you'll interact with.

Bases: BaseEstimator

A scikit-learn compatible Mixture Density Estimator.

Parameters:

Name Type Description Default
hidden_dim

hidden layer dimension

10
n_gaussians

number of gaussians in the mixture model

5
epochs

number of epochs

1000
lr

learning rate

0.01
weight_decay

weight decay for regularisation

0.0
Source code in skmdn/__init__.py
class MixtureDensityEstimator(BaseEstimator):
    '''
    A scikit-learn compatible Mixture Density Estimator.

    Args:
        hidden_dim: hidden layer dimension
        n_gaussians: number of gaussians in the mixture model
        epochs: number of epochs
        lr: learning rate
        weight_decay: weight decay for regularisation
    '''
    def __init__(self, hidden_dim=10, n_gaussians=5, epochs=1000, lr=0.01, weight_decay=0.0):
        self.hidden_dim = hidden_dim
        self.n_gaussians = n_gaussians
        self.epochs = epochs
        self.lr = lr
        self.weight_decay = weight_decay

    def _cast_torch(self, X, y):
        if not hasattr(self, 'X_width_'):
            self.X_width_ = X.shape[1]
        if not hasattr(self, 'X_min_:'):
            self.X_min_ = X.min(axis=0)
        if not hasattr(self, 'X_max_:'):
            self.X_max_ = X.max(axis=0)
        if not hasattr(self, 'y_min_:'):
            self.y_min_ = y.min()
        if not hasattr(self, 'y_max_:'):
            self.y_max_ = y.max()

        assert X.shape[1] == self.X_width_, "Input dimension mismatch"

        return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

    def fit(self, X, y):
        """
        Fit the model to the data.

        Args:
            X: (n_samples, n_features)
            y: (n_samples, 1)
        """
        X, y = self._cast_torch(X, y)

        self.model_ = MixtureDensityNetwork(X.shape[1], self.hidden_dim, y.shape[1], self.n_gaussians)
        self.optimizer_ = torch.optim.Adam(self.model_.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        for epoch in range(self.epochs):
            self.optimizer_.zero_grad()
            pi, mu, sigma = self.model_(X)
            loss = mdn_loss(pi, mu, sigma, y)
            loss.backward()
            self.optimizer_.step()

        return self

    def partial_fit(self, X, y, n_epochs=1):
        """
        Fit the model to the data for a set number of epochs. Can be used to continue training on new data.

        Args:
            X: (n_samples, n_features)
            y: (n_samples, 1)
            n_epochs: number of epochs
        """
        X, y = self._cast_torch(X, y)

        if not self.optimizer_:
            self.optimizer_ = torch.optim.Adam(self.model_.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        for epoch in range(n_epochs):
            self.optimizer_.zero_grad()
            pi, mu, sigma = self.model_(X)
            loss = mdn_loss(pi, mu, sigma, y)
            loss.backward()
            self.optimizer_.step()

        return self

    def forward(self, X):
        """
        Calculate the $\pi$, $\mu$ and $\sigma$ outputs n for each sample in X.

        Args:
            X: (n_samples, n_features)

        Returns:
            pi: (n_samples, n_gaussians)
            mu: (n_samples, n_gaussians)
            sigma: (n_samples, n_gaussians)
        """
        X = torch.FloatTensor(X)
        with torch.no_grad():
            pi, mu, sigma = self.model_(X)
        pi, mu, sigma = pi.detach().numpy(), mu.detach().numpy(), sigma.detach().numpy()
        return pi, mu[:, :, 0], sigma[:, :, 0]

    def pdf(self, X, resolution=100, y_min=None, y_max=None):
        '''
        Compute the probability density function of the model.

        This function computes the pdf for each sample in X.
        It also returns the y values for which the pdf is computed to help with plotting.

        Args:
            X: (n_samples, n_features)
            resolution: number of intervals to compute the quantile over

        Returns:
            pdf: (n_samples, resolution)
            ys: (resolution,)
        '''
        X = torch.FloatTensor(X)
        with torch.no_grad():
            pi, mu, sigma = self.model_(X)
        pi, mu, sigma = self.forward(X)
        ys = np.linspace(
            y_min if y_min else self.y_min_, 
            y_max if y_max else self.y_max_, 
            resolution
        )
        pdf = np.zeros((pi.shape[0], resolution))
        for i in range(pi.shape[0]):
            for j in range(pi.shape[1]):
                pdf[i] += norm(mu[i, j], sigma[i, j]).pdf(ys) * pi[i, j]
        return pdf, ys

    def cdf(self, X, resolution=100):
        '''
        Compute the cumulative probability density function of the model.

        This function computes the cdf for each sample in Xd.
        It also returns the y values for which the cdf is computed to help with plotting.

        Args:
            X: (n_samples, n_features)
            resolution: number of intervals to compute the quantile over

        Returns:
            cdf: (n_samples, resolution)
            ys: (resolution,)
        '''
        pdf, ys = self.pdf(X, resolution=resolution)
        cdf = pdf.cumsum(axis=1)
        cdf /= cdf[:, -1].reshape(-1, 1)
        return cdf, ys

    def predict(self, X, quantiles=None, resolution=100):
        '''
        Predicts the variance at risk at a given quantile for each datapoint X.

        Args:
            X: (n_samples, n_features)
            quantile: quantile value
            resolution: number of intervals to compute the quantile over

        Returns:
            pred: (n_samples,)
            quantiles: (n_samples, n_quantiles)
        '''
        pdf, ys = self.pdf(X, resolution=resolution)
        cdf = pdf.cumsum(axis=1)
        cdf /= cdf[:, -1].reshape(-1, 1)

        mean_pred = ys[np.argmax(cdf > 0.5, axis=1)]

        if not quantiles:
            return mean_pred

        quantile_out = np.zeros((X.shape[0], len(quantiles)))
        for j, q in enumerate(quantiles):
            quantile_out[:, j] = ys[np.argmax(cdf > q, axis=1)]
        return mean_pred, quantile_out

cdf(X, resolution=100)

Compute the cumulative probability density function of the model.

This function computes the cdf for each sample in Xd. It also returns the y values for which the cdf is computed to help with plotting.

Parameters:

Name Type Description Default
X

(n_samples, n_features)

required
resolution

number of intervals to compute the quantile over

100

Returns:

Name Type Description
cdf

(n_samples, resolution)

ys

(resolution,)

Source code in skmdn/__init__.py
def cdf(self, X, resolution=100):
    '''
    Compute the cumulative probability density function of the model.

    This function computes the cdf for each sample in Xd.
    It also returns the y values for which the cdf is computed to help with plotting.

    Args:
        X: (n_samples, n_features)
        resolution: number of intervals to compute the quantile over

    Returns:
        cdf: (n_samples, resolution)
        ys: (resolution,)
    '''
    pdf, ys = self.pdf(X, resolution=resolution)
    cdf = pdf.cumsum(axis=1)
    cdf /= cdf[:, -1].reshape(-1, 1)
    return cdf, ys

fit(X, y)

Fit the model to the data.

Parameters:

Name Type Description Default
X

(n_samples, n_features)

required
y

(n_samples, 1)

required
Source code in skmdn/__init__.py
def fit(self, X, y):
    """
    Fit the model to the data.

    Args:
        X: (n_samples, n_features)
        y: (n_samples, 1)
    """
    X, y = self._cast_torch(X, y)

    self.model_ = MixtureDensityNetwork(X.shape[1], self.hidden_dim, y.shape[1], self.n_gaussians)
    self.optimizer_ = torch.optim.Adam(self.model_.parameters(), lr=self.lr, weight_decay=self.weight_decay)

    for epoch in range(self.epochs):
        self.optimizer_.zero_grad()
        pi, mu, sigma = self.model_(X)
        loss = mdn_loss(pi, mu, sigma, y)
        loss.backward()
        self.optimizer_.step()

    return self

forward(X)

Calculate the \(\pi\), \(\mu\) and \(\sigma\) outputs n for each sample in X.

Parameters:

Name Type Description Default
X

(n_samples, n_features)

required

Returns:

Name Type Description
pi

(n_samples, n_gaussians)

mu

(n_samples, n_gaussians)

sigma

(n_samples, n_gaussians)

Source code in skmdn/__init__.py
def forward(self, X):
    """
    Calculate the $\pi$, $\mu$ and $\sigma$ outputs n for each sample in X.

    Args:
        X: (n_samples, n_features)

    Returns:
        pi: (n_samples, n_gaussians)
        mu: (n_samples, n_gaussians)
        sigma: (n_samples, n_gaussians)
    """
    X = torch.FloatTensor(X)
    with torch.no_grad():
        pi, mu, sigma = self.model_(X)
    pi, mu, sigma = pi.detach().numpy(), mu.detach().numpy(), sigma.detach().numpy()
    return pi, mu[:, :, 0], sigma[:, :, 0]

partial_fit(X, y, n_epochs=1)

Fit the model to the data for a set number of epochs. Can be used to continue training on new data.

Parameters:

Name Type Description Default
X

(n_samples, n_features)

required
y

(n_samples, 1)

required
n_epochs

number of epochs

1
Source code in skmdn/__init__.py
def partial_fit(self, X, y, n_epochs=1):
    """
    Fit the model to the data for a set number of epochs. Can be used to continue training on new data.

    Args:
        X: (n_samples, n_features)
        y: (n_samples, 1)
        n_epochs: number of epochs
    """
    X, y = self._cast_torch(X, y)

    if not self.optimizer_:
        self.optimizer_ = torch.optim.Adam(self.model_.parameters(), lr=self.lr, weight_decay=self.weight_decay)
    for epoch in range(n_epochs):
        self.optimizer_.zero_grad()
        pi, mu, sigma = self.model_(X)
        loss = mdn_loss(pi, mu, sigma, y)
        loss.backward()
        self.optimizer_.step()

    return self

pdf(X, resolution=100, y_min=None, y_max=None)

Compute the probability density function of the model.

This function computes the pdf for each sample in X. It also returns the y values for which the pdf is computed to help with plotting.

Parameters:

Name Type Description Default
X

(n_samples, n_features)

required
resolution

number of intervals to compute the quantile over

100

Returns:

Name Type Description
pdf

(n_samples, resolution)

ys

(resolution,)

Source code in skmdn/__init__.py
def pdf(self, X, resolution=100, y_min=None, y_max=None):
    '''
    Compute the probability density function of the model.

    This function computes the pdf for each sample in X.
    It also returns the y values for which the pdf is computed to help with plotting.

    Args:
        X: (n_samples, n_features)
        resolution: number of intervals to compute the quantile over

    Returns:
        pdf: (n_samples, resolution)
        ys: (resolution,)
    '''
    X = torch.FloatTensor(X)
    with torch.no_grad():
        pi, mu, sigma = self.model_(X)
    pi, mu, sigma = self.forward(X)
    ys = np.linspace(
        y_min if y_min else self.y_min_, 
        y_max if y_max else self.y_max_, 
        resolution
    )
    pdf = np.zeros((pi.shape[0], resolution))
    for i in range(pi.shape[0]):
        for j in range(pi.shape[1]):
            pdf[i] += norm(mu[i, j], sigma[i, j]).pdf(ys) * pi[i, j]
    return pdf, ys

predict(X, quantiles=None, resolution=100)

Predicts the variance at risk at a given quantile for each datapoint X.

Parameters:

Name Type Description Default
X

(n_samples, n_features)

required
quantile

quantile value

required
resolution

number of intervals to compute the quantile over

100

Returns:

Name Type Description
pred

(n_samples,)

quantiles

(n_samples, n_quantiles)

Source code in skmdn/__init__.py
def predict(self, X, quantiles=None, resolution=100):
    '''
    Predicts the variance at risk at a given quantile for each datapoint X.

    Args:
        X: (n_samples, n_features)
        quantile: quantile value
        resolution: number of intervals to compute the quantile over

    Returns:
        pred: (n_samples,)
        quantiles: (n_samples, n_quantiles)
    '''
    pdf, ys = self.pdf(X, resolution=resolution)
    cdf = pdf.cumsum(axis=1)
    cdf /= cdf[:, -1].reshape(-1, 1)

    mean_pred = ys[np.argmax(cdf > 0.5, axis=1)]

    if not quantiles:
        return mean_pred

    quantile_out = np.zeros((X.shape[0], len(quantiles)))
    for j, q in enumerate(quantiles):
        quantile_out[:, j] = ys[np.argmax(cdf > q, axis=1)]
    return mean_pred, quantile_out