sklearn’s init

As a maintainer of scikit-lego, I often see PR’s that want to move attribute validation from the fit method into the __init__. In this post we’ll explore why, in most cases, that won’t work as you expect.

Matthijs Brouns

Scikit-lego started because my colleague Vincent and I saw people rewrite the same transformers and estimators at clients over and over again. We set out to consolidate these into a package that allows a place for some of the more experimental and niche use-cases to live, while still offering a high code quality and adherence to the scikit-learn API standards.

When we started, most of the pull requests either came from Vincent and me, but as scikit-lego gained some traction, we also started receiving several external pull requests. Something we noticed was that many of those pull requests contained a diff that looked something like this:

A common pull request

Figure 1: A common pull request

And to be fair, this diff is totally reasonable when you first look at it. Moving input validation or datatype coercion to the class’s __init__ means that we can forget about it in all the other methods of our class. So what is the problem with the proposed change?


It turns out that one of the spots where this can result in problems is when running GridSearch. Let’s say we have a pipeline containing an instance of our ColumnSelector, and we want to figure out which combination of columns should be used to get us the highest score on our validation set. We’ll define our gridsearch like this:

gs = GridSearchCV(
      'columnselector__columns': [
        ['a', 'b']
), y)

In scikit-learn<=0.22 this will blow up in our face with the following error:

ValueError: Expected 2D array, got 1D array instead.

So what happened here? A dive into the GridSearch implementation reveals the following: When calling fit on a GridSearch object:

  1. a ParameterGrid is created which contains the cartesian product of all possible possible parameter values.
  2. for each of the parameter possibilities in the grid:
    1. The original estimator (in this case our Pipeline) is cloned
    2. A method called set_params is called on the cloned estimator with the new parameters.
    3. The cloned estimator is fitted on the supplied data

Wait up, we didn’t implement a method called set_params on our ColumnSelector, so where did it come from? All estimators in scikit-learn inherit from a single base class called BaseEstimator. This base class only implements two user-facing methods: set_params and get_params. The set_params method takes the estimator object and replaces the value of the attribute that should be set and replaces its value with the desired value.

This means that for our gridsearch defined above, we basically have the following happen1:

clone = clone(column_selector)
setattr(clone, 'columns', 'a'), y)

Since the __init__ is called when the clone is made, but the attribute is later overwritten by the call to setattr, this means that the snippet that casts our columns attribute to a list is completely bypassed by the set_params method, and our columns attribute will just be a single string. For the ColumnSelector implementation, that means the selection that is done in the transform is X['a'] instead of X[['a']], resulting in a (1d) series being returned instead of a (2d) dataframe.

This means that we can’t rely on any of the logic that is performed in an estimators __init__ to be called for the current values of the parameters. The only way around this is to do all of our parameter checking and modification in the fit method, rather than the __init__.


Now that we know this, it might be interesting to check what other parts of scikit-learn use set_params. This might just save us a whole bunch of trouble somewhere in the future. It turns out that we already saw a function that relies heavily on set_params and its get_params counterpart: clone.

A bare-bones implementation of clone looks something like this2:

def clone(estimator):
  # Get the class of the original estimator and instantiate a new object from it
  new_estimator = estimator.__class__()  
  return new_estimator

This implementation means that somehow, scikit-learn can use an existing estimator object and figure out what parameters it can take. To do this, scikit-learn uses Python’s inspect module. A basic version of get_params would look something like this3:

import inspect

def get_params(estimator):
    init_sig = inspect.signature(estimator.__init__)
    params = [ for p in init_sig.parameters.values()]

    return {param: getattr(estimator, param) for param in params}

# {'fit_intercept': True, 'normalize': False, 'copy_X': True, 'n_jobs': None}

Looking at this implementation closely, reveals that there is an important assumption that should not be broken. The names of the arguments of our __init__ method, should be exactly the same as the attributes stored on our class. Otherwise, the getattr call will return None instead of the actual attribute value.

A quick modification of our ColumnSelector shows this in action:

from sklearn.base import BaseEstimator, TransformerMixin

class ColumnSelector(TransformerMixin, BaseEstimator):
    def __init__(self, columns=None):
        self.cols = columns
    def fit(self, X, y=None):
        return self
    def transform(self, X, y=None):
        return X[self.columns]
from sklearn import clone
# ColumnSelector(columns=None)

The cloned ColumnSelector has None as its columns attribute, and the cols attribute is not even set. As of scikit-learn>=0.22 this will raise the following warning:

FutureWarning: From version 0.24, get_params will raise an AttributeError if a parameter cannot be retrieved as an instance attribute. Previously it would return None.

But in any previous version this would fail silently and return None.


As we have seen in our example above, there should be no logic, not even input validation, in an estimators __init__. The logic should be put where the parameters are used, which is typically in fit. Apart from that, every argument accepted by __init__ should correspond to an attribute on the instance.

It turns out that the specific implementation of get_params and its companion set_params also makes for other unintuitive problems when trying to inherit from other estimators, but we’ll explore those in another post.

  1. Adapted from↩︎

  2. Adapted from:↩︎

  3. Adapted from:↩︎