Skip to content

Commit

Permalink
AIChatTabHelper refine and test retry logic
Browse files Browse the repository at this point in the history
  • Loading branch information
petemill committed Sep 20, 2024
1 parent 9a62867 commit ee3279f
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "brave/browser/ui/sidebar/sidebar_controller.h"
#include "brave/browser/ui/sidebar/sidebar_model.h"
#include "brave/components/ai_chat/content/browser/ai_chat_tab_helper.h"
#include "brave/components/ai_chat/core/browser/ai_chat_service.h"
#include "brave/components/ai_chat/core/browser/engine/engine_consumer.h"
#include "brave/components/ai_chat/core/browser/engine/mock_engine_consumer.h"
#include "brave/components/ai_chat/core/browser/engine/mock_remote_completion_client.h"
Expand Down
27 changes: 3 additions & 24 deletions browser/ai_chat/ai_chat_ui_browsertest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,28 +261,6 @@ IN_PROC_BROWSER_TEST_F(AIChatUIBrowserTest, PrintPreview) {
EXPECT_FALSE(HasPendingGetContentRequest());
}

IN_PROC_BROWSER_TEST_F(AIChatUIBrowserTest, PrintPreviewRequests) {
NavigateURL(https_server_.GetURL("docs.google.com", "/long_canvas.html"),
false);
FetchPageContent(FROM_HERE, "", false);
EXPECT_TRUE(HasPendingGetContentRequest());

GetActiveWebContents()->GetController().Reload(content::ReloadType::NORMAL,
false);
content::WaitForLoadStop(GetActiveWebContents());
// The request should be cleared after reload.
EXPECT_FALSE(HasPendingGetContentRequest());

// Try page loaded scenario
NavigateURL(https_server_.GetURL("docs.google.com", "/canvas.html"));
FetchPageContent(FROM_HERE, "", false);
EXPECT_TRUE(HasPendingGetContentRequest());

NavigateURL(https_server_.GetURL("a.com", "/canvas.html"), false);
// The request should be cleared after navigation.
EXPECT_FALSE(HasPendingGetContentRequest());
}

#if BUILDFLAG(ENABLE_TEXT_RECOGNITION)
IN_PROC_BROWSER_TEST_F(AIChatUIBrowserTest, PrintPreviewPagesLimit) {
NavigateURL(
Expand Down Expand Up @@ -344,8 +322,9 @@ IN_PROC_BROWSER_TEST_F(AIChatUIBrowserTest, MAYBE_PrintPreviewFallback) {
FetchPageContent(FROM_HERE, "this is the way");

// Does not fall back when there is regular DOM content
NavigateURL(https_server_.GetURL(
"a.com", "/long_canvas_with_dom_content.html", false));
NavigateURL(
https_server_.GetURL("a.com", "/long_canvas_with_dom_content.html"),
false);
FetchPageContent(FROM_HERE, "Or maybe not.");
}
#endif // BUILDFLAG(ENABLE_PRINT_PREVIEW)
Expand Down
145 changes: 142 additions & 3 deletions browser/ui/ai_chat/ai_chat_tab_helper_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,18 @@ class AIChatTabHelperUnitTest : public content::RenderViewHostTestHarness,
return std::make_unique<TestingProfile>();
}

// TODO(petemill): param for simulating page load
void NavigateTo(GURL url, bool is_same_page = false, std::string title = "") {
void NavigateTo(GURL url,
bool keep_loading = false,
bool is_same_page = false,
std::string title = "") {
if (title.empty()) {
title = base::StrCat({url.host(), url.path()});
}
std::unique_ptr<content::NavigationSimulator> simulator =
content::NavigationSimulator::CreateRendererInitiated(url, main_rfh());

simulator->SetKeepLoading(keep_loading);

if (is_same_page) {
simulator->CommitSameDocument();
} else {
Expand All @@ -125,6 +130,10 @@ class AIChatTabHelperUnitTest : public content::RenderViewHostTestHarness,
title);
}

void SimulateLoadFinished() {
helper_->DidFinishLoad(main_rfh(), helper_->GetPageURL());
}

void GetPageContent(ConversationHandler::GetPageContentCallback callback,
std::string_view invalidation_token) {
helper_->GetPageContent(std::move(callback), invalidation_token);
Expand All @@ -134,6 +143,10 @@ class AIChatTabHelperUnitTest : public content::RenderViewHostTestHarness,
helper_->TitleWasSet(entry);
}

content::TestWebContents* test_web_contents() {
return static_cast<content::TestWebContents*>(web_contents());
}

protected:
NiceMock<favicon::MockFaviconService> favicon_service_;
std::unique_ptr<NiceMock<MockAssociatedContentObserver>> observer_;
Expand Down Expand Up @@ -184,7 +197,7 @@ TEST_P(AIChatTabHelperUnitTest, OnNewPage) {
// Same-document navigation should not call OnNewPage if page title is the
// same
EXPECT_CALL(*observer_, OnAssociatedContentNavigated).Times(0);
NavigateTo(GURL("https://www.brave.com/2/3"), true, "www.brave.com/2");
NavigateTo(GURL("https://www.brave.com/2/3"), false, true, "www.brave.com/2");
testing::Mock::VerifyAndClearExpectations(&observer_);
// ...unless the page title changes before the next navigation.
EXPECT_CALL(*observer_, OnAssociatedContentNavigated)
Expand Down Expand Up @@ -333,4 +346,130 @@ TEST_P(AIChatTabHelperUnitTest,
GetPageContent(callback.Get(), "");
}

TEST_P(AIChatTabHelperUnitTest,
GetPageContent_PrintPreviewTriggeringUrlWaitForLoad) {
// A url that does by itself trigger print preview extraction.
NavigateTo(GURL("https://docs.google.com"), /*keep_loading=*/true);
base::MockCallback<ConversationHandler::GetPageContentCallback> callback;
// Not epecting callback to be run until page load.
EXPECT_CALL(callback, Run).Times(0);
if (is_print_preview_supported_) {
// Nothing should be called until page load
EXPECT_CALL(*page_content_fetcher_, FetchPageContent).Times(0);
EXPECT_CALL(*print_preview_extractor_, Extract).Times(0);
GetPageContent(callback.Get(), "");
testing::Mock::VerifyAndClearExpectations(&page_content_fetcher_);
testing::Mock::VerifyAndClearExpectations(&print_preview_extractor_);
testing::Mock::VerifyAndClearExpectations(&callback);

// Simulate page load should trigger check again and, even with
// empty content, callback should run.
EXPECT_CALL(callback, Run("", false, ""));
EXPECT_CALL(*page_content_fetcher_, FetchPageContent).Times(0);
EXPECT_CALL(*print_preview_extractor_, Extract)
.WillOnce(base::test::RunOnceCallback<1>(""));
SimulateLoadFinished();

testing::Mock::VerifyAndClearExpectations(&page_content_fetcher_);
testing::Mock::VerifyAndClearExpectations(&print_preview_extractor_);
testing::Mock::VerifyAndClearExpectations(&callback);
} else {
// FetchPageContent will not wait for page load. Let's test that the
// re-try will wait for page load.
EXPECT_CALL(*page_content_fetcher_, FetchPageContent)
.WillRepeatedly(
base::test::RunOnceCallbackRepeatedly<1>("", false, ""));
GetPageContent(callback.Get(), "");
testing::Mock::VerifyAndClearExpectations(&callback);

// Simulate page load should trigger check again and, even with
// empty content, callback should run.
EXPECT_CALL(callback, Run("", false, ""));
SimulateLoadFinished();

testing::Mock::VerifyAndClearExpectations(&page_content_fetcher_);
testing::Mock::VerifyAndClearExpectations(&callback);
}
}

TEST_P(AIChatTabHelperUnitTest, GetPageContent_RetryAfterLoad) {
// A url that does not by itself trigger print preview extraction.
NavigateTo(GURL("https://www.example.com"), /*keep_loading=*/true);
base::MockCallback<ConversationHandler::GetPageContentCallback> callback;

// FetchPageContent will not wait for page load. Let's test that the
// re-try will wait for page load.
EXPECT_CALL(*page_content_fetcher_, FetchPageContent)
.WillOnce(base::test::RunOnceCallback<1>("", false, ""));
if (is_print_preview_supported_) {
// Doesn't initialy ask for print preview extraction
EXPECT_CALL(*print_preview_extractor_, Extract).Times(0);
}
EXPECT_CALL(callback, Run).Times(0);
GetPageContent(callback.Get(), "");
testing::Mock::VerifyAndClearExpectations(&page_content_fetcher_);
if (is_print_preview_supported_) {
testing::Mock::VerifyAndClearExpectations(&print_preview_extractor_);
}
testing::Mock::VerifyAndClearExpectations(&callback);

// Simulate page load should trigger check again and, even with
// empty content, callback should run.
const std::string expected_content = "retried content";
if (is_print_preview_supported_) {
// First it will try to see if after page load there is real content
EXPECT_CALL(*page_content_fetcher_, FetchPageContent)
.WillOnce(base::test::RunOnceCallback<1>("", false, ""));
// And if it has no content, it will finally try print preview extraction
EXPECT_CALL(*print_preview_extractor_, Extract)
.WillOnce(base::test::RunOnceCallback<1>(expected_content));
} else {
EXPECT_CALL(*page_content_fetcher_, FetchPageContent)
.WillOnce(base::test::RunOnceCallback<1>(expected_content, false, ""));
}
EXPECT_CALL(callback, Run(expected_content, false, ""));
SimulateLoadFinished();

testing::Mock::VerifyAndClearExpectations(&page_content_fetcher_);
if (is_print_preview_supported_) {
testing::Mock::VerifyAndClearExpectations(&print_preview_extractor_);
}
testing::Mock::VerifyAndClearExpectations(&callback);
}

TEST_P(AIChatTabHelperUnitTest,
GetPageContent_ClearPendingCallbackOnNavigation) {
const GURL initial_url =
GURL(is_print_preview_supported_ ? "https://docs.google.com"
: "https://www.example.com");
for (const bool is_same_document : {false, true}) {
SCOPED_TRACE(testing::Message() << "Same document: " << is_same_document);
NavigateTo(initial_url,
/*keep_loading=*/true);
base::MockCallback<ConversationHandler::GetPageContentCallback> callback;
EXPECT_CALL(callback, Run).Times(0);
if (!is_print_preview_supported_) {
EXPECT_CALL(*page_content_fetcher_, FetchPageContent)
.WillOnce(base::test::RunOnceCallback<1>("", false, ""));
}
GetPageContent(callback.Get(), "");
testing::Mock::VerifyAndClearExpectations(&callback);

// Navigatng should result in our pending callback being run with no content
// and no content extraction initiated.
EXPECT_CALL(*page_content_fetcher_, FetchPageContent).Times(0);
if (is_print_preview_supported_) {
EXPECT_CALL(*print_preview_extractor_, Extract).Times(0);
}
EXPECT_CALL(callback, Run("", false, ""));
NavigateTo(initial_url.Resolve("/2"), /*keep_loading=*/true,
is_same_document);
testing::Mock::VerifyAndClearExpectations(&callback);
testing::Mock::VerifyAndClearExpectations(&page_content_fetcher_);
if (is_print_preview_supported_) {
testing::Mock::VerifyAndClearExpectations(&print_preview_extractor_);
}
}
}

} // namespace ai_chat
112 changes: 57 additions & 55 deletions components/ai_chat/content/browser/ai_chat_tab_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ void AIChatTabHelper::InnerWebContentsAttached(

void AIChatTabHelper::DidFinishLoad(content::RenderFrameHost* render_frame_host,
const GURL& validated_url) {
DVLOG(4) << __func__ << ": " << validated_url.spec();
if (validated_url == GetPageURL()) {
is_page_loaded_ = true;
if (pending_get_page_content_callback_) {
Expand Down Expand Up @@ -243,20 +244,11 @@ GURL AIChatTabHelper::GetPageURL() const {
return web_contents()->GetLastCommittedURL();
}

void AIChatTabHelper::GetPageContent(
ConversationHandler::GetPageContentCallback callback,
std::string_view invalidation_token) {
void AIChatTabHelper::GetPageContent(GetPageContentCallback callback,
std::string_view invalidation_token) {
bool is_pdf = IsPdf(web_contents());
if (is_pdf && !is_pdf_a11y_info_loaded_) {
if (pending_get_page_content_callback_) {
// TODO(petemill): Queue the callback in a OneShotEvent, instead of only
// allowing a single pending callback. At the moment, this doesn't matter
// since only higher level usage (|AssociatedContentDriver|) does
// queue calls to |GetPageContent| via a OneShotEvent.
std::move(pending_get_page_content_callback_).Run("", false, "");
}
// invalidation_token doesn't matter for PDF extraction.
pending_get_page_content_callback_ = std::move(callback);
SetPendingGetContentCallback(std::move(callback));
// PdfAccessibilityTree::AccessibilityModeChanged handles kPDFOcr changes
// with |always_load_or_reload_accessibility| is true
if (inner_web_contents_) {
Expand All @@ -280,65 +272,76 @@ void AIChatTabHelper::GetPageContent(
CheckPDFA11yTree();
return;
}
if (base::Contains(kPrintPreviewRetrievalHosts, GetPageURL().host_piece()) &&
print_preview_extraction_delegate_ != nullptr) {
if (is_page_loaded_) {
// Get content using a printing / OCR mechanism, instead of
// directly from the source.
print_preview_extraction_delegate_->Extract(
is_pdf,
base::BindOnce(&AIChatTabHelper::OnExtractPrintPreviewContentComplete,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
} else {
// Run print preview extraction after load
if (pending_get_page_content_callback_) {
std::move(pending_get_page_content_callback_).Run("", false, "");
}
pending_get_page_content_callback_ = std::move(callback);
if (base::Contains(kPrintPreviewRetrievalHosts, GetPageURL().host_piece())) {
// Get content using a printing / OCR mechanism, instead of
// directly from the source, if available.
DVLOG(1) << __func__ << " print preview url";
if (MaybePrintPreviewExtract(callback)) {
return;
}
} else {
page_content_fetcher_delegate_->FetchPageContent(
invalidation_token,
base::BindOnce(&AIChatTabHelper::OnFetchPageContentComplete,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}
page_content_fetcher_delegate_->FetchPageContent(
invalidation_token,
base::BindOnce(&AIChatTabHelper::OnFetchPageContentComplete,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}

void AIChatTabHelper::OnFetchPageContentComplete(
ConversationHandler::GetPageContentCallback callback,
GetPageContentCallback callback,
std::string content,
bool is_video,
std::string invalidation_token) {
base::TrimWhitespaceASCII(content, base::TRIM_ALL, &content);
if (!did_retry_get_page_content_after_page_load_ && content.empty() &&
!is_video) {
// Only try this once
did_retry_get_page_content_after_page_load_ = true;
DVLOG(1) << __func__ << "empty content, will retry once, is_page_loaded_="
<< is_page_loaded_;
if (!is_page_loaded_) {
// Retry after page is loaded, including possible print preview fallback
if (pending_get_page_content_callback_) {
std::move(pending_get_page_content_callback_).Run("", false, "");
}
pending_get_page_content_callback_ = std::move(callback);
// If content is empty, and page was not loaded yet, wait for page load.
// Once page load is complete, try again. If it's still empty, fallback
// to print preview extraction.
if (content.empty() && !is_video) {
// When page isn't loaded yet, wait until DidFinishLoad
DVLOG(1) << __func__ << " empty content, will attempt fallback";
if (MaybePrintPreviewExtract(callback)) {
return;
} else if (!is_page_loaded_) {
DVLOG(1) << "page was not loaded yet, will try again after load";
SetPendingGetContentCallback(std::move(callback));
return;
}
if (is_page_loaded_) {
// Fallback to print preview extraction
print_preview_extraction_delegate_->Extract(
IsPdf(web_contents()),
base::BindOnce(&AIChatTabHelper::OnExtractPrintPreviewContentComplete,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}
return;
// When print preview extraction isn't available, return empty content
DVLOG(1) << "no fallback available";
}
std::move(callback).Run(std::move(content), is_video,
std::move(invalidation_token));
}

void AIChatTabHelper::SetPendingGetContentCallback(
GetPageContentCallback callback) {
if (pending_get_page_content_callback_) {
std::move(pending_get_page_content_callback_).Run("", false, "");
}
pending_get_page_content_callback_ = std::move(callback);
}

bool AIChatTabHelper::MaybePrintPreviewExtract(
GetPageContentCallback& callback) {
if (print_preview_extraction_delegate_ == nullptr) {
DVLOG(1) << "print preview extraction not supported";
return false;
}
if (!is_page_loaded_) {
DVLOG(1) << "will extract print preview content when page is loaded";
SetPendingGetContentCallback(std::move(callback));
} else {
// When page is already loaded, fallback to print preview extraction
DVLOG(1) << "extracting print preview content now";
print_preview_extraction_delegate_->Extract(
IsPdf(web_contents()),
base::BindOnce(&AIChatTabHelper::OnExtractPrintPreviewContentComplete,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}
return true;
}

void AIChatTabHelper::OnExtractPrintPreviewContentComplete(
ConversationHandler::GetPageContentCallback callback,
GetPageContentCallback callback,
std::string content) {
// Invalidation token not applicable for print preview OCR
std::move(callback).Run(std::move(content), false, "");
Expand All @@ -351,7 +354,6 @@ std::u16string AIChatTabHelper::GetPageTitle() const {
void AIChatTabHelper::OnNewPage(int64_t navigation_id) {
DVLOG(3) << __func__ << " id: " << navigation_id;
AssociatedContentDriver::OnNewPage(navigation_id);
did_retry_get_page_content_after_page_load_ = false;
if (pending_get_page_content_callback_) {
std::move(pending_get_page_content_callback_).Run("", false, "");
}
Expand Down
Loading

0 comments on commit ee3279f

Please sign in to comment.