From f189875c424fb08880962914908d84cb9e9ea8c4 Mon Sep 17 00:00:00 2001 From: ZhiningLiu1998 Date: Fri, 21 Jul 2023 22:47:10 -0700 Subject: [PATCH] CI fix pytest errors --- imbens/ensemble/_under_sampling/balanced_random_forest.py | 7 +++---- imbens/utils/estimator_checks.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/imbens/ensemble/_under_sampling/balanced_random_forest.py b/imbens/ensemble/_under_sampling/balanced_random_forest.py index 65993fb..53df214 100644 --- a/imbens/ensemble/_under_sampling/balanced_random_forest.py +++ b/imbens/ensemble/_under_sampling/balanced_random_forest.py @@ -348,7 +348,6 @@ def __init__( ccp_alpha=0.0, max_samples=None, ): - super().__init__( criterion=criterion, max_depth=max_depth, @@ -392,9 +391,9 @@ def _validate_estimator(self, default=DecisionTreeClassifier()): ) if self.estimator is not None: - self._estimator = clone(self.estimator) + self.estimator_ = clone(self.estimator) else: - self._estimator = clone(default) + self.estimator_ = clone(default) self.sampler_ = RandomUnderSampler( sampling_strategy=self._sampling_strategy, @@ -406,7 +405,7 @@ def _make_sampler_estimator(self, random_state=None): Warning: This method should be used to properly instantiate new sub-estimators. """ - estimator = clone(self._estimator) + estimator = clone(self.estimator_) estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params}) sampler = clone(self.sampler_) diff --git a/imbens/utils/estimator_checks.py b/imbens/utils/estimator_checks.py index a6e718f..5ee890f 100644 --- a/imbens/utils/estimator_checks.py +++ b/imbens/utils/estimator_checks.py @@ -131,7 +131,7 @@ def check_target_type(name, estimator_orig): # should raise warning if the target is continuous (we cannot raise error) X = np.random.random((20, 2)) y = np.linspace(0, 1, 20) - msg = "Unknown label type: 'continuous'" + msg = "Unknown label type: continuous" assert_raises_regex( ValueError, msg,