Skip to content

Commit

Permalink
Merge pull request #299 from whisperfish/handle-pni-uuid
Browse files Browse the repository at this point in the history
Better handle PNI UUID
  • Loading branch information
rubdos committed Jul 12, 2024
2 parents 5a953e4 + 8b967e9 commit 8cd92cd
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 67 deletions.
16 changes: 6 additions & 10 deletions libsignal-service/src/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,13 @@ fn debug_envelope(envelope: &Envelope) -> String {
} else {
format!(
"Envelope {{ \
source_address: {}, \
source_address: {:?}, \
source_device: {:?}, \
server_guid: {:?}, \
timestamp: {:?}, \
content: {} bytes, \
}}",
if envelope.source_service_id.is_some() {
format!("{:?}", envelope.source_address())
} else {
"unknown".to_string()
},
envelope.source_service_id,
envelope.source_device(),
envelope.server_guid(),
envelope.timestamp(),
Expand Down Expand Up @@ -278,13 +274,13 @@ where
)
.await?;

let sender = ServiceAddress {
uuid: Uuid::parse_str(&sender_uuid).map_err(|_| {
let sender = ServiceAddress::try_from(sender_uuid.as_str())
.map_err(|e| {
tracing::error!("{:?}", e);
SignalProtocolError::InvalidSealedSenderMessage(
"invalid sender UUID".to_string(),
)
})?,
};
})?;

let needs_receipt = if envelope.source_service_id.is_some() {
tracing::warn!(?envelope, "Received an unidentified delivery over an identified channel. Marking needs_receipt=false");
Expand Down
15 changes: 5 additions & 10 deletions libsignal-service/src/envelope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::convert::{TryFrom, TryInto};
use aes::cipher::block_padding::Pkcs7;
use aes::cipher::{BlockDecryptMut, KeyIvInit};
use prost::Message;
use uuid::Uuid;

use crate::{
configuration::SignalingKey, push_service::ServiceError,
Expand Down Expand Up @@ -135,15 +134,11 @@ impl Envelope {
}

pub fn source_address(&self) -> ServiceAddress {
let uuid = self
.source_service_id
.as_deref()
.map(Uuid::parse_str)
.transpose()
.expect("valid uuid checked in constructor")
.expect("source_service_id is set");

ServiceAddress { uuid }
match self.source_service_id.as_deref() {
Some(service_id) => ServiceAddress::try_from(service_id)
.expect("invalid ProtocolAddress UUID or prefix"),
None => panic!("source_service_id is set"),
}
}
}

Expand Down
7 changes: 4 additions & 3 deletions libsignal-service/src/profile_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ impl ProfileService {
) -> Result<SignalServiceProfile, ServiceError> {
let endpoint = match profile_key {
Some(key) => {
let version = bincode::serialize(
&key.get_profile_key_version(address.aci()),
)?;
let version =
bincode::serialize(&key.get_profile_key_version(
address.aci().expect("profile by ACI ProtocolAddress"),
))?;
let version = std::str::from_utf8(&version)
.expect("hex encoded profile key version");
format!("/v1/profile/{}/{}", address.uuid, version)
Expand Down
10 changes: 5 additions & 5 deletions libsignal-service/src/push_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub const STICKER_PATH: &str = "stickers/%s/full/%d";
pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
pub const DEFAULT_DEVICE_ID: u32 = 1;

#[derive(Debug, Clone, Copy, Eq, PartialEq)]
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum ServiceIdType {
/// Account Identity (ACI)
///
Expand Down Expand Up @@ -811,7 +811,7 @@ pub trait PushService: MaybeSend {
&mut self,
messages: OutgoingPushMessages,
) -> Result<SendMessageResponse, ServiceError> {
let path = format!("/v1/messages/{}", messages.recipient.uuid);
let path = format!("/v1/messages/{}", messages.destination);
self.put_json(
Endpoint::Service,
&path,
Expand Down Expand Up @@ -902,9 +902,9 @@ pub trait PushService: MaybeSend {
profile_key: Option<zkgroup::profiles::ProfileKey>,
) -> Result<SignalServiceProfile, ServiceError> {
let endpoint = if let Some(key) = profile_key {
let version = bincode::serialize(
&key.get_profile_key_version(address.aci()),
)?;
let version = bincode::serialize(&key.get_profile_key_version(
address.aci().expect("profile by ACI ProtocolAddress"),
))?;
let version = std::str::from_utf8(&version)
.expect("hex encoded profile key version");
format!("/v1/profile/{}/{}", address.uuid, version)
Expand Down
32 changes: 22 additions & 10 deletions libsignal-service/src/sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use libsignal_protocol::{
use rand::{CryptoRng, Rng};
use tracing::{info, trace};
use tracing_futures::Instrument;
use uuid::Uuid;

use crate::{
cipher::{get_preferred_protocol_address, ServiceCipher},
Expand Down Expand Up @@ -38,7 +37,7 @@ pub struct OutgoingPushMessage {

#[derive(serde::Serialize, Debug)]
pub struct OutgoingPushMessages {
pub recipient: ServiceAddress,
pub destination: uuid::Uuid,
pub timestamp: u64,
pub messages: Vec<OutgoingPushMessage>,
pub online: bool,
Expand Down Expand Up @@ -120,8 +119,8 @@ pub enum MessageSenderError {
#[error("Proof of type {options:?} required using token {token}")]
ProofRequired { token: String, options: Vec<String> },

#[error("Recipient not found: {uuid}")]
NotFound { uuid: Uuid },
#[error("Recipient not found: {addr:?}")]
NotFound { addr: ServiceAddress },
}

impl<Service, S, R> MessageSender<Service, S, R>
Expand Down Expand Up @@ -500,7 +499,7 @@ where
.await?;

let messages = OutgoingPushMessages {
recipient,
destination: recipient.uuid,
timestamp,
messages,
online,
Expand Down Expand Up @@ -601,7 +600,7 @@ where
Err(ServiceError::NotFoundError) => {
tracing::debug!("Not found when sending a message");
return Err(MessageSenderError::NotFound {
uuid: recipient.uuid,
addr: recipient,
});
},
Err(e) => {
Expand Down Expand Up @@ -722,9 +721,22 @@ where
devices.insert(DEFAULT_DEVICE_ID.into());

// never try to send messages to the sender device
if recipient.aci() == self.local_aci.aci() {
devices.remove(&self.device_id);
}
match recipient.identity {
ServiceIdType::AccountIdentity => {
if recipient.aci().is_some()
&& recipient.aci() == self.local_aci.aci()
{
devices.remove(&self.device_id);
}
},
ServiceIdType::PhoneNumberIdentity => {
if recipient.pni().is_some()
&& recipient.pni() == self.local_aci.pni()
{
devices.remove(&self.device_id);
}
},
};

for device_id in devices {
trace!("sending message to device {}", device_id);
Expand Down Expand Up @@ -836,7 +848,7 @@ where
},
Err(ServiceError::NotFoundError) => {
return Err(MessageSenderError::NotFound {
uuid: recipient.uuid,
addr: *recipient,
});
},
Err(e) => Err(e)?,
Expand Down
108 changes: 81 additions & 27 deletions libsignal-service/src/service_address.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::convert::TryFrom;

use libsignal_protocol::{DeviceId, ProtocolAddress};
use serde::{Deserialize, Serialize};
use uuid::Uuid;

pub use crate::push_service::ServiceIdType;

#[derive(thiserror::Error, Debug, Clone)]
pub enum ParseServiceAddressError {
#[error("Supplied UUID could not be parsed")]
Expand All @@ -13,64 +14,117 @@ pub enum ParseServiceAddressError {
NoUuid,
}

#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct ServiceAddress {
pub uuid: Uuid,
pub identity: ServiceIdType,
}

impl ServiceAddress {
pub fn to_protocol_address(
&self,
device_id: impl Into<DeviceId>,
) -> ProtocolAddress {
ProtocolAddress::new(self.uuid.to_string(), device_id.into())
match self.identity {
ServiceIdType::AccountIdentity => {
ProtocolAddress::new(self.uuid.to_string(), device_id.into())
},
ServiceIdType::PhoneNumberIdentity => ProtocolAddress::new(
format!("PNI:{}", self.uuid),
device_id.into(),
),
}
}

pub fn aci(&self) -> libsignal_protocol::Aci {
libsignal_protocol::Aci::from_uuid_bytes(self.uuid.into_bytes())
pub fn new_aci(uuid: Uuid) -> Self {
Self {
uuid,
identity: ServiceIdType::AccountIdentity,
}
}

pub fn pni(&self) -> libsignal_protocol::Pni {
libsignal_protocol::Pni::from_uuid_bytes(self.uuid.into_bytes())
pub fn new_pni(uuid: Uuid) -> Self {
Self {
uuid,
identity: ServiceIdType::PhoneNumberIdentity,
}
}
}

impl From<Uuid> for ServiceAddress {
fn from(uuid: Uuid) -> Self {
Self { uuid }
pub fn aci(&self) -> Option<libsignal_protocol::Aci> {
use libsignal_protocol::Aci;
match self.identity {
ServiceIdType::AccountIdentity => {
Some(Aci::from_uuid_bytes(self.uuid.into_bytes()))
},
ServiceIdType::PhoneNumberIdentity => None,
}
}

pub fn pni(&self) -> Option<libsignal_protocol::Pni> {
use libsignal_protocol::Pni;
match self.identity {
ServiceIdType::AccountIdentity => None,
ServiceIdType::PhoneNumberIdentity => {
Some(Pni::from_uuid_bytes(self.uuid.into_bytes()))
},
}
}

pub fn to_service_id(&self) -> String {
match self.identity {
ServiceIdType::AccountIdentity => self.uuid.to_string(),
ServiceIdType::PhoneNumberIdentity => {
format!("PNI:{}", self.uuid)
},
}
}
}

impl TryFrom<&str> for ServiceAddress {
impl TryFrom<&ProtocolAddress> for ServiceAddress {
type Error = ParseServiceAddressError;

fn try_from(value: &str) -> Result<Self, Self::Error> {
Ok(ServiceAddress {
uuid: Uuid::parse_str(value)?,
fn try_from(addr: &ProtocolAddress) -> Result<Self, Self::Error> {
let value = addr.name();
if let Some(pni) = value.strip_prefix("PNI:") {
Ok(ServiceAddress::new_pni(Uuid::parse_str(pni)?))
} else {
Ok(ServiceAddress::new_aci(Uuid::parse_str(value)?))
}
.map_err(|e| {
tracing::error!("Parsing ServiceAddress from {:?}", addr);
ParseServiceAddressError::InvalidUuid(e)
})
}
}

impl TryFrom<Option<&str>> for ServiceAddress {
impl TryFrom<&str> for ServiceAddress {
type Error = ParseServiceAddressError;

fn try_from(value: Option<&str>) -> Result<Self, Self::Error> {
match value.map(Uuid::parse_str) {
Some(Ok(uuid)) => Ok(ServiceAddress { uuid }),
Some(Err(e)) => Err(ParseServiceAddressError::InvalidUuid(e)),
None => Err(ParseServiceAddressError::NoUuid),
fn try_from(value: &str) -> Result<Self, Self::Error> {
if let Some(pni) = value.strip_prefix("PNI:") {
Ok(ServiceAddress::new_pni(Uuid::parse_str(pni)?))
} else {
Ok(ServiceAddress::new_aci(Uuid::parse_str(value)?))
}
.map_err(|e| {
tracing::error!("Parsing ServiceAddress from '{}'", value);
ParseServiceAddressError::InvalidUuid(e)
})
}
}

impl TryFrom<Option<&[u8]>> for ServiceAddress {
impl TryFrom<&[u8]> for ServiceAddress {
type Error = ParseServiceAddressError;

fn try_from(value: Option<&[u8]>) -> Result<Self, Self::Error> {
match value.map(Uuid::from_slice) {
Some(Ok(uuid)) => Ok(ServiceAddress { uuid }),
Some(Err(e)) => Err(ParseServiceAddressError::InvalidUuid(e)),
None => Err(ParseServiceAddressError::NoUuid),
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
if let Some(pni) = value.strip_prefix(b"PNI:") {
Ok(ServiceAddress::new_pni(Uuid::from_slice(pni)?))
} else {
Ok(ServiceAddress::new_aci(Uuid::from_slice(value)?))
}
.map_err(|e| {
tracing::error!("Parsing ServiceAddress from {:?}", value);
ParseServiceAddressError::InvalidUuid(e)
})
}
}
4 changes: 2 additions & 2 deletions libsignal-service/src/websocket/sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ impl SignalWebSocket {
&mut self,
messages: OutgoingPushMessages,
) -> Result<SendMessageResponse, ServiceError> {
let path = format!("/v1/messages/{}", messages.recipient.uuid);
let path = format!("/v1/messages/{}", messages.destination);
self.put_json(&path, messages).await
}

Expand All @@ -21,7 +21,7 @@ impl SignalWebSocket {
messages: OutgoingPushMessages,
access: &UnidentifiedAccess,
) -> Result<SendMessageResponse, ServiceError> {
let path = format!("/v1/messages/{}", messages.recipient.uuid);
let path = format!("/v1/messages/{}", messages.destination);
let header = format!(
"Unidentified-Access-Key:{}",
BASE64_RELAXED.encode(&access.key)
Expand Down

0 comments on commit 8cd92cd

Please sign in to comment.