Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incompatible types in assignment #344

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

daikikatsuragawa
Copy link
Contributor

Signed-off-by: Daiki Katsuragawa [email protected]

This pull request is one of the steps of #331 . Most of the "Incompatible types in assignment" errors detected by mypy are fixed. Basically, the policy is to modify the process without major changes. Therefore, there are also temporary responses (e.g., defining type hints Any). These are excluded from the modifications in this pull request and will be revised later.

In conclusion, this pull request will have the following effects.

  • Deletion of "Incompatible types in assignment" processes
  • Prevention of unexpected exceptions
  • Improved readability with type hints

Signed-off-by: Daiki Katsuragawa <[email protected]>
@@ -343,12 +344,16 @@ def _predict_fn_custom(self, input_instance, desired_class):

def compute_yloss(self, cfs, desired_range, desired_class):
"""Computes the first part (y-loss) of the loss function."""
yloss = 0.0
yloss: Any = 0.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, multiple types of values are stored.

if opt_method == "adam":
self.optimizer = torch.optim.Adam(self.cfs, lr=learning_rate)
elif opt_method == "rmsprop":
self.optimizer = torch.optim.RMSprop(self.cfs, lr=learning_rate)

def compute_yloss(self):
"""Computes the first part (y-loss) of the loss function."""
yloss = 0.0
yloss: Any = 0.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, multiple types of values are stored.

@@ -307,7 +310,7 @@ def compute_diversity_loss(self):
def compute_regularization_loss(self):
"""Adds a linear equality constraints to the loss functions -
to ensure all levels of a categorical variable sums to one"""
regularization_loss = 0.0
regularization_loss: Any = 0.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, multiple types of values are stored.

@@ -425,7 +428,7 @@ def find_counterfactuals(self, query_instance, desired_class, optimizer, learnin
test_pred = self.predict_fn(torch.tensor(query_instance).float())[0]
if desired_class == "opposite":
desired_class = 1.0 - np.round(test_pred)
self.target_cf_class = torch.tensor(desired_class).float()
self.target_cf_class: Any = torch.tensor(desired_class).float()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, multiple types of values are stored.

@@ -341,8 +341,7 @@ def initialize_CFs(self, query_instance, init_near_query_instance=False):
one_init.append(np.random.uniform(self.minx[0][i], self.maxx[0][i]))
else:
one_init.append(query_instance[0][i])
one_init = np.array([one_init], dtype=np.float32)
self.cfs[n].assign(one_init)
self.cfs[n].assign(np.array([one_init], dtype=np.float32))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Substitute directly. This prevents one_init from being of more than one type.

Signed-off-by: Daiki Katsuragawa <[email protected]>
train_dataset = torch.tensor(self.vae_train_feat).float()
train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
train_dataset = torch.utils.data.DataLoader(
torch.tensor(self.vae_train_feat).float(), # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set type: ignore because the error depends on the implementation of torch.utils.data.DataLoader.

train_dataset = torch.tensor(self.vae_train_feat).float()
train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
train_dataset = torch.utils.data.DataLoader(
torch.tensor(self.vae_train_feat).float(), # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set type: ignore because the error depends on the implementation of torch.utils.data.DataLoader.

@codecov-commenter
Copy link

codecov-commenter commented Nov 16, 2022

Codecov Report

Base: 72.04% // Head: 72.12% // Increases project coverage by +0.07% 🎉

Coverage data is based on head (61b7d79) compared to base (8f17cfd).
Patch coverage: 78.12% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #344      +/-   ##
==========================================
+ Coverage   72.04%   72.12%   +0.07%     
==========================================
  Files          26       26              
  Lines        3595     3598       +3     
==========================================
+ Hits         2590     2595       +5     
+ Misses       1005     1003       -2     
Flag Coverage Δ
unittests 72.12% <78.12%> (+0.07%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
dice_ml/explainer_interfaces/feasible_base_vae.py 0.00% <0.00%> (ø)
...e_ml/explainer_interfaces/feasible_model_approx.py 0.00% <0.00%> (ø)
dice_ml/explainer_interfaces/explainer_base.py 89.46% <75.00%> (-0.21%) ⬇️
dice_ml/explainer_interfaces/dice_genetic.py 94.89% <87.50%> (-0.25%) ⬇️
dice_ml/explainer_interfaces/dice_KD.py 91.72% <100.00%> (ø)
dice_ml/explainer_interfaces/dice_pytorch.py 85.84% <100.00%> (+0.12%) ⬆️
dice_ml/explainer_interfaces/dice_random.py 93.84% <100.00%> (+0.66%) ⬆️
dice_ml/explainer_interfaces/dice_tensorflow2.py 86.23% <100.00%> (-0.05%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

Signed-off-by: Daiki Katsuragawa <[email protected]>
@amit-sharma
Copy link
Collaborator

@daikikatsuragawa thanks for this contribution. Can you share a few lines on the motivation for this PR (e.g., why change cfs from [] to pd.DataFrame()?) and are there any bugs that this PR fixes?

@daikikatsuragawa
Copy link
Contributor Author

@amit-sharma

Comment on each. After completion, send a message with a Mention.

@daikikatsuragawa
Copy link
Contributor Author

Operational error for closed (and reopened). My apologies.

Copy link
Collaborator

@gaugup gaugup left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we enable mypy in dice-ml linting build and fix all the mypy errors?

@daikikatsuragawa
Copy link
Contributor Author

@amit-sharma

Should we enable mypy in dice-ml linting build and fix all the mypy errors?

Yes. However, some mypy errors remain at present. It would be good to add them to Lint after they have been resolved.

@daikikatsuragawa
Copy link
Contributor Author

@amit-sharma

@daikikatsuragawa thanks for this contribution. Can you share a few lines on the motivation for this PR (e.g., why change cfs from [] to pd.DataFrame()?)

During the function, the process immediately following was checked and it was determined that the type of cfs is pd.DataFrame. If this is not appropriate, please let me know.

cfs = self.dataset_with_predictions.iloc[indices].copy()
cfs['distance'] = distances
cfs = self.do_sparsity_check(cfs, query_instance, sparsity_weight)
cfs = cfs.drop(self.data_interface.outcome_name, axis=1)

and are there any bugs that this PR fixes?

Probably this case (cfs = pd.DataFrame()) only. The others are fixes related to type hinting and refactoring and are not expected to affect behaviour. Unit tests also passed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

4 participants