Naive Bayes¶
Naive Bayes models are flexible and interpretable. In scikit-lego we've added support for a Gaussian Mixture variant of the algorithm.
An example of the usage of algorithm can be found below.
Example¶
Let's first import the dependencies and create some data. This code will create a plot of the dataset we'll try to predict.
import numpy as np
import matplotlib.pylab as plt
import seaborn as sns
sns.set_theme()
n = 10000
def make_arr(mu1, mu2, std1=1, std2=1, p=0.5):
res = np.where(np.random.uniform(0, 1, n) > p,
np.random.normal(mu1, std1, n),
np.random.normal(mu2, std2, n));
return np.expand_dims(res, 1)
np.random.seed(42)
X1 = np.concatenate([make_arr(0, 4), make_arr(0, 4)], axis=1)
X2 = np.concatenate([make_arr(-3, 7), make_arr(2, 2)], axis=1)
plt.figure(figsize=(4,4))
plt.scatter(X1[:, 0], X1[:, 1], alpha=0.5)
plt.scatter(X2[:, 0], X2[:, 1], alpha=0.5)
plt.title("simulated dataset");
Note that this dataset would be hard to classify directly if we would be using a standard Gaussian Naive Bayes algorithm since the orange class is multipeaked over two clusters.
To demonstrate this we'll run our GaussianMixtureNB
algorithm with one or two gaussians that the mixture is allowed to find.
from sklego.naive_bayes import GaussianMixtureNB
cmap=sns.color_palette("flare", as_cmap=True)
X = np.concatenate([X1, X2])
y = np.concatenate([np.zeros(n), np.ones(n)])
plt.figure(figsize=(8, 8))
for i, k in enumerate([1, 2]):
mod = GaussianMixtureNB(n_components=k).fit(X, y)
plt.subplot(220 + i * 2 + 1)
pred = mod.predict_proba(X)[:, 0]
plt.scatter(X[:, 0], X[:, 1], c=pred, cmap=cmap)
plt.title(f"predict_proba k={k}")
plt.subplot(220 + i * 2 + 2)
pred = mod.predict(X)
plt.scatter(X[:, 0], X[:, 1], c=pred, cmap=cmap)
plt.title(f"predict k={k}");
Note that the second plot fits the original much better.
We can even zoom in on this second algorithm by having it sample what it believes is the distribution on each column.
Model density
gmm1 = mod.gmms_[0.0]
gmm2 = mod.gmms_[1.0]
plt.figure(figsize=(8, 8))
plt.subplot(221)
plt.hist(gmm1[0].sample(n)[0], 30)
plt.title("model 1 - column 1 density")
plt.subplot(222)
plt.hist(gmm1[1].sample(n)[0], 30)
plt.title("model 1 - column 2 density")
plt.subplot(223)
plt.hist(gmm2[0].sample(n)[0], 30)
plt.title("model 2 - column 1 density")
plt.subplot(224)
plt.hist(gmm2[1].sample(n)[0], 30)
plt.title("model 2 - column 2 density");