Skip to content

Commit

Permalink
MNT compatability with sklearn 1.3.0 (close #31)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiningLiu1998 committed Jul 22, 2023
1 parent 9c2e612 commit 982d694
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions imbens/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sklearn import pipeline
from sklearn.base import clone
from sklearn.utils import _print_elapsed_time
from sklearn.utils.metaestimators import if_delegate_has_method
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import check_memory

__all__ = ["Pipeline", "make_pipeline"]
Expand Down Expand Up @@ -183,7 +183,6 @@ def _iter(self, with_final=True, filter_passthrough=True, filter_resample=True):
# Estimator interface

def _fit(self, X, y=None, sample_weight=None, **fit_params_steps):

self.steps = list(self.steps)
self._validate_steps()
# Setup the memory
Expand All @@ -192,7 +191,7 @@ def _fit(self, X, y=None, sample_weight=None, **fit_params_steps):
fit_transform_one_cached = memory.cache(pipeline._fit_transform_one)
fit_resample_one_cached = memory.cache(_fit_resample_one)

for (step_idx, name, transformer) in self._iter(
for step_idx, name, transformer in self._iter(
with_final=False, filter_passthrough=False, filter_resample=False
):
if transformer is None or transformer == "passthrough":
Expand Down Expand Up @@ -373,7 +372,7 @@ def fit_resample(self, X, y=None, sample_weight=None, **fit_params):
if hasattr(last_step, "fit_resample"):
return last_step.fit_resample(Xt, yt, **fit_params_last_step)

@if_delegate_has_method(delegate="_final_estimator")
@available_if(pipeline._final_estimator_has("fit_predict"))
def fit_predict(self, X, y=None, **fit_params):
"""Apply `fit_predict` of last step in pipeline after transforms.
Expand Down Expand Up @@ -414,7 +413,6 @@ def _fit_resample_one(
sampler, X, y, sample_weight=None, message_clsname="", message=None, **fit_params
):
with _print_elapsed_time(message_clsname, message):

out = sampler.fit_resample(X, y, sample_weight=sample_weight, **fit_params)

if sample_weight is None:
Expand Down

0 comments on commit 982d694

Please sign in to comment.