diff --git a/geoportal/c2cgeoportal_geoportal/__init__.py b/geoportal/c2cgeoportal_geoportal/__init__.py index a8c43010c4..0b093738cb 100644 --- a/geoportal/c2cgeoportal_geoportal/__init__.py +++ b/geoportal/c2cgeoportal_geoportal/__init__.py @@ -49,7 +49,6 @@ import sqlalchemy.orm import zope.event.classhandler from c2cgeoform import translator -from c2cwsgiutils.broadcast import decorator from c2cwsgiutils.health_check import HealthCheck from c2cwsgiutils.prometheus import MemoryMapCollector from deform import Form @@ -57,15 +56,15 @@ from papyrus.renderers import GeoJSON from prometheus_client.core import REGISTRY from pyramid.config import Configurator -from pyramid.httpexceptions import HTTPBadRequest, HTTPException +from pyramid.httpexceptions import HTTPException from pyramid.path import AssetResolver from pyramid_mako import add_mako_renderer -from sqlalchemy.orm import Session, joinedload +from sqlalchemy.orm import joinedload import c2cgeoportal_commons.models import c2cgeoportal_geoportal.views from c2cgeoportal_commons.models import InvalidateCacheEvent -from c2cgeoportal_geoportal.lib import C2CPregenerator, caching, check_collector, checker +from c2cgeoportal_geoportal.lib import C2CPregenerator, caching, check_collector, checker, oidc from c2cgeoportal_geoportal.lib.cacheversion import version_cache_buster from c2cgeoportal_geoportal.lib.common_headers import Cache, set_common_headers from c2cgeoportal_geoportal.lib.i18n import available_locale_names @@ -317,7 +316,6 @@ def get_user_from_request( """ from c2cgeoportal_commons.models import DBSession # pylint: disable=import-outside-toplevel from c2cgeoportal_commons.models.static import User # pylint: disable=import-outside-toplevel - from c2cgeoportal_geoportal.lib import oidc # pylint: disable=import-outside-toplevel assert DBSession is not None @@ -347,28 +345,10 @@ def get_user_from_request( ) user_info = oidc.OidcRemember(request).remember(token_response) - if openid_connect_config.get("provide_roles", False) is True: - from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel - Role, - ) - - request.user_ = oidc.DynamicUser( - username=user_info["username"], - email=user_info["email"], - settings_role=( - DBSession.query(Role).filter_by(name=user_info["settings_role"]).first() - if user_info.get("settings_role") is not None - else None - ), - roles=[ - DBSession.query(Role).filter_by(name=role).one() - for role in user_info.get("roles", []) - ], - ) - else: - request.user_ = DBSession.query(User).filter_by(email=user_info["email"]).first() - for user in DBSession.query(User).all(): - _LOG.error(user.username) + request.user_ = request.get_user_from_reminder( + user_info, + request.registry.settings.get("authentication", {}).get("openid_connect", {}), + ) else: # We know we will need the role object of the # user so we use joined loading @@ -517,6 +497,7 @@ def includeme(config: pyramid.config.Configurator) -> None: config.include("pyramid_mako") config.include("c2cwsgiutils.pyramid.includeme") + config.include(oidc.includeme) health_check = HealthCheck(config) config.registry["health_check"] = health_check diff --git a/geoportal/c2cgeoportal_geoportal/lib/oidc.py b/geoportal/c2cgeoportal_geoportal/lib/oidc.py index 3061ed609f..0afdbfabaf 100644 --- a/geoportal/c2cgeoportal_geoportal/lib/oidc.py +++ b/geoportal/c2cgeoportal_geoportal/lib/oidc.py @@ -28,7 +28,7 @@ import datetime import json import logging -from typing import NamedTuple, TypedDict +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, TypedDict, Union import pyramid.request import pyramid.response @@ -37,9 +37,11 @@ from pyramid.httpexceptions import HTTPBadRequest, HTTPInternalServerError, HTTPUnauthorized from pyramid.security import remember -from c2cgeoportal_commons.models import main from c2cgeoportal_geoportal.lib.caching import get_region +if TYPE_CHECKING: + from c2cgeoportal_commons.models import main, static + _LOG = logging.getLogger(__name__) _CACHE_REGION_OBJ = get_region("obj") @@ -52,8 +54,8 @@ class DynamicUser(NamedTuple): username: str email: str - settings_role: main.Role | None - roles: list[main.Role] + settings_role: Optional["main.Role"] + roles: list["main.Role"] @_CACHE_REGION_OBJ.cache_on_arguments() @@ -92,6 +94,81 @@ class OidcRememberObject(TypedDict): roles: list[str] +def get_remember_from_user_info( + user_info: dict[str, Any], remember_object: OidcRememberObject, settings: dict[str, Any] +) -> None: + """ + Fill the remember object from the user info. + + The remember object will be stored in a cookie to remember the user. + + :param user_info: The user info from the ID token or from the user info view according to the `query_user_info` configuration. + :param remember_object: The object to fill, by default with the `username`, `email`, `settings_role` and `roles`, + the corresponding field from `user_info` can be configured in `user_info_fields`. + :param settings: The OpenID Connect configuration. + """ + settings_fields = settings.get("user_info_fields", {}) + + for field_, default_field in ( + ("username", "name"), + ("email", "email"), + ("settings_role", None), + ("roles", None), + ): + user_info_field = settings_fields.get(field_, default_field) + if user_info_field is not None: + if user_info_field not in user_info: + _LOG.error( + "Field '%s' not found in user info, available: %s.", + user_info_field, + ", ".join(user_info.keys()), + ) + raise HTTPInternalServerError(f"Field '{user_info_field}' not found in user info.") + remember_object[field_] = user_info[user_info_field] # type: ignore[literal-required] + + +def get_user_from_remember( + remember_object: OidcRememberObject, settings: dict[str, Any], create_user: bool = False +) -> Union["static.User", DynamicUser] | None: + """ + Create a user from the remember object filled from `get_remember_from_user_info`. + + :param remember_object: The object to fill, by default with the `username`, `email`, `settings_role` and `roles`. + :param settings: The OpenID Connect configuration. + :param create_user: If the user should be created if it does not exist. + """ + from c2cgeoportal_commons import models # pylint: disable=import-outside-toplevel + from c2cgeoportal_commons.models import main, static # pylint: disable=import-outside-toplevel + + assert models.DBSession is not None + + user: static.User | DynamicUser | None + username = remember_object["username"] + assert username is not None + email = remember_object["email"] + assert email is not None + if settings.get("provide_roles", False) is False: + user = models.DBSession.query(static.User).filter_by(email=email).one_or_none() + if user is None and create_user is True: + user = static.User(username=username, email=email) + models.DBSession.add(user) + else: + user = DynamicUser( + username=username, + email=email, + settings_role=( + models.DBSession.query(main.Role).filter_by(name=remember_object["settings_role"]).first() + if remember_object.get("settings_role") is not None + else None + ), + roles=[ + models.DBSession.query(main.Role).filter_by(name=role).one() + for role in remember_object.get("roles", []) + ], + ) + return user + + class OidcRemember: """ Build the abject that we want to remember in the cookie. @@ -142,7 +219,6 @@ def remember( "settings_role": None, "roles": [], } - settings_fields = openid_connect.get("user_info_fields", {}) client = get_oidc_client(self.request) if openid_connect.get("query_user_info", False) is True: @@ -166,24 +242,15 @@ def remember( ), ) - for field_, default_field in ( - ("username", "name"), - ("email", "email"), - ("settings_role", None), - ("roles", None), - ): - user_info_field = settings_fields.get(field_, default_field) - if user_info_field is not None: - user_info_dict = user_info.dict() - if user_info_field not in user_info_dict: - _LOG.error( - "Field '%s' not found in user info, available: %s.", - user_info_field, - ", ".join(user_info_dict.keys()), - ) - raise HTTPInternalServerError(f"Field '{user_info_field}' not found in user info.") - remember_object[field_] = user_info_dict[user_info_field] # type: ignore[literal-required] - + self.request.get_remember_from_user_info(user_info.dict(), remember_object, openid_connect) self.request.response.headers.extend(remember(self.request, json.dumps(remember_object))) return remember_object + + +def includeme(config: pyramid.config.Configurator) -> None: + """ + Pyramid includeme function. + """ + config.add_request_method(get_remember_from_user_info, name="get_remember_from_user_info") + config.add_request_method(get_user_from_remember, name="get_user_from_remember") diff --git a/geoportal/c2cgeoportal_geoportal/views/login.py b/geoportal/c2cgeoportal_geoportal/views/login.py index 32d750d67f..f823322624 100644 --- a/geoportal/c2cgeoportal_geoportal/views/login.py +++ b/geoportal/c2cgeoportal_geoportal/views/login.py @@ -644,26 +644,9 @@ def oidc_callback(self) -> pyramid.response.Response: remember_object = oidc.OidcRemember(self.request).remember(token_response) - user: static.User | oidc.DynamicUser | None - if self.authentication_settings.get("openid_connect", {}).get("provide_roles", False) is False: - user = models.DBSession.query(static.User).filter_by(email=remember_object["email"]).one_or_none() - if user is None: - user = static.User(username=remember_object["username"], email=remember_object["email"]) - models.DBSession.add(user) - else: - user = oidc.DynamicUser( - username=remember_object["username"], - email=remember_object["email"], - settings_role=( - models.DBSession.query(main.Role).filter_by(name=remember_object["settings_role"]).first() - if remember_object.get("settings_role") is not None - else None - ), - roles=[ - models.DBSession.query(main.Role).filter_by(name=role).one() - for role in remember_object.get("roles", []) - ], - ) + user: static.User | oidc.DynamicUser | None = self.request.get_user_from_remember( + remember_object, self.authentication_settings + ) assert user is not None self.request.user_ = user diff --git a/geoportal/tests/functional/test_oidc.py b/geoportal/tests/functional/test_oidc.py index accfb1291f..efe58f4fcc 100644 --- a/geoportal/tests/functional/test_oidc.py +++ b/geoportal/tests/functional/test_oidc.py @@ -14,6 +14,8 @@ from tests.functional import setup_db from tests.functional import teardown_common as teardown_module # noqa, pylint: disable=unused-import +from c2cgeoportal_geoportal.lib import oidc + _OIDC_CONFIGURATION = { "issuer": "https://sso.example.com", "authorization_endpoint": "https://sso.example.com/authorize", @@ -66,6 +68,8 @@ def test_login(self): }, params={"came_from": "/came_from"}, ) + request.get_remember_from_user_info = oidc.get_remember_from_user_info + request.get_user_from_remember = oidc.get_user_from_remember responses.get("https://sso.example.com/.well-known/openid-configuration", json=_OIDC_CONFIGURATION) responses.get("https://sso.example.com/jwks", json=_OIDC_KEYS) @@ -118,6 +122,8 @@ def test_callback(self): "code_challenge": "code_challenge", }, ) + request.get_remember_from_user_info = oidc.get_remember_from_user_info + request.get_user_from_remember = oidc.get_user_from_remember responses.get("https://sso.example.com/.well-known/openid-configuration", json=_OIDC_CONFIGURATION) responses.get("https://sso.example.com/jwks", json=_OIDC_KEYS) responses.post(