Skip to content

Commit

Permalink
[Community] handle auto refresh jwt token
Browse files Browse the repository at this point in the history
  • Loading branch information
GuillaumeDSM committed Sep 3, 2024
1 parent 88a9602 commit 7b88690
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 13 deletions.
1 change: 1 addition & 0 deletions additional_tests/supabase_backend_tests/.env.template
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ SUPABASE_BACKEND_KEY=
SUPABASE_BACKEND_CLIENT_1_EMAIL=
SUPABASE_BACKEND_CLIENT_1_PASSWORD=
SUPABASE_BACKEND_CLIENT_1_AUTH_KEY=
SUPABASE_BACKEND_CLIENT_1_EXPIRED_JWT_TOKEN=

SUPABASE_BACKEND_CLIENT_2_EMAIL=
SUPABASE_BACKEND_CLIENT_2_PASSWORD=
Expand Down
4 changes: 4 additions & 0 deletions additional_tests/supabase_backend_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def get_backend_client_auth_key(identifier):
return os.getenv(f"SUPABASE_BACKEND_CLIENT_{identifier}_AUTH_KEY")


def get_backend_client_expired_jwt_token(identifier):
return os.getenv(f"SUPABASE_BACKEND_CLIENT_{identifier}_EXPIRED_JWT_TOKEN")


def _get_backend_service_key():
return os.getenv(f"SUPABASE_BACKEND_SERVICE_KEY")

Expand Down
38 changes: 37 additions & 1 deletion additional_tests/supabase_backend_tests/test_user_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@
# You should have received a copy of the GNU General Public
# License along with OctoBot. If not, see <https://www.gnu.org/licenses/>.
import time

import mock
import postgrest
import pytest

import octobot_commons.configuration as commons_configuration
import octobot_commons.authentication as authentication
import octobot.community as community
import octobot.community.supabase_backend as supabase_backend
import octobot.community.errors as community_errors
import octobot.community.supabase_backend.enums as supabase_backend_enums
from additional_tests.supabase_backend_tests import authenticated_client_1, authenticated_client_2, \
admin_client, anon_client, get_backend_api_creds, skip_if_no_service_key, get_backend_client_creds, \
get_backend_client_auth_key
get_backend_client_auth_key, get_backend_client_expired_jwt_token


# All test coroutines will be treated as marked.
Expand Down Expand Up @@ -162,3 +167,34 @@ async def test_sign_in_with_auth_token():
finally:
if supabase_client:
await supabase_client.aclose()


async def test_expired_jwt_token(authenticated_client_1):
initial_email = (await authenticated_client_1.get_user())[supabase_backend_enums.UserKeys.EMAIL.value]

# refreshing session is working
await authenticated_client_1.refresh_session()
# does not raise
bots = await authenticated_client_1.fetch_bots()
assert (await authenticated_client_1.get_user())[supabase_backend_enums.UserKeys.EMAIL.value] == initial_email

# use expired jwt token
expired_jwt_token = get_backend_client_expired_jwt_token(1)

# simulate auth using this token
session = mock.Mock(access_token=expired_jwt_token)
authenticated_client_1._listen_to_auth_events(
"SIGNED_IN", session
)

# now raising "APIError: JWT expired"
with pytest.raises(postgrest.APIError):
await authenticated_client_1.fetch_bots()

# now raising "APIError: JWT expired" which is converted into community_errors.SessionTokenExpiredError
with pytest.raises(community_errors.SessionTokenExpiredError):
with supabase_backend.error_describer():
await authenticated_client_1.fetch_bots()



18 changes: 9 additions & 9 deletions octobot/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public
# License along with OctoBot. If not, see <https://www.gnu.org/licenses/>.
import argparse
import contextlib
import os
import sys
import multiprocessing
Expand Down Expand Up @@ -197,15 +198,13 @@ async def _get_authenticated_community_if_possible(config, logger):
try:
if not community_auth.is_initialized():
if constants.IS_CLOUD_ENV:
if constants.USER_ACCOUNT_EMAIL and constants.USER_AUTH_KEY:
try:
logger.debug("Attempting auth key authentication")
await community_auth.login(
constants.USER_ACCOUNT_EMAIL, None, auth_key=constants.USER_AUTH_KEY
)
except authentication.AuthenticationError as err:
logger.info(f"Auth key auth failure ({err}). Trying other methods if available.")
if constants.USER_ACCOUNT_EMAIL and constants.USER_PASSWORD_TOKEN:
authenticated = False
try:
logger.debug("Attempting auth key authentication")
authenticated = await community_auth.auto_reauthenticate()
except authentication.AuthenticationError as err:
logger.info(f"Auth key auth failure ({err}). Trying other methods if available.")
if not authenticated and constants.USER_ACCOUNT_EMAIL and constants.USER_PASSWORD_TOKEN:
try:
logger.debug("Attempting password token authentication")
await community_auth.login(
Expand All @@ -229,6 +228,7 @@ async def _get_authenticated_community_if_possible(config, logger):


async def _async_load_community_data(community_auth, config, logger, is_first_startup):
is_first_startup = True #tmp
if constants.IS_CLOUD_ENV and is_first_startup:
if not community_auth.is_logged_in():
raise authentication.FailedAuthentication(
Expand Down
36 changes: 36 additions & 0 deletions octobot/community/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,31 @@
import octobot_trading.enums as trading_enums


def expired_session_retrier(func):
async def expired_session_retrier_wrapper(*args, **kwargs):
self = args[0]
try:
with supabase_backend.error_describer():
return await func(*args, **kwargs)
except errors.SessionTokenExpiredError:
try:
with supabase_backend.error_describer():
self.logger.info(f"Expired session, trying to refresh token.")
await self.supabase_client.refresh_session()
return await func(*args, **kwargs)
except errors.SessionTokenExpiredError as err:
if await self.auto_reauthenticate():
self.logger.error(
f"Impossible to use default refresh token, using saved auth details instead."
)
return await func(*args, **kwargs)
# can't refresh token: logout
self.logger.warning(f"Expired session, please re-authenticate. {err}")
await self.logout()
return expired_session_retrier_wrapper

def _bot_data_update(func):
@expired_session_retrier
async def bot_data_update_wrapper(*args, raise_errors=False, **kwargs):
self = args[0]
if not self.is_logged_in_and_has_selected_bot():
Expand All @@ -49,6 +73,9 @@ async def bot_data_update_wrapper(*args, raise_errors=False, **kwargs):
try:
self.logger.debug(f"bot_data_update: {func.__name__} initiated.")
return await func(*args, **kwargs)
except errors.SessionTokenExpiredError:
# requried by expired_session_retrier
raise
except Exception as err:
if raise_errors:
raise err
Expand Down Expand Up @@ -329,6 +356,15 @@ async def login(
if self.is_logged_in():
await self.on_signed_in(minimal=minimal)

async def auto_reauthenticate(self) -> bool:
if constants.IS_CLOUD_ENV and constants.USER_ACCOUNT_EMAIL and constants.USER_AUTH_KEY:
self.logger.debug("Attempting auth key authentication")
await self.login(
constants.USER_ACCOUNT_EMAIL, None, auth_key=constants.USER_AUTH_KEY
)
return self.is_logged_in()
return False

async def register(self, email, password):
if self.must_be_authenticated_through_authenticator():
raise authentication.AuthenticationError("Creating a new account is not authorized on this environment.")
Expand Down
4 changes: 4 additions & 0 deletions octobot/community/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ class StatusCodeRequestError(RequestError):
pass


class SessionTokenExpiredError(commons_authentication.AuthenticationError):
pass


class BotError(commons_authentication.UnavailableError):
pass

Expand Down
2 changes: 2 additions & 0 deletions octobot/community/supabase_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from octobot.community.supabase_backend import community_supabase_client
from octobot.community.supabase_backend.community_supabase_client import (
error_describer,
CommunitySupabaseClient,
HTTP_RETRY_COUNT,
)
Expand All @@ -33,6 +34,7 @@
"SyncConfigurationStorage",
"ASyncConfigurationStorage",
"AuthenticatedAsyncSupabaseClient",
"error_describer",
"CommunitySupabaseClient",
"HTTP_RETRY_COUNT",
]
17 changes: 14 additions & 3 deletions octobot/community/supabase_backend/community_supabase_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
import httpx
import uuid
import json

import contextlib
import aiohttp

import gotrue.errors
import gotrue.types
import postgrest
Expand Down Expand Up @@ -56,6 +57,16 @@
HTTP_RETRY_COUNT = 5


@contextlib.contextmanager
def error_describer():
try:
yield
except postgrest.APIError as err:
if "jwt expired" in str(err).lower():
raise errors.SessionTokenExpiredError(err) from err
raise


def _httpx_retrier(f):
async def httpx_retrier_wrapper(*args, **kwargs):
resp = None
Expand Down Expand Up @@ -160,9 +171,9 @@ async def restore_session(self):
if not self.is_signed_in():
raise authentication.FailedAuthentication()

async def refresh_session(self):
async def refresh_session(self, refresh_token: typing.Union[str, None] = None):
try:
await self.auth.refresh_session()
await self.auth.refresh_session(refresh_token=refresh_token)
except gotrue.errors.AuthError as err:
raise authentication.AuthenticationError(err) from err

Expand Down

0 comments on commit 7b88690

Please sign in to comment.