Scikit-Learn Compatibility

Scikit-Learn

Many of the language-backends inside of this package can be used in scikit-learn pipelines. We've implemented a compatible .fit() and .transform() API which means that you could write scikit-learn pipelines like this:

import numpy as np
from whatlies.language import SpacyLanguage
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression

pipe = Pipeline([
    ("embed", SpacyLanguage("en_core_web_md")),
    ("model", LogisticRegression())
])

X = [
    "i really like this post",
    "thanks for that comment",
    "i enjoy this friendly forum",
    "this is a bad post",
    "i dislike this article",
    "this is not well written"
]

y = np.array([1, 1, 1, 0, 0, 0])

pipe.fit(X, y)

This pipeline is using the embeddings from spaCy now and passing those to the logistic regression.

pipe.predict_proba(X)
# array([[0.37862409, 0.62137591],
#        [0.27858304, 0.72141696],
#        [0.21386529, 0.78613471],
#        [0.7155662 , 0.2844338 ],
#        [0.64924579, 0.35075421],
#        [0.76414156, 0.23585844]])

You could make a pipeline that generates both dense and sparse features by using a FeatureUnion.

from sklearn.pipeline import FeatureUnion
from sklearn.feature_extraction.text import CountVectorizer

preprocess = FeatureUnion([
    ("dense", SpacyLanguage("en_core_web_md")),
    ("sparse_word", CountVectorizer()),
    ("sparse_subword", CountVectorizer(analyzer="char", ngram_range=(2, 4)))
])

Supported Models

Every language backend that this library offers is compatible for use in a scikit-learn pipeline. This includes the following;

  • whatlies.language.SpacyLanguage
  • whatlies.language.FasttextLanguage
  • whatlies.language.CountVectorLanguage
  • whatlies.language.BytePairLanguage
  • whatlies.language.GensimLanguage
  • whatlies.language.HFTransformersLanguage
  • whatlies.language.TFHubLanguage
  • whatlies.language.UniversalSentenceLanguage
  • whatlies.language.SentenceTFMLanguage
  • whatlies.language.UniversalSentenceLanguage
  • whatlies.language.LaBSELanguage

Caveats

There's a few caveats to be aware of though. In general these language backends cannot be directly pickled so that means that you won't be able to save a pipeline if there's a whatlies component in it. This also means that you cannot use a gridsearch. Where possible we try to test against scikit-learn's testing utilities but for now the use-case is limited for use in a Pipeline. You should assume that you cannot use GridSearchCV and that you cannot pickle to disk.

If you see a way to properly support this in general, let us know on github by creating an issue.