Skip to content

Commit

Permalink
Subscription types (#1281)
Browse files Browse the repository at this point in the history
* subscription types

Signed-off-by: Michael Carlstrom <[email protected]>

* Fix docstrings

Signed-off-by: Michael Carlstrom <[email protected]>

* Flake8 imports

Signed-off-by: Michael Carlstrom <[email protected]>

* Fix Node Import

Signed-off-by: Michael Carlstrom <[email protected]>

* Update subscription.py

Signed-off-by: Michael Carlstrom <[email protected]>

---------

Signed-off-by: Michael Carlstrom <[email protected]>
Signed-off-by: Michael Carlstrom <[email protected]>
Signed-off-by: Shane Loretz <[email protected]>
Co-authored-by: Shane Loretz <[email protected]>
  • Loading branch information
InvincibleRMC and sloretz committed Aug 7, 2024
1 parent 63145fe commit 35d494c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 15 deletions.
9 changes: 7 additions & 2 deletions rclpy/rclpy/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Callable
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
import warnings

import rclpy
Expand All @@ -27,6 +28,10 @@
from rclpy.waitable import Waitable


if TYPE_CHECKING:
from rclpy.subscription import SubscriptionHandle


QoSPublisherEventType = _rclpy.rcl_publisher_event_type_t
QoSSubscriptionEventType = _rclpy.rcl_subscription_event_type_t

Expand Down Expand Up @@ -189,8 +194,8 @@ def __init__(
self.use_default_callbacks = use_default_callbacks

def create_event_handlers(
self, callback_group: CallbackGroup, subscription: _rclpy.Subscription, topic_name: str,
) -> List[EventHandler]:
self, callback_group: CallbackGroup, subscription: 'SubscriptionHandle',
topic_name: str) -> List[EventHandler]:
with subscription:
logger = get_logger(subscription.get_logger_name())

Expand Down
6 changes: 4 additions & 2 deletions rclpy/rclpy/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@
from rclpy.qos_overriding_options import _declare_qos_parameters
from rclpy.qos_overriding_options import QoSOverridingOptions
from rclpy.service import Service
from rclpy.subscription import MessageInfo
from rclpy.subscription import Subscription
from rclpy.subscription import SubscriptionHandle
from rclpy.time_source import TimeSource
from rclpy.timer import Rate
from rclpy.timer import Timer, TimerInfo
Expand Down Expand Up @@ -1601,7 +1603,7 @@ def create_subscription(
self,
msg_type: Type[MsgT],
topic: str,
callback: Callable[[MsgT], None],
callback: Union[Callable[[MsgT], None], Callable[[MsgT, MessageInfo], None]],
qos_profile: Union[QoSProfile, int],
*,
callback_group: Optional[CallbackGroup] = None,
Expand Down Expand Up @@ -1651,7 +1653,7 @@ def create_subscription(
check_is_valid_msg_type(msg_type)
try:
with self.handle:
subscription_object = _rclpy.Subscription(
subscription_object: SubscriptionHandle[MsgT] = _rclpy.Subscription(
self.handle, msg_type, topic, qos_profile.get_c_qos_profile())
except ValueError:
failed = True
Expand Down
54 changes: 43 additions & 11 deletions rclpy/rclpy/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,46 @@
from enum import Enum
import inspect
from types import TracebackType
from typing import Callable, Generic, List, Optional, Type, TypeVar
from typing import Callable, Generic, Optional, Protocol, Tuple, Type, TypedDict, TypeVar, Union

from rclpy.callback_groups import CallbackGroup
from rclpy.event_handler import EventHandler, SubscriptionEventCallbacks
from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy
from rclpy.destroyable import DestroyableType
from rclpy.event_handler import SubscriptionEventCallbacks
from rclpy.qos import QoSProfile
from rclpy.type_support import MsgT


class MessageInfo(TypedDict):
source_timestamp: int
received_timestamp: int
publication_sequence_number: Optional[int]
reception_sequence_number: Optional[int]


class SubscriptionHandle(DestroyableType, Protocol[MsgT]):

@property
def pointer(self) -> int:
"""Get the address of the entity as an integer."""
...

def take_message(self, pymsg_type: Type[MsgT], raw: bool) -> Tuple[MsgT, MessageInfo]:
"""Take a message and its metadata from a subscription."""
...

def get_logger_name(self) -> str:
"""Get the name of the logger associated with the node of the subscription."""
...

def get_topic_name(self) -> str:
"""Return the resolved topic name of a subscription."""
...

def get_publisher_count(self) -> int:
"""Count the publishers from a subscription."""
...


# Left to support Legacy TypeVars.
MsgType = TypeVar('MsgType')

Expand All @@ -37,10 +68,10 @@ class CallbackType(Enum):

def __init__(
self,
subscription_impl: _rclpy.Subscription,
subscription_impl: SubscriptionHandle[MsgT],
msg_type: Type[MsgT],
topic: str,
callback: Callable[[MsgT], None],
callback: Union[Callable[[MsgT], None], Callable[[MsgT, MessageInfo], None]],
callback_group: CallbackGroup,
qos_profile: QoSProfile,
raw: bool,
Expand Down Expand Up @@ -74,7 +105,7 @@ def __init__(
self.qos_profile = qos_profile
self.raw = raw

self.event_handlers: List[EventHandler] = event_callbacks.create_event_handlers(
self.event_handlers = event_callbacks.create_event_handlers(
callback_group, subscription_impl, topic)

def get_publisher_count(self) -> int:
Expand All @@ -83,10 +114,10 @@ def get_publisher_count(self) -> int:
return self.__subscription.get_publisher_count()

@property
def handle(self):
def handle(self) -> SubscriptionHandle[MsgT]:
return self.__subscription

def destroy(self):
def destroy(self) -> None:
"""
Destroy a container for a ROS subscription.
Expand All @@ -98,16 +129,17 @@ def destroy(self):
self.handle.destroy_when_not_in_use()

@property
def topic_name(self):
def topic_name(self) -> str:
with self.handle:
return self.__subscription.get_topic_name()

@property
def callback(self) -> Callable[[MsgT], None]:
def callback(self) -> Union[Callable[[MsgT], None], Callable[[MsgT, MessageInfo], None]]:
return self._callback

@callback.setter
def callback(self, value: Callable[[MsgT], None]) -> None:
def callback(self, value: Union[Callable[[MsgT], None],
Callable[[MsgT, MessageInfo], None]]) -> None:
self._callback = value
self._callback_type = Subscription.CallbackType.MessageOnly
try:
Expand Down

0 comments on commit 35d494c

Please sign in to comment.