From 982d694b2bcc650f7846dd03ee0b642d2c9e0f26 Mon Sep 17 00:00:00 2001 From: ZhiningLiu1998 Date: Fri, 21 Jul 2023 18:08:41 -0700 Subject: [PATCH] MNT compatability with sklearn 1.3.0 (close #31) --- imbens/pipeline.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/imbens/pipeline.py b/imbens/pipeline.py index fd20aaa..c8accfb 100644 --- a/imbens/pipeline.py +++ b/imbens/pipeline.py @@ -18,7 +18,7 @@ from sklearn import pipeline from sklearn.base import clone from sklearn.utils import _print_elapsed_time -from sklearn.utils.metaestimators import if_delegate_has_method +from sklearn.utils.metaestimators import available_if from sklearn.utils.validation import check_memory __all__ = ["Pipeline", "make_pipeline"] @@ -183,7 +183,6 @@ def _iter(self, with_final=True, filter_passthrough=True, filter_resample=True): # Estimator interface def _fit(self, X, y=None, sample_weight=None, **fit_params_steps): - self.steps = list(self.steps) self._validate_steps() # Setup the memory @@ -192,7 +191,7 @@ def _fit(self, X, y=None, sample_weight=None, **fit_params_steps): fit_transform_one_cached = memory.cache(pipeline._fit_transform_one) fit_resample_one_cached = memory.cache(_fit_resample_one) - for (step_idx, name, transformer) in self._iter( + for step_idx, name, transformer in self._iter( with_final=False, filter_passthrough=False, filter_resample=False ): if transformer is None or transformer == "passthrough": @@ -373,7 +372,7 @@ def fit_resample(self, X, y=None, sample_weight=None, **fit_params): if hasattr(last_step, "fit_resample"): return last_step.fit_resample(Xt, yt, **fit_params_last_step) - @if_delegate_has_method(delegate="_final_estimator") + @available_if(pipeline._final_estimator_has("fit_predict")) def fit_predict(self, X, y=None, **fit_params): """Apply `fit_predict` of last step in pipeline after transforms. @@ -414,7 +413,6 @@ def _fit_resample_one( sampler, X, y, sample_weight=None, message_clsname="", message=None, **fit_params ): with _print_elapsed_time(message_clsname, message): - out = sampler.fit_resample(X, y, sample_weight=sample_weight, **fit_params) if sample_weight is None: