Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add collection search extension #136

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ coverage.xml
*.log
.git
.envrc
*egg-info

venv
venv
env
2 changes: 1 addition & 1 deletion .github/workflows/cicd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
runs-on: ubuntu-latest
services:
pgstac:
image: ghcr.io/stac-utils/pgstac:v0.7.10
image: ghcr.io/stac-utils/pgstac:v0.8.6
env:
POSTGRES_USER: username
POSTGRES_PASSWORD: password
Expand Down
3 changes: 2 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

## [Unreleased]

- Fix Docker compose file, so example data can be loaded into database (author @zstatmanweil, https://github.com/stac-utils/stac-fastapi-pgstac/pull/142)
- Fix Docker compose file, so example data can be loaded into database (author @zstatmanweil, <https://github.com/stac-utils/stac-fastapi-pgstac/pull/142>)
- Add collection search extension

## [3.0.0] - 2024-08-02

Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
"orjson",
"pydantic",
"stac_pydantic==3.1.*",
"stac-fastapi.api~=3.0",
"stac-fastapi.extensions~=3.0",
"stac-fastapi.types~=3.0",
"stac-fastapi.api~=3.0.2",
"stac-fastapi.extensions~=3.0.2",
"stac-fastapi.types~=3.0.2",
"asyncpg",
"buildpg",
"brotli_asgi",
Expand Down
21 changes: 19 additions & 2 deletions stac_fastapi/pgstac/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fastapi.responses import ORJSONResponse
from stac_fastapi.api.app import StacApi
from stac_fastapi.api.models import (
EmptyRequest,
ItemCollectionUri,
create_get_request_model,
create_post_request_model,
Expand All @@ -22,6 +23,7 @@
TokenPaginationExtension,
TransactionExtension,
)
from stac_fastapi.extensions.core.collection_search import CollectionSearchExtension
from stac_fastapi.extensions.third_party import BulkTransactionExtension

from stac_fastapi.pgstac.config import Settings
Expand All @@ -48,12 +50,17 @@
}

if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"):
_enabled_extensions = enabled_extensions.split(",")
extensions = [
extensions_map[extension_name] for extension_name in enabled_extensions.split(",")
extension
for key, extension in extensions_map.items()
if key in _enabled_extensions
]
else:
_enabled_extensions = list(extensions_map.keys()) + ["collection_search"]
extensions = list(extensions_map.values())


if any(isinstance(ext, TokenPaginationExtension) for ext in extensions):
items_get_request_model = create_request_model(
model_name="ItemCollectionUri",
Expand All @@ -64,17 +71,27 @@
else:
items_get_request_model = ItemCollectionUri

if "collection_search" in _enabled_extensions:
collection_search_extension = CollectionSearchExtension.from_extensions(
extensions=extensions
)
collections_get_request_model = collection_search_extension.GET
else:
collection_search_extension = None
collections_get_request_model = EmptyRequest

post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
get_request_model = create_get_request_model(extensions)

api = StacApi(
settings=settings,
extensions=extensions,
extensions=extensions + [collection_search_extension],
client=CoreCrudClient(post_request_model=post_request_model), # type: ignore
response_class=ORJSONResponse,
items_get_request_model=items_get_request_model,
search_get_request_model=get_request_model,
search_post_request_model=post_request_model,
collections_get_request_model=collections_get_request_model,
)
app = api.app

Expand Down
242 changes: 169 additions & 73 deletions stac_fastapi/pgstac/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
from pygeofilter.parsers.cql2_text import parse as parse_cql2_text
from pypgstac.hydration import hydrate
from stac_fastapi.api.models import JSONResponse
from stac_fastapi.types.core import AsyncBaseCoreClient
from stac_fastapi.types.core import AsyncBaseCoreClient, Relations
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
from stac_fastapi.types.requests import get_base_url
from stac_fastapi.types.rfc3339 import DateTimeType
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection
from stac_pydantic.links import Relations
from stac_pydantic.shared import BBox, MimeTypes

from stac_fastapi.pgstac.config import Settings
Expand All @@ -39,17 +38,100 @@
class CoreCrudClient(AsyncBaseCoreClient):
"""Client for core endpoints defined by stac."""

async def all_collections(self, request: Request, **kwargs) -> Collections:
"""Read all collections from the database."""
async def all_collections( # noqa: C901
self,
request: Request,
# Extensions
bbox: Optional[BBox] = None,
datetime: Optional[DateTimeType] = None,
limit: Optional[int] = None,
Comment on lines +45 to +47
Copy link
Contributor

@alukach alukach Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts from discussion: consider adding id to the endpoint, despite it being omitted from the Collection Search Extension.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that the same as just calling /collections/:id ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use-case for including an ids parameter would be to limit the scope of a search in the context of scoped authentication for a STAC API, but we discussed some more and it probably makes more sense to use the filter extension for injecting scope limits in a search request.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ids would be a better parameter name than id to evoke the ability to provide multiple as a filter. And for our specific needs today, we will make use of filter as @hrodmn said. I believe it was @bitner who suggested adding the ability to filter by ID parameter, I'll let him weigh in on whether we should go without.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense just to be parallel to the items spec (and yes, "ids" plural which is how it works in items)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add it, I'd recommend to add a separate conformance class for it so that clients actually know whether it's supported or not.

PS: ids is not included in collection search as we just inherit from OGC API - Records, which doesn't have it. It's orthogonal to how ids is not part of OGC API - Features for items. ids is a STAC-specific thing.

query: Optional[str] = None,
token: Optional[str] = None,
fields: Optional[List[str]] = None,
sortby: Optional[str] = None,
filter: Optional[str] = None,
filter_lang: Optional[str] = None,
**kwargs,
) -> Collections:
"""Cross catalog search (GET).
Called with `GET /collections`.
Returns:
Collections which match the search criteria, returns all
collections by default.
"""

# Parse request parameters
base_args = {
"bbox": bbox,
"limit": limit,
"token": token,
"query": orjson.loads(unquote_plus(query)) if query else query,
}

clean = clean_search_args(
base_args=base_args,
datetime=datetime,
fields=fields,
sortby=sortby,
filter=filter,
filter_lang=filter_lang,
)

# Do the request
try:
search_request = self.post_request_model(**clean)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 😬 the issue is that here we don't have to the model used, we should add a collection_get_request_model attribute to the client or find a way to avoid this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand the problem correctly, we are just trying to dump the args into a json that pgstac will accept, but we are using a model for item search that is defined in stac_pydantic. Does this mean we want to add a new attribute to the CoreCrudClient that is equivalent to CollectionSearchExtension().GET?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's kinda weird because we will pass the Client to the CollectionSearch extension, which kinda result in a circular dependency, but that's how it's done for the search post model https://github.com/stac-utils/stac-fastapi/blob/main/stac_fastapi/types/stac_fastapi/types/core.py#L346

Does this mean we want to add a new attribute to the CoreCrudClient that is equivalent to CollectionSearchExtension().GET?

But yes I think we will have to do that!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried adding a collections_get_request_model attribute based on the BaseCollectionSearchGetRequest class but it doesn't work because that class is not a pydantic model. It looks like the post_request_model is a BaseSearchPostRequest which inherits from stac_pydantic.api.Search (which is based on a BaseModel), but BaseSearchGetRequest is based on the APIRequest class which does not inherit from a pydantic model and thus has no model_dump_json() method.

Here is what I tried:

@attr.s
class CoreCrudClient(AsyncBaseCoreClient):
    """Client for core endpoints defined by stac."""

    collections_get_request_model: BaseCollectionSearchGetRequest = attr.ib(
        default=BaseCollectionSearchGetRequest
    )

    async def all_collections(  # noqa: C901
        self,
        request: Request,
        # Extensions
        bbox: Optional[BBox] = None,
        datetime: Optional[DateTimeType] = None,
        limit: Optional[int] = None,
        query: Optional[str] = None,
        token: Optional[str] = None,
        fields: Optional[List[str]] = None,
        sortby: Optional[str] = None,
        filter: Optional[str] = None,
        filter_lang: Optional[str] = None,
        **kwargs,
    ) -> Collections:
        """Cross catalog search (GET).

        Called with `GET /collections`.

        Returns:
            Collections which match the search criteria, returns all
            collections by default.
        """

        # Parse request parameters
        base_args = {
            "bbox": bbox,
            "limit": limit,
            "token": token,
            "query": orjson.loads(unquote_plus(query)) if query else query,
        }

        clean = clean_search_args(
            base_args=base_args,
            datetime=datetime,
            fields=fields,
            sortby=sortby,
            filter=filter,
            filter_lang=filter_lang,
        )

        # Do the request
        try:
            search_request = self.collections_get_request_model(**clean)
        except ValidationError as e:
            raise HTTPException(
                status_code=400, detail=f"Invalid parameters provided {e}"
            ) from e

        return await self._collection_search_base(search_request, request=request)

So I think I will need to get the json for the pgstac query using a different approach.

Copy link
Author

@hrodmn hrodmn Aug 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got this working by using BaseCollectionSearchPostRequest (instead of BaseCollectionSearchGetRequest) which follows the pattern used in the item search. Does that approach make sense too you @vincentsarago?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think I see the path. I probably need to use create_request_model to make sure the appropriate args get added for the extensions. However, without creating a new pydantic model, I would still need to use the POST model because the search functionality relies on the pydantic model. That's probably something that belongs in stac-fastapi, right? Is there a particular reason the GET request models are a generic class rather than a pydantic model in stac-fastapi?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a particular reason the GET request models are a generic class rather than a pydantic model in stac-fastapi?

yes because pydantic model are interpreted as body parameter for endpoints

Copy link
Author

@hrodmn hrodmn Aug 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the PgstacSearch model is compatible with collection search, and it follows the pattern that we use in the get_search method which also uses self.post_request_model.

try:
search_request = self.post_request_model(**clean)
except ValidationError as e:
raise HTTPException(
status_code=400, detail=f"Invalid parameters provided {e}"
) from e

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patience while I figure this out. I think I need to double back and make sure I understand how all of these models are getting used. Upon closer examination, it looks like stac-fastapi-pgstac is just using the AsyncBaseCoreClient.post_request_model for all of the POST requests:
e.g.

search_request = self.post_request_model(**clean)

We provide a search_post_request_model to StacAPI:

post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
get_request_model = create_get_request_model(extensions)
api = StacApi(
settings=settings,
extensions=extensions,
client=CoreCrudClient(post_request_model=post_request_model), # type: ignore
response_class=ORJSONResponse,
items_get_request_model=items_get_request_model,
search_get_request_model=get_request_model,
search_post_request_model=post_request_model,
)

But I can't see how search_post_request_model is getting used by anything right now.

Copy link
Author

@hrodmn hrodmn Aug 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess search_get_request_model and search_post_request_model are a dead end. We pass the post_request_model that is based on PgstacSearch to the CoreCrudClient and that is what gets used to generate inputs for all of the pgstac search functions.

I am pretty sure we can use it for the /collections GET method like we do in the /collections/{collection_id}/items GET method, that is what my last commit implements.

except ValidationError as e:
raise HTTPException(
status_code=400, detail=f"Invalid parameters provided {e}"
) from e

return await self._collection_search_base(search_request, request=request)

async def _collection_search_base( # noqa: C901
self,
search_request: PgstacSearch,
request: Request,
) -> Collections:
"""Cross catalog search (GET).
Called with `GET /search`.
Args:
search_request: search request parameters.
Returns:
All collections which match the search criteria.
"""
base_url = get_base_url(request)
search_request_json = search_request.model_dump_json(
exclude_none=True, by_alias=True
)

try:
async with request.app.state.get_connection(request, "r") as conn:
q, p = render(
"""
SELECT * FROM collection_search(:req::text::jsonb);
""",
req=search_request_json,
)
collections_result: Collections = await conn.fetchval(q, *p)
except InvalidDatetimeFormatError as e:
raise InvalidQueryParameter(
f"Datetime parameter {search_request.datetime} is invalid."
) from e

next: Optional[str] = None
prev: Optional[str] = None

if links := collections_result.get("links"):
next = collections_result["links"].pop("next")
prev = collections_result["links"].pop("prev")

async with request.app.state.get_connection(request, "r") as conn:
collections = await conn.fetchval(
"""
SELECT * FROM all_collections();
"""
)
linked_collections: List[Collection] = []
collections = collections_result["collections"]
if collections is not None and len(collections) > 0:
for c in collections:
coll = Collection(**c)
Expand All @@ -71,25 +153,16 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:

linked_collections.append(coll)

links = [
{
"rel": Relations.root.value,
"type": MimeTypes.json,
"href": base_url,
},
{
"rel": Relations.parent.value,
"type": MimeTypes.json,
"href": base_url,
},
{
"rel": Relations.self.value,
"type": MimeTypes.json,
"href": urljoin(base_url, "collections"),
},
]
collection_list = Collections(collections=linked_collections or [], links=links)
return collection_list
links = await PagingLinks(
request=request,
next=next,
prev=prev,
).get_links()

return Collections(
collections=linked_collections or [],
links=links,
)

async def get_collection(
self, collection_id: str, request: Request, **kwargs
Expand Down Expand Up @@ -383,7 +456,7 @@ async def post_search(

return ItemCollection(**item_collection)

async def get_search( # noqa: C901
async def get_search(
self,
request: Request,
collections: Optional[List[str]] = None,
Expand Down Expand Up @@ -418,49 +491,15 @@ async def get_search( # noqa: C901
"query": orjson.loads(unquote_plus(query)) if query else query,
}

if filter:
if filter_lang == "cql2-text":
ast = parse_cql2_text(filter)
base_args["filter"] = orjson.loads(to_cql2(ast))
base_args["filter-lang"] = "cql2-json"

if datetime:
base_args["datetime"] = format_datetime_range(datetime)

if intersects:
base_args["intersects"] = orjson.loads(unquote_plus(intersects))

if sortby:
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
sort_param = []
for sort in sortby:
sortparts = re.match(r"^([+-]?)(.*)$", sort)
if sortparts:
sort_param.append(
{
"field": sortparts.group(2).strip(),
"direction": "desc" if sortparts.group(1) == "-" else "asc",
}
)
base_args["sortby"] = sort_param

if fields:
includes = set()
excludes = set()
for field in fields:
if field[0] == "-":
excludes.add(field[1:])
elif field[0] == "+":
includes.add(field[1:])
else:
includes.add(field)
base_args["fields"] = {"include": includes, "exclude": excludes}

# Remove None values from dict
clean = {}
for k, v in base_args.items():
if v is not None and v != []:
clean[k] = v
clean = clean_search_args(
base_args=base_args,
intersects=intersects,
datetime=datetime,
fields=fields,
sortby=sortby,
filter=filter,
filter_lang=filter_lang,
)

# Do the request
try:
Expand All @@ -471,3 +510,60 @@ async def get_search( # noqa: C901
) from e

return await self.post_search(search_request, request=request)


def clean_search_args( # noqa: C901
base_args: Dict[str, Any],
intersects: Optional[str] = None,
datetime: Optional[DateTimeType] = None,
fields: Optional[List[str]] = None,
sortby: Optional[str] = None,
filter: Optional[str] = None,
filter_lang: Optional[str] = None,
) -> Dict[str, Any]:
"""Clean up search arguments to match format expected by pgstac"""
if filter:
if filter_lang == "cql2-text":
ast = parse_cql2_text(filter)
base_args["filter"] = orjson.loads(to_cql2(ast))
base_args["filter-lang"] = "cql2-json"

if datetime:
base_args["datetime"] = format_datetime_range(datetime)

if intersects:
base_args["intersects"] = orjson.loads(unquote_plus(intersects))

if sortby:
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
sort_param = []
for sort in sortby:
sortparts = re.match(r"^([+-]?)(.*)$", sort)
if sortparts:
sort_param.append(
{
"field": sortparts.group(2).strip(),
"direction": "desc" if sortparts.group(1) == "-" else "asc",
}
)
base_args["sortby"] = sort_param

if fields:
includes = set()
excludes = set()
for field in fields:
if field[0] == "-":
excludes.add(field[1:])
elif field[0] == "+":
includes.add(field[1:])
else:
includes.add(field)
base_args["fields"] = {"include": includes, "exclude": excludes}

# Remove None values from dict
clean = {}
for k, v in base_args.items():
if v is not None and v != []:
clean[k] = v

return clean
Loading
Loading