Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to script the forward pass? #96

Open
francescamanni1989 opened this issue Sep 17, 2021 · 6 comments
Open

How to script the forward pass? #96

francescamanni1989 opened this issue Sep 17, 2021 · 6 comments
Labels
new feature Feature request to work on

Comments

@francescamanni1989
Copy link

Hu everyone,

I am trying to script the ensemble, however, argsvar cannot be used with torchscript

torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File ".....\lib\site-packages\torchensemble\soft_gradient_boosting.py", line 390
"classifier_forward",
)
def forward(self, *x):
~~ <--- HERE
output = [estimator(*x) for estimator in self.estimators_]
output = op.sum_with_multiplicative(output, self.shrinkage_rate)

do you have any idea on how to handle it?

@xuyxu
Copy link
Member

xuyxu commented Sep 18, 2021

Thanks for reporting! @francescamanni1989

Could you provide the code snippet that reproduces the runtime error.

@francescamanni1989
Copy link
Author

Hi,

the code part in gradient_boosting.py is in the argsvar part, when boosting is performed:
def forward(self, *x):
output = [estimator(*x) for estimator in self.estimators_]
output = op.sum_with_multiplicative(output, self.shrinkage_rate)
output = F.softmax(output, dim=1)
return output

My error comes, when trying to script the model:

model = model_ensemble
traced_model = torch.jit.script(model)

where model_ensemble could be:

model_ensemble = GradientBoostingClassifier(
estimator=MLP,
n_estimators=10,
cuda=False,
shrinkage_rate=0.9,
)

@xuyxu
Copy link
Member

xuyxu commented Sep 20, 2021

It looks like the package does not support torchscript well for now. I will have a careful look when I get a moment, thanks!

@francescamanni1989
Copy link
Author

Exactly!
Thank you

@xuyxu xuyxu added the new feature Feature request to work on label Sep 20, 2021
@francescamanni1989
Copy link
Author

Also, the function sum is not scriptable, but this could be by-passed using @torch.jit.ignore()

@francescamanni1989
Copy link
Author

My suggestion for the indexed variable is to use a for loop instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new feature Feature request to work on
Projects
None yet
Development

No branches or pull requests

2 participants