Skip to content

Commit

Permalink
Add async support with memory and redis storage backends
Browse files Browse the repository at this point in the history
This addresses issue #60, but is still WIP. Support for more
backends is still needed.
  • Loading branch information
laurentS authored and alisaifee committed Nov 28, 2021
1 parent 0105492 commit be544ce
Show file tree
Hide file tree
Showing 11 changed files with 1,103 additions and 0 deletions.
Empty file added limits/_async/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions limits/_async/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import urllib

from limits.errors import ConfigurationError
from limits.storage.registry import SCHEMES

from .base import AsyncStorage
from .memory import AsyncMemoryStorage
from .redis import AsyncRedisStorage

def async_storage_from_string(storage_string: str, **options) -> AsyncStorage:
"""
factory function to get an instance of the async storage class based
on the uri of the storage
:param storage_string: a string of the form method://host:port
:return: an instance of :class:`limits._async.storage.AsyncStorage`
"""
scheme = urllib.parse.urlparse(storage_string).scheme
if scheme not in SCHEMES:
raise ConfigurationError(
"unknown storage scheme : %s" % storage_string
)
return SCHEMES[scheme](storage_string, **options)


__all__ = [
"storage_from_string",
"AsyncStorage",
"AsyncMemoryStorage",
"AsyncRedisStorage",
"AsyncRedisClusterStorage",
"AsyncRedisSentinelStorage",
]
64 changes: 64 additions & 0 deletions limits/_async/storage/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import threading
from abc import abstractmethod
from typing import Dict, Optional

from limits.storage.registry import StorageRegistry


class AsyncStorage(object, metaclass=StorageRegistry):
"""
Base class to extend when implementing an async storage backend.
"""

def __init__(self, uri: Optional[str] = None, **options: Dict) -> None:
self.lock = threading.RLock()

@abstractmethod
async def incr(
self, key: str, expiry: int, elastic_expiry: bool = False
) -> int:
"""
increments the counter for a given rate limit key
:param str key: the key to increment
:param int expiry: amount in seconds for the key to expire in
:param bool elastic_expiry: whether to keep extending the rate limit
window every hit.
"""
raise NotImplementedError

@abstractmethod
async def get(self, key: str) -> int:
"""
:param str key: the key to get the counter value for
"""
raise NotImplementedError

@abstractmethod
async def get_expiry(self, key: str) -> int:
"""
:param str key: the key to get the expiry for
"""
raise NotImplementedError

@abstractmethod
async def check(self) -> bool:
"""
check if storage is healthy
"""
raise NotImplementedError

@abstractmethod
async def reset(self) -> None:
"""
reset storage to clear limits
"""
raise NotImplementedError

@abstractmethod
async def clear(self, key: str) -> int:
"""
resets the rate limit key
:param str key: the key to clear rate limits for
"""
raise NotImplementedError
166 changes: 166 additions & 0 deletions limits/_async/storage/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import threading
import time
from typing import Dict, Tuple, List, Optional
from collections import Counter

from limits.storage.base import Storage


class LockableEntry(threading._RLock):
__slots__ = ["atime", "expiry"]

def __init__(self, expiry: int) -> None:
self.atime = time.time()
self.expiry = self.atime + expiry
super(LockableEntry, self).__init__()


class AsyncMemoryStorage(Storage):
"""
rate limit storage using :class:`collections.Counter`
as an in memory storage for fixed and elastic window strategies,
and a simple list to implement moving window strategy.
"""

STORAGE_SCHEME = ["amemory"]

def __init__(self, uri: Optional[str] = None, **_: Dict) -> None:
self.storage: Counter = Counter()
self.expirations: Dict = {}
self.events: Dict[str, List[LockableEntry]] = {}
self.timer = threading.Timer(0.01, self.__expire_events)
self.timer.start()
super(AsyncMemoryStorage, self).__init__(uri) # type: ignore

def __expire_events(self) -> None:
# this remains a sync function so we can pass it to
# threading.Timer
# TODO: can we replace threading.Timer with asyncio.sleep?
for key in self.events.keys():
for event in list(self.events[key]):
with event:
if (
event.expiry <= time.time()
and event in self.events[key]
):
self.events[key].remove(event)
for key in list(self.expirations.keys()):
if self.expirations[key] <= time.time():
self.storage.pop(key, None)
self.expirations.pop(key, None)

async def __schedule_expiry(self) -> None:
if not self.timer.is_alive():
self.timer = threading.Timer(0.01, self.__expire_events)
self.timer.start()

async def incr(
self, key: str, expiry: int, elastic_expiry: bool = False
) -> int:
"""
increments the counter for a given rate limit key
:param str key: the key to increment
:param int expiry: amount in seconds for the key to expire in
:param bool elastic_expiry: whether to keep extending the rate limit
window every hit.
"""
await self.get(key)
await self.__schedule_expiry()
self.storage[key] += 1
if elastic_expiry or self.storage[key] == 1:
self.expirations[key] = time.time() + expiry
return self.storage.get(key, 0)

async def get(self, key: str) -> int:
"""
:param str key: the key to get the counter value for
"""
if self.expirations.get(key, 0) <= time.time():
self.storage.pop(key, None)
self.expirations.pop(key, None)
return self.storage.get(key, 0)

async def clear(self, key: str) -> None:
"""
:param str key: the key to clear rate limits for
"""
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.events.pop(key, None)

async def acquire_entry(
self, key: str, limit: int, expiry: int, no_add: bool = False
) -> bool:
"""
:param str key: rate limit key to acquire an entry in
:param int limit: amount of entries allowed
:param int expiry: expiry of the entry
:param bool no_add: if False an entry is not actually acquired
but instead serves as a 'check'
:rtype: bool
"""
self.events.setdefault(key, [])
await self.__schedule_expiry()
timestamp = time.time()
try:
entry: Optional[LockableEntry] = self.events[key][limit - 1]
except IndexError:
entry = None
if entry and entry.atime >= timestamp - expiry:
return False
else:
if not no_add:
self.events[key].insert(0, LockableEntry(expiry))
return True

async def get_expiry(self, key: str) -> int:
"""
:param str key: the key to get the expiry for
"""
return int(self.expirations.get(key, -1))

async def get_num_acquired(self, key: str, expiry: int) -> int:
"""
returns the number of entries already acquired
:param str key: rate limit key to acquire an entry in
:param int expiry: expiry of the entry
"""
timestamp = time.time()
return (
len([k for k in self.events[key] if k.atime >= timestamp - expiry])
if self.events.get(key)
else 0
)

# FIXME: arg limit is not used
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> Tuple[int, int]:
"""
returns the starting point and the number of entries in the moving
window
:param str key: rate limit key
:param int expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
timestamp = time.time()
acquired = await self.get_num_acquired(key, expiry)
for item in self.events.get(key, []):
if item.atime >= timestamp - expiry:
return int(item.atime), acquired
return int(timestamp), acquired

async def check(self) -> bool:
"""
check if storage is healthy
"""
return True

async def reset(self) -> None:
self.storage.clear()
self.expirations.clear()
self.events.clear()
Loading

0 comments on commit be544ce

Please sign in to comment.