As a maintainer of scikit-lego, I often see PR’s that want to move attribute validation from the fit
method into the __init__
. In this post we’ll explore why, in most cases, that won’t work as you expect.
Scikit-lego started because my colleague Vincent and I saw people rewrite the same transformers and estimators at clients over and over again. We set out to consolidate these into a package that allows a place for some of the more experimental and niche use-cases to live, while still offering a high code quality and adherence to the scikit-learn
API standards.
When we started, most of the pull requests either came from Vincent and me, but as scikit-lego gained some traction, we also started receiving several external pull requests. Something we noticed was that many of those pull requests contained a diff that looked something like this:
And to be fair, this diff is totally reasonable when you first look at it. Moving input validation or datatype coercion to the class’s __init__
means that we can forget about it in all the other methods of our class. So what is the problem with the proposed change?
It turns out that one of the spots where this can result in problems is when running GridSearch
. Let’s say we have a pipeline containing an instance of our ColumnSelector
, and we want to figure out which combination of columns should be used to get us the highest score on our validation set. We’ll define our gridsearch like this:
gs = GridSearchCV(
pipeline_with_column_selector,
param_grid={
'columnselector__columns': [
'a',
'b',
['a', 'b']
]
}
)
gs.fit(X, y)
In scikit-learn<=0.22
this will blow up in our face with the following error:
ValueError: Expected 2D array, got 1D array instead
.
So what happened here? A dive into the GridSearch
implementation reveals the following: When calling fit
on a GridSearch
object:
ParameterGrid
is created which contains the cartesian product of all possible possible parameter values.Pipeline
) is clonedset_params
is called on the cloned estimator with the new parameters.Wait up, we didn’t implement a method called set_params
on our ColumnSelector
, so where did it come from? All estimators in scikit-learn inherit from a single base class called BaseEstimator
. This base class only implements two user-facing methods: set_params
and get_params
. The set_params
method takes the estimator object and replaces the value of the attribute that should be set and replaces its value with the desired value.
This means that for our gridsearch defined above, we basically have the following happen1:
clone = clone(column_selector)
setattr(clone, 'columns', 'a')
clone.fit(X, y)
Since the __init__
is called when the clone
is made, but the attribute is later overwritten by the call to setattr
, this means that the snippet that casts our columns attribute to a list is completely bypassed by the set_params
method, and our columns attribute will just be a single string. For the ColumnSelector
implementation, that means the selection that is done in the transform is X['a']
instead of X[['a']]
, resulting in a (1d) series being returned instead of a (2d) dataframe.
This means that we can’t rely on any of the logic that is performed in an estimators __init__
to be called for the current values of the parameters. The only way around this is to do all of our parameter checking and modification in the fit
method, rather than the __init__
.
Now that we know this, it might be interesting to check what other parts of scikit-learn
use set_params
. This might just save us a whole bunch of trouble somewhere in the future. It turns out that we already saw a function that relies heavily on set_params
and its get_params
counterpart: clone
.
A bare-bones implementation of clone looks something like this2:
def clone(estimator):
# Get the class of the original estimator and instantiate a new object from it
new_estimator = estimator.__class__()
new_estimator.set_params(estimator.get_params())
return new_estimator
This implementation means that somehow, scikit-learn
can use an existing estimator object and figure out what parameters it can take. To do this, scikit-learn
uses Python’s inspect
module. A basic version of get_params
would look something like this3:
import inspect
def get_params(estimator):
init_sig = inspect.signature(estimator.__init__)
params = [p.name for p in init_sig.parameters.values()]
return {param: getattr(estimator, param) for param in params}
get_params(LinearRegression())
# {'fit_intercept': True, 'normalize': False, 'copy_X': True, 'n_jobs': None}
Looking at this implementation closely, reveals that there is an important assumption that should not be broken. The names of the arguments of our __init__
method, should be exactly the same as the attributes stored on our class. Otherwise, the getattr
call will return None
instead of the actual attribute value.
A quick modification of our ColumnSelector
shows this in action:
from sklearn.base import BaseEstimator, TransformerMixin
class ColumnSelector(TransformerMixin, BaseEstimator):
def __init__(self, columns=None):
self.cols = columns
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
return X[self.columns]
from sklearn import clone
clone(ColumnSelector('a'))
# ColumnSelector(columns=None)
The cloned ColumnSelector
has None
as its columns
attribute, and the cols
attribute is not even set. As of scikit-learn>=0.22
this will raise the following warning:
FutureWarning: From version 0.24, get_params will raise an AttributeError if a parameter cannot be retrieved as an instance attribute. Previously it would return None.
But in any previous version this would fail silently and return None.
As we have seen in our example above, there should be no logic, not even input validation, in an estimators __init__
. The logic should be put where the parameters are used, which is typically in fit
. Apart from that, every argument accepted by __init__
should correspond to an attribute on the instance.
It turns out that the specific implementation of get_params
and its companion set_params
also makes for other unintuitive problems when trying to inherit from other estimators, but we’ll explore those in another post.
Adapted from https://github.com/scikit-learn/scikit-learn/blob/b194674c42d54b26137a456c510c5fdba1ba23e0/sklearn/model_selection/_validation.py#L394↩︎
Adapted from: https://github.com/scikit-learn/scikit-learn/blob/b194674c4/sklearn/base.py#L39↩︎
Adapted from: https://github.com/scikit-learn/scikit-learn/blob/b194674c42d54b26137a456c510c5fdba1ba23e0/sklearn/base.py#L147↩︎