Skip to content

Commit

Permalink
OpenID Connect: Add hook to be able to customize role creation
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunner committed Sep 20, 2024
1 parent 758fc00 commit 72226de
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 71 deletions.
35 changes: 8 additions & 27 deletions geoportal/c2cgeoportal_geoportal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,22 @@
import sqlalchemy.orm
import zope.event.classhandler
from c2cgeoform import translator
from c2cwsgiutils.broadcast import decorator
from c2cwsgiutils.health_check import HealthCheck
from c2cwsgiutils.prometheus import MemoryMapCollector
from deform import Form
from dogpile.cache import register_backend # type: ignore[attr-defined]
from papyrus.renderers import GeoJSON
from prometheus_client.core import REGISTRY
from pyramid.config import Configurator
from pyramid.httpexceptions import HTTPBadRequest, HTTPException
from pyramid.httpexceptions import HTTPException
from pyramid.path import AssetResolver
from pyramid_mako import add_mako_renderer
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm import joinedload

import c2cgeoportal_commons.models
import c2cgeoportal_geoportal.views
from c2cgeoportal_commons.models import InvalidateCacheEvent
from c2cgeoportal_geoportal.lib import C2CPregenerator, caching, check_collector, checker
from c2cgeoportal_geoportal.lib import C2CPregenerator, caching, check_collector, checker, oidc
from c2cgeoportal_geoportal.lib.cacheversion import version_cache_buster
from c2cgeoportal_geoportal.lib.common_headers import Cache, set_common_headers
from c2cgeoportal_geoportal.lib.i18n import available_locale_names
Expand Down Expand Up @@ -317,7 +316,6 @@ def get_user_from_request(
"""
from c2cgeoportal_commons.models import DBSession # pylint: disable=import-outside-toplevel
from c2cgeoportal_commons.models.static import User # pylint: disable=import-outside-toplevel
from c2cgeoportal_geoportal.lib import oidc # pylint: disable=import-outside-toplevel

assert DBSession is not None

Expand Down Expand Up @@ -347,28 +345,10 @@ def get_user_from_request(
)
user_info = oidc.OidcRemember(request).remember(token_response)

if openid_connect_config.get("provide_roles", False) is True:
from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel
Role,
)

request.user_ = oidc.DynamicUser(
username=user_info["username"],
email=user_info["email"],
settings_role=(
DBSession.query(Role).filter_by(name=user_info["settings_role"]).first()
if user_info.get("settings_role") is not None
else None
),
roles=[
DBSession.query(Role).filter_by(name=role).one()
for role in user_info.get("roles", [])
],
)
else:
request.user_ = DBSession.query(User).filter_by(email=user_info["email"]).first()
for user in DBSession.query(User).all():
_LOG.error(user.username)
request.user_ = request.get_user_from_reminder(
user_info,
request.registry.settings.get("authentication", {}).get("openid_connect", {}),
)
else:
# We know we will need the role object of the
# user so we use joined loading
Expand Down Expand Up @@ -517,6 +497,7 @@ def includeme(config: pyramid.config.Configurator) -> None:

config.include("pyramid_mako")
config.include("c2cwsgiutils.pyramid.includeme")
config.include(oidc.includeme)
health_check = HealthCheck(config)
config.registry["health_check"] = health_check

Expand Down
123 changes: 100 additions & 23 deletions geoportal/c2cgeoportal_geoportal/lib/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import datetime
import json
import logging
from typing import NamedTuple, TypedDict
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, TypedDict, Union

import pyramid.request
import pyramid.response
Expand All @@ -37,9 +37,11 @@
from pyramid.httpexceptions import HTTPBadRequest, HTTPInternalServerError, HTTPUnauthorized
from pyramid.security import remember

from c2cgeoportal_commons.models import main
from c2cgeoportal_geoportal.lib.caching import get_region

if TYPE_CHECKING:
from c2cgeoportal_commons.models import main, static

_LOG = logging.getLogger(__name__)
_CACHE_REGION_OBJ = get_region("obj")

Expand All @@ -52,8 +54,8 @@ class DynamicUser(NamedTuple):

username: str
email: str
settings_role: main.Role | None
roles: list[main.Role]
settings_role: Optional["main.Role"]
roles: list["main.Role"]


@_CACHE_REGION_OBJ.cache_on_arguments()
Expand Down Expand Up @@ -92,6 +94,91 @@ class OidcRememberObject(TypedDict):
roles: list[str]


def get_remember_from_user_info(
request: pyramid.request.Request, user_info: dict[str, Any], remember_object: OidcRememberObject
) -> None:
"""
Fill the remember object from the user info.
The remember object will be stored in a cookie to remember the user.
:param user_info: The user info from the ID token or from the user info view according to the `query_user_info` configuration.
:param remember_object: The object to fill, by default with the `username`, `email`, `settings_role` and `roles`,
the corresponding field from `user_info` can be configured in `user_info_fields`.
:param settings: The OpenID Connect configuration.
"""
settings_fields = (
request.registry.settings.get("authentication", {})
.get("openid_connect", {})
.get("user_info_fields", {})
)

for field_, default_field in (
("username", "name"),
("email", "email"),
("settings_role", None),
("roles", None),
):
user_info_field = settings_fields.get(field_, default_field)
if user_info_field is not None:
if user_info_field not in user_info:
_LOG.error(
"Field '%s' not found in user info, available: %s.",
user_info_field,
", ".join(user_info.keys()),
)
raise HTTPInternalServerError(f"Field '{user_info_field}' not found in user info.")
remember_object[field_] = user_info[user_info_field] # type: ignore[literal-required]


def get_user_from_remember(
request: pyramid.request.Request, remember_object: OidcRememberObject, create_user: bool = False
) -> Union["static.User", DynamicUser] | None:
"""
Create a user from the remember object filled from `get_remember_from_user_info`.
:param remember_object: The object to fill, by default with the `username`, `email`, `settings_role` and `roles`.
:param settings: The OpenID Connect configuration.
:param create_user: If the user should be created if it does not exist.
"""
from c2cgeoportal_commons import models # pylint: disable=import-outside-toplevel
from c2cgeoportal_commons.models import main, static # pylint: disable=import-outside-toplevel

assert models.DBSession is not None

user: static.User | DynamicUser | None
username = remember_object["username"]
assert username is not None
email = remember_object["email"]
assert email is not None

provide_roles = (
request.registry.settings.get("authentication", {})
.get("openid_connect", {})
.get("provide_roles", False)
)
if provide_roles is False:
user = models.DBSession.query(static.User).filter_by(email=email).one_or_none()
if user is None and create_user is True:
user = static.User(username=username, email=email)
models.DBSession.add(user)
else:
user = DynamicUser(
username=username,
email=email,
settings_role=(
models.DBSession.query(main.Role).filter_by(name=remember_object["settings_role"]).first()
if remember_object.get("settings_role") is not None
else None
),
roles=[
models.DBSession.query(main.Role).filter_by(name=role).one()
for role in remember_object.get("roles", [])
],
)
return user


class OidcRemember:
"""
Build the abject that we want to remember in the cookie.
Expand Down Expand Up @@ -142,7 +229,6 @@ def remember(
"settings_role": None,
"roles": [],
}
settings_fields = openid_connect.get("user_info_fields", {})
client = get_oidc_client(self.request)

if openid_connect.get("query_user_info", False) is True:
Expand All @@ -166,24 +252,15 @@ def remember(
),
)

for field_, default_field in (
("username", "name"),
("email", "email"),
("settings_role", None),
("roles", None),
):
user_info_field = settings_fields.get(field_, default_field)
if user_info_field is not None:
user_info_dict = user_info.dict()
if user_info_field not in user_info_dict:
_LOG.error(
"Field '%s' not found in user info, available: %s.",
user_info_field,
", ".join(user_info_dict.keys()),
)
raise HTTPInternalServerError(f"Field '{user_info_field}' not found in user info.")
remember_object[field_] = user_info_dict[user_info_field] # type: ignore[literal-required]

self.request.get_remember_from_user_info(user_info.dict(), remember_object)
self.request.response.headers.extend(remember(self.request, json.dumps(remember_object)))

return remember_object


def includeme(config: pyramid.config.Configurator) -> None:
"""
Pyramid includeme function.
"""
config.add_request_method(get_remember_from_user_info, name="get_remember_from_user_info")
config.add_request_method(get_user_from_remember, name="get_user_from_remember")
21 changes: 1 addition & 20 deletions geoportal/c2cgeoportal_geoportal/views/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,26 +644,7 @@ def oidc_callback(self) -> pyramid.response.Response:

remember_object = oidc.OidcRemember(self.request).remember(token_response)

user: static.User | oidc.DynamicUser | None
if self.authentication_settings.get("openid_connect", {}).get("provide_roles", False) is False:
user = models.DBSession.query(static.User).filter_by(email=remember_object["email"]).one_or_none()
if user is None:
user = static.User(username=remember_object["username"], email=remember_object["email"])
models.DBSession.add(user)
else:
user = oidc.DynamicUser(
username=remember_object["username"],
email=remember_object["email"],
settings_role=(
models.DBSession.query(main.Role).filter_by(name=remember_object["settings_role"]).first()
if remember_object.get("settings_role") is not None
else None
),
roles=[
models.DBSession.query(main.Role).filter_by(name=role).one()
for role in remember_object.get("roles", [])
],
)
user: static.User | oidc.DynamicUser | None = self.request.get_user_from_remember(remember_object)
assert user is not None
self.request.user_ = user

Expand Down
11 changes: 10 additions & 1 deletion geoportal/tests/functional/test_oidc.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import base64
import re
import types
import urllib.parse
from http.client import responses
from unittest import TestCase

import jwt
import responses
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from pyramid import testing
from tests.functional import cleanup_db, create_dummy_request
from tests.functional import setup_common as setup_module # noqa, pylint: disable=unused-import
from tests.functional import setup_db
from tests.functional import teardown_common as teardown_module # noqa, pylint: disable=unused-import

from c2cgeoportal_geoportal.lib import oidc

_OIDC_CONFIGURATION = {
"issuer": "https://sso.example.com",
"authorization_endpoint": "https://sso.example.com/authorize",
Expand Down Expand Up @@ -41,6 +43,11 @@
}


def includeme(request):
request.get_remember_from_user_info = types.MethodType(oidc.get_remember_from_user_info, request)
request.get_user_from_remember = types.MethodType(oidc.get_user_from_remember, request)


class TestLogin(TestCase):
def setUp(self):
setup_db()
Expand All @@ -66,6 +73,7 @@ def test_login(self):
},
params={"came_from": "/came_from"},
)
includeme(request)
responses.get("https://sso.example.com/.well-known/openid-configuration", json=_OIDC_CONFIGURATION)
responses.get("https://sso.example.com/jwks", json=_OIDC_KEYS)

Expand Down Expand Up @@ -118,6 +126,7 @@ def test_callback(self):
"code_challenge": "code_challenge",
},
)
includeme(request)
responses.get("https://sso.example.com/.well-known/openid-configuration", json=_OIDC_CONFIGURATION)
responses.get("https://sso.example.com/jwks", json=_OIDC_KEYS)
responses.post(
Expand Down

0 comments on commit 72226de

Please sign in to comment.