Skip to content

Commit

Permalink
Allow one reusable proxy URL per ESPHome device (#125845)
Browse files Browse the repository at this point in the history
* Allow one reusable URL per device

* Move process to convert info

* Stop previous process

* Change to 404

* Better error handling
  • Loading branch information
synesthesiam committed Sep 19, 2024
1 parent f8274cd commit d1a4838
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 48 deletions.
102 changes: 57 additions & 45 deletions homeassistant/components/esphome/ffmpeg_proxy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""HTTP view that converts audio from a URL to a preferred format."""

import asyncio
from collections import defaultdict
from dataclasses import dataclass, field
from http import HTTPStatus
import logging
Expand All @@ -28,7 +27,7 @@ def async_create_proxy_url(
channels: int | None = None,
width: int | None = None,
) -> str:
"""Create a one-time use proxy URL that automatically converts the media."""
"""Create a use proxy URL that automatically converts the media."""
data: FFmpegProxyData = hass.data[DATA_FFMPEG_PROXY]
return data.async_create_proxy_url(
device_id, media_url, media_format, rate, channels, width
Expand All @@ -39,7 +38,10 @@ def async_create_proxy_url(
class FFmpegConversionInfo:
"""Information for ffmpeg conversion."""

url: str
convert_id: str
"""Unique id for media conversion."""

media_url: str
"""Source URL of media to convert."""

media_format: str
Expand All @@ -54,18 +56,16 @@ class FFmpegConversionInfo:
width: int | None
"""Target sample width in bytes (None to keep source width)."""

proc: asyncio.subprocess.Process | None = None
"""Subprocess doing ffmpeg conversion."""


@dataclass
class FFmpegProxyData:
"""Data for ffmpeg proxy conversion."""

# device_id -> convert_id -> info
conversions: dict[str, dict[str, FFmpegConversionInfo]] = field(
default_factory=lambda: defaultdict(dict)
)

# device_id -> process
processes: dict[str, asyncio.subprocess.Process] = field(default_factory=dict)
# device_id -> info
conversions: dict[str, FFmpegConversionInfo] = field(default_factory=dict)

def async_create_proxy_url(
self,
Expand All @@ -77,9 +77,19 @@ def async_create_proxy_url(
width: int | None,
) -> str:
"""Create a one-time use proxy URL that automatically converts the media."""
if (convert_info := self.conversions.pop(device_id, None)) is not None:
# Stop existing conversion before overwriting info
if (convert_info.proc is not None) and (
convert_info.proc.returncode is None
):
_LOGGER.debug(
"Stopping existing ffmpeg process for device: %s", device_id
)
convert_info.proc.kill()

convert_id = secrets.token_urlsafe(16)
self.conversions[device_id][convert_id] = FFmpegConversionInfo(
media_url, media_format, rate, channels, width
self.conversions[device_id] = FFmpegConversionInfo(
convert_id, media_url, media_format, rate, channels, width
)
_LOGGER.debug("Media URL allowed by proxy: %s", media_url)

Expand Down Expand Up @@ -128,7 +138,7 @@ async def prepare(self, request: BaseRequest) -> AbstractStreamWriter | None:

command_args = [
"-i",
self.convert_info.url,
self.convert_info.media_url,
"-f",
self.convert_info.media_format,
]
Expand Down Expand Up @@ -156,12 +166,12 @@ async def prepare(self, request: BaseRequest) -> AbstractStreamWriter | None:
stderr=asyncio.subprocess.PIPE,
)

# Only one conversion process per device is allowed
self.convert_info.proc = proc

assert proc.stdout is not None
assert proc.stderr is not None

# Only one conversion process per device is allowed
self.proxy_data.processes[self.device_id] = proc

try:
# Pull audio chunks from ffmpeg and pass them to the HTTP client
while (
Expand All @@ -173,21 +183,25 @@ async def prepare(self, request: BaseRequest) -> AbstractStreamWriter | None:
):
await writer.write(chunk)
await writer.drain()
except asyncio.CancelledError:
raise # don't log error
except:
_LOGGER.exception("Unexpected error during ffmpeg conversion")

# Process did not exit successfully
stderr_text = ""
while line := await proc.stderr.readline():
stderr_text += line.decode()
_LOGGER.error("FFmpeg output: %s", stderr_text)

raise
finally:
# Close connection
await writer.write_eof()

# Terminate hangs, so kill is used
proc.kill()
if proc.returncode is None:
proc.kill()

if proc.returncode != 0:
# Process did not exit successfully
stderr_text = ""
while line := await proc.stderr.readline():
stderr_text += line.decode()
_LOGGER.error("Error shutting down ffmpeg: %s", stderr_text)
else:
_LOGGER.debug("Conversion completed: %s", self.convert_info)
# Close connection
await writer.write_eof()

return writer

Expand All @@ -208,27 +222,25 @@ async def get(
self, request: web.Request, device_id: str, filename: str
) -> web.StreamResponse:
"""Start a get request."""

# {id}.mp3 -> id
convert_id = filename.rsplit(".")[0]

try:
convert_info = self.proxy_data.conversions[device_id].pop(convert_id)
except KeyError:
_LOGGER.error(
"Unrecognized convert id %s for device: %s", convert_id, device_id
)
if (convert_info := self.proxy_data.conversions.get(device_id)) is None:
return web.Response(
body="Convert id not recognized", status=HTTPStatus.BAD_REQUEST
body="No proxy URL for device", status=HTTPStatus.NOT_FOUND
)

# Stop any existing process
proc = self.proxy_data.processes.pop(device_id, None)
if (proc is not None) and (proc.returncode is None):
_LOGGER.debug("Stopping existing ffmpeg process for device: %s", device_id)
# {id}.mp3 -> id, mp3
convert_id, media_format = filename.rsplit(".")

# Terminate hangs, so kill is used
proc.kill()
if (convert_info.convert_id != convert_id) or (
convert_info.media_format != media_format
):
return web.Response(body="Invalid proxy URL", status=HTTPStatus.BAD_REQUEST)

# Stop previous process if the URL is being reused.
# We could continue from where the previous connection left off, but
# there would be no media header.
if (convert_info.proc is not None) and (convert_info.proc.returncode is None):
convert_info.proc.kill()
convert_info.proc = None

# Stream converted audio back to client
return FFmpegConvertResponse(
Expand Down
129 changes: 126 additions & 3 deletions tests/components/esphome/test_ffmpeg_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def test_proxy_view(

# Should fail because we haven't allowed the URL yet
req = await client.get(url)
assert req.status == HTTPStatus.BAD_REQUEST
assert req.status == HTTPStatus.NOT_FOUND

# Allow the URL
with patch(
Expand All @@ -75,6 +75,12 @@ async def test_proxy_view(
== url
)

# Requesting the wrong media format should fail
wrong_url = f"/api/esphome/ffmpeg_proxy/{device_id}/{convert_id}.flac"
req = await client.get(wrong_url)
assert req.status == HTTPStatus.BAD_REQUEST

# Correct URL
req = await client.get(url)
assert req.status == HTTPStatus.OK

Expand All @@ -90,11 +96,11 @@ async def test_proxy_view(
assert round(mp3_file.info.length, 0) == 1


async def test_ffmpeg_error(
async def test_ffmpeg_file_doesnt_exist(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
) -> None:
"""Test proxy HTTP view with an ffmpeg error."""
"""Test ffmpeg conversion with a file that doesn't exist."""
device_id = "1234"

await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}})
Expand All @@ -109,3 +115,120 @@ async def test_ffmpeg_error(
assert req.status == HTTPStatus.OK
mp3_data = await req.content.read()
assert not mp3_data


async def test_lingering_process(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
) -> None:
"""Test that a new request stops the old ffmpeg process."""
device_id = "1234"

await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}})
client = await hass_client()

with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_file:
with wave.open(temp_file.name, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(16000 * 2)) # 1s

temp_file.seek(0)
wav_url = pathname2url(temp_file.name)
url1 = async_create_proxy_url(
hass,
device_id,
wav_url,
media_format="wav",
rate=22050,
channels=2,
width=2,
)

# First request will start ffmpeg
req1 = await client.get(url1)
assert req1.status == HTTPStatus.OK

# Only read part of the data
await req1.content.readexactly(100)

# Allow another URL
url2 = async_create_proxy_url(
hass,
device_id,
wav_url,
media_format="wav",
rate=22050,
channels=2,
width=2,
)

req2 = await client.get(url2)
assert req2.status == HTTPStatus.OK

wav_data = await req2.content.read()

# All of the data should be there because this is a new ffmpeg process
with io.BytesIO(wav_data) as wav_io, wave.open(wav_io, "rb") as wav_file:
# We can't use getnframes() here because the WAV header will be incorrect.
# WAV encoders usually go back and update the WAV header after all of
# the frames are written, but ffmpeg can't do that because we're
# streaming the data.
# So instead, we just read and count frames until we run out.
num_frames = 0
while chunk := wav_file.readframes(1024):
num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples

assert num_frames == 22050 # 1s


async def test_request_same_url_multiple_times(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
) -> None:
"""Test that the ffmpeg process is restarted if the same URL is requested multiple times."""
device_id = "1234"

await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}})
client = await hass_client()

with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_file:
with wave.open(temp_file.name, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(16000 * 2 * 10)) # 10s

temp_file.seek(0)
wav_url = pathname2url(temp_file.name)
url = async_create_proxy_url(
hass,
device_id,
wav_url,
media_format="wav",
rate=22050,
channels=2,
width=2,
)

# First request will start ffmpeg
req1 = await client.get(url)
assert req1.status == HTTPStatus.OK

# Only read part of the data
await req1.content.readexactly(100)

# Second request should restart ffmpeg
req2 = await client.get(url)
assert req2.status == HTTPStatus.OK

wav_data = await req2.content.read()

# All of the data should be there because this is a new ffmpeg process
with io.BytesIO(wav_data) as wav_io, wave.open(wav_io, "rb") as wav_file:
num_frames = 0
while chunk := wav_file.readframes(1024):
num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples

assert num_frames == 22050 * 10 # 10s

0 comments on commit d1a4838

Please sign in to comment.