Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AxClient.get_best_trial() produces the wrong best trial index with use_model_predictions=True. #1629

Open
leandrobbraga opened this issue May 24, 2023 · 8 comments
Labels
bug Something isn't working

Comments

@leandrobbraga
Copy link
Contributor

leandrobbraga commented May 24, 2023

I was working with the Service API in a constrained single-objective set and the get_best_trial() function returns me the parameters from a trial, but with the wrong index.

This is the dataframe containing all the trials:
image (11)

The output of get_best_trial()
image (12)

This second image says that the parameters belong to the 29th trial, which is not true, this is actually the 4th trial.

When I pass the parameter use_model_predictions=False it gives the correct index (10).

@mpolson64
Copy link
Contributor

Thanks for catching this, this is indeed a bug on our end. @saitcakmak could you take a look at fixing this?

@leandrobbraga
Copy link
Contributor Author

If you guys are ok with it, I could try to solve this issue.

@leandrobbraga
Copy link
Contributor Author

I did some work today and managed to write a failing test for it:

    @patch(
        f"{get_best_parameters_from_model_predictions_with_trial_index.__module__}"
        + ".assess_model_fit",
        wraps=assess_model_fit,
        return_value=AssessModelFitResult(
            good_fit_metrics_to_fisher_score={"x": 1},
            bad_fit_metrics_to_fisher_score={},
        ),
    )
    def test_get_best_point_with_model_prediction(
        self,
        mock_assess_model_fit,
    ) -> None:
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [1.0, 10.0],
                },
            ],
            objectives={"y": ObjectiveProperties(minimize=True)},
            is_test=True,
            choose_generation_strategy_kwargs={"num_initialization_trials": 4},
        )

        params, idx = ax_client.get_next_trial()
        ax_client.complete_trial(idx, raw_data={"y": 1})

        for i in range(1, 5):
            ax_client.get_next_trial()
            ax_client.complete_trial(i, raw_data={"y": i})

        # ax_client.get_next_trial()
        best_index, best_params, _ = ax_client.get_best_trial()
        self.assertEqual(best_index, idx)
        self.assertEqual(best_params, params)
        mock_assess_model_fit.assert_called()

I know that the issue is in the ax/service/utils/best_point.py get_best_parameters_from_model_predictions_with_trial_index method. If I understand correctly it's using the RunGenerator index, not the actual best param index.

@saitcakmak
Copy link
Contributor

If I understand correctly it's using the RunGenerator index, not the actual best param index.

Yes, that's exactly the issue. I had discussed this with Miles but we forgot to update it here. It returns the index of last GeneratorRun that has a model that can be used to evaluate the arms to find out the best performing one. The index has nothing to do with the actual predicted best arm.

@noppelmax
Copy link

Is this issue still "in progress"? Is there a work-around to get the best trial index?

@saitcakmak
Copy link
Contributor

Hi @noppelmax. I don't think anyone has been working on this. The issue is with the model prediction based best trial index. If you call AxClient.get_best_trial(use_model_predictions=False), it will use the raw observations to find the best point and return the correct trial index.

@saitcakmak saitcakmak removed their assignment Jul 10, 2024
@lena-kashtelyan
Copy link
Contributor

Looks like @saitcakmak provided a workaround here, and this has been inactive for quite some time. @leandrobbraga, please reopen if you'd like to continue the discussion! We likely won't see further activity on a closed issue.

@saitcakmak
Copy link
Contributor

@lena-kashtelyan This is still a bug with use_model_predictions=True, which is the default. Let's keep it open for tracking. If we get around to doing a clean up / rewrite of best point utils, this should be fixed in the process as well.

@saitcakmak saitcakmak reopened this Jul 31, 2024
@saitcakmak saitcakmak changed the title AxClient.get_best_trial() with wrong index AxClient.get_best_trial() produces the wrong best trial index with use_model_predictions=True. Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants