-
-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add async support with memory and redis storage backends
This addresses issue #60, but is still WIP. Support for more backends is still needed.
- Loading branch information
Showing
11 changed files
with
1,103 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import urllib | ||
|
||
from limits.errors import ConfigurationError | ||
from limits.storage.registry import SCHEMES | ||
|
||
from .base import AsyncStorage | ||
from .memory import AsyncMemoryStorage | ||
from .redis import AsyncRedisStorage | ||
|
||
def async_storage_from_string(storage_string: str, **options) -> AsyncStorage: | ||
""" | ||
factory function to get an instance of the async storage class based | ||
on the uri of the storage | ||
:param storage_string: a string of the form method://host:port | ||
:return: an instance of :class:`limits._async.storage.AsyncStorage` | ||
""" | ||
scheme = urllib.parse.urlparse(storage_string).scheme | ||
if scheme not in SCHEMES: | ||
raise ConfigurationError( | ||
"unknown storage scheme : %s" % storage_string | ||
) | ||
return SCHEMES[scheme](storage_string, **options) | ||
|
||
|
||
__all__ = [ | ||
"storage_from_string", | ||
"AsyncStorage", | ||
"AsyncMemoryStorage", | ||
"AsyncRedisStorage", | ||
"AsyncRedisClusterStorage", | ||
"AsyncRedisSentinelStorage", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import threading | ||
from abc import abstractmethod | ||
from typing import Dict, Optional | ||
|
||
from limits.storage.registry import StorageRegistry | ||
|
||
|
||
class AsyncStorage(object, metaclass=StorageRegistry): | ||
""" | ||
Base class to extend when implementing an async storage backend. | ||
""" | ||
|
||
def __init__(self, uri: Optional[str] = None, **options: Dict) -> None: | ||
self.lock = threading.RLock() | ||
|
||
@abstractmethod | ||
async def incr( | ||
self, key: str, expiry: int, elastic_expiry: bool = False | ||
) -> int: | ||
""" | ||
increments the counter for a given rate limit key | ||
:param str key: the key to increment | ||
:param int expiry: amount in seconds for the key to expire in | ||
:param bool elastic_expiry: whether to keep extending the rate limit | ||
window every hit. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
async def get(self, key: str) -> int: | ||
""" | ||
:param str key: the key to get the counter value for | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
async def get_expiry(self, key: str) -> int: | ||
""" | ||
:param str key: the key to get the expiry for | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
async def check(self) -> bool: | ||
""" | ||
check if storage is healthy | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
async def reset(self) -> None: | ||
""" | ||
reset storage to clear limits | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
async def clear(self, key: str) -> int: | ||
""" | ||
resets the rate limit key | ||
:param str key: the key to clear rate limits for | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import threading | ||
import time | ||
from typing import Dict, Tuple, List, Optional | ||
from collections import Counter | ||
|
||
from limits.storage.base import Storage | ||
|
||
|
||
class LockableEntry(threading._RLock): | ||
__slots__ = ["atime", "expiry"] | ||
|
||
def __init__(self, expiry: int) -> None: | ||
self.atime = time.time() | ||
self.expiry = self.atime + expiry | ||
super(LockableEntry, self).__init__() | ||
|
||
|
||
class AsyncMemoryStorage(Storage): | ||
""" | ||
rate limit storage using :class:`collections.Counter` | ||
as an in memory storage for fixed and elastic window strategies, | ||
and a simple list to implement moving window strategy. | ||
""" | ||
|
||
STORAGE_SCHEME = ["amemory"] | ||
|
||
def __init__(self, uri: Optional[str] = None, **_: Dict) -> None: | ||
self.storage: Counter = Counter() | ||
self.expirations: Dict = {} | ||
self.events: Dict[str, List[LockableEntry]] = {} | ||
self.timer = threading.Timer(0.01, self.__expire_events) | ||
self.timer.start() | ||
super(AsyncMemoryStorage, self).__init__(uri) # type: ignore | ||
|
||
def __expire_events(self) -> None: | ||
# this remains a sync function so we can pass it to | ||
# threading.Timer | ||
# TODO: can we replace threading.Timer with asyncio.sleep? | ||
for key in self.events.keys(): | ||
for event in list(self.events[key]): | ||
with event: | ||
if ( | ||
event.expiry <= time.time() | ||
and event in self.events[key] | ||
): | ||
self.events[key].remove(event) | ||
for key in list(self.expirations.keys()): | ||
if self.expirations[key] <= time.time(): | ||
self.storage.pop(key, None) | ||
self.expirations.pop(key, None) | ||
|
||
async def __schedule_expiry(self) -> None: | ||
if not self.timer.is_alive(): | ||
self.timer = threading.Timer(0.01, self.__expire_events) | ||
self.timer.start() | ||
|
||
async def incr( | ||
self, key: str, expiry: int, elastic_expiry: bool = False | ||
) -> int: | ||
""" | ||
increments the counter for a given rate limit key | ||
:param str key: the key to increment | ||
:param int expiry: amount in seconds for the key to expire in | ||
:param bool elastic_expiry: whether to keep extending the rate limit | ||
window every hit. | ||
""" | ||
await self.get(key) | ||
await self.__schedule_expiry() | ||
self.storage[key] += 1 | ||
if elastic_expiry or self.storage[key] == 1: | ||
self.expirations[key] = time.time() + expiry | ||
return self.storage.get(key, 0) | ||
|
||
async def get(self, key: str) -> int: | ||
""" | ||
:param str key: the key to get the counter value for | ||
""" | ||
if self.expirations.get(key, 0) <= time.time(): | ||
self.storage.pop(key, None) | ||
self.expirations.pop(key, None) | ||
return self.storage.get(key, 0) | ||
|
||
async def clear(self, key: str) -> None: | ||
""" | ||
:param str key: the key to clear rate limits for | ||
""" | ||
self.storage.pop(key, None) | ||
self.expirations.pop(key, None) | ||
self.events.pop(key, None) | ||
|
||
async def acquire_entry( | ||
self, key: str, limit: int, expiry: int, no_add: bool = False | ||
) -> bool: | ||
""" | ||
:param str key: rate limit key to acquire an entry in | ||
:param int limit: amount of entries allowed | ||
:param int expiry: expiry of the entry | ||
:param bool no_add: if False an entry is not actually acquired | ||
but instead serves as a 'check' | ||
:rtype: bool | ||
""" | ||
self.events.setdefault(key, []) | ||
await self.__schedule_expiry() | ||
timestamp = time.time() | ||
try: | ||
entry: Optional[LockableEntry] = self.events[key][limit - 1] | ||
except IndexError: | ||
entry = None | ||
if entry and entry.atime >= timestamp - expiry: | ||
return False | ||
else: | ||
if not no_add: | ||
self.events[key].insert(0, LockableEntry(expiry)) | ||
return True | ||
|
||
async def get_expiry(self, key: str) -> int: | ||
""" | ||
:param str key: the key to get the expiry for | ||
""" | ||
return int(self.expirations.get(key, -1)) | ||
|
||
async def get_num_acquired(self, key: str, expiry: int) -> int: | ||
""" | ||
returns the number of entries already acquired | ||
:param str key: rate limit key to acquire an entry in | ||
:param int expiry: expiry of the entry | ||
""" | ||
timestamp = time.time() | ||
return ( | ||
len([k for k in self.events[key] if k.atime >= timestamp - expiry]) | ||
if self.events.get(key) | ||
else 0 | ||
) | ||
|
||
# FIXME: arg limit is not used | ||
async def get_moving_window( | ||
self, key: str, limit: int, expiry: int | ||
) -> Tuple[int, int]: | ||
""" | ||
returns the starting point and the number of entries in the moving | ||
window | ||
:param str key: rate limit key | ||
:param int expiry: expiry of entry | ||
:return: (start of window, number of acquired entries) | ||
""" | ||
timestamp = time.time() | ||
acquired = await self.get_num_acquired(key, expiry) | ||
for item in self.events.get(key, []): | ||
if item.atime >= timestamp - expiry: | ||
return int(item.atime), acquired | ||
return int(timestamp), acquired | ||
|
||
async def check(self) -> bool: | ||
""" | ||
check if storage is healthy | ||
""" | ||
return True | ||
|
||
async def reset(self) -> None: | ||
self.storage.clear() | ||
self.expirations.clear() | ||
self.events.clear() |
Oops, something went wrong.