Skip to content

Commit

Permalink
CI fix pytest errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiningLiu1998 committed Jul 22, 2023
1 parent 5b7ada4 commit f189875
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
7 changes: 3 additions & 4 deletions imbens/ensemble/_under_sampling/balanced_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ def __init__(
ccp_alpha=0.0,
max_samples=None,
):

super().__init__(
criterion=criterion,
max_depth=max_depth,
Expand Down Expand Up @@ -392,9 +391,9 @@ def _validate_estimator(self, default=DecisionTreeClassifier()):
)

if self.estimator is not None:
self._estimator = clone(self.estimator)
self.estimator_ = clone(self.estimator)
else:
self._estimator = clone(default)
self.estimator_ = clone(default)

self.sampler_ = RandomUnderSampler(
sampling_strategy=self._sampling_strategy,
Expand All @@ -406,7 +405,7 @@ def _make_sampler_estimator(self, random_state=None):
Warning: This method should be used to properly instantiate new
sub-estimators.
"""
estimator = clone(self._estimator)
estimator = clone(self.estimator_)
estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params})
sampler = clone(self.sampler_)

Expand Down
2 changes: 1 addition & 1 deletion imbens/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def check_target_type(name, estimator_orig):
# should raise warning if the target is continuous (we cannot raise error)
X = np.random.random((20, 2))
y = np.linspace(0, 1, 20)
msg = "Unknown label type: 'continuous'"
msg = "Unknown label type: continuous"
assert_raises_regex(
ValueError,
msg,
Expand Down

0 comments on commit f189875

Please sign in to comment.