-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from killian31/feat_ui
feat: streamlit app
- Loading branch information
Showing
4 changed files
with
310 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import os | ||
|
||
import streamlit as st | ||
from PIL import Image, ImageDraw | ||
|
||
import redirect as rd | ||
from main import segment_video | ||
from video_to_images import ImageCreator | ||
|
||
|
||
def load_image(image_path): | ||
return Image.open(image_path) | ||
|
||
|
||
st.title("Video Background Removal") | ||
|
||
video_file = st.file_uploader("Upload a video", type=["mp4", "avi", "mov"]) | ||
|
||
if video_file is not None: | ||
# Temporary save uploaded file to process | ||
with open("temp_video.mp4", "wb") as f: | ||
f.write(video_file.getbuffer()) | ||
|
||
if not os.path.exists("temp_images"): | ||
vid_to_im = ImageCreator( | ||
"temp_video.mp4", "temp_images", image_start=0, image_end=0 | ||
) | ||
vid_to_im.get_images() | ||
# get initial frame filename (can vary depending on the video) | ||
frame_path = os.path.join("temp_images", sorted(os.listdir("temp_images"))[0]) | ||
else: | ||
frame_path = os.path.join("temp_images", sorted(os.listdir("temp_images"))[0]) | ||
|
||
# Display sliders for bounding box coordinates | ||
col1, col2 = st.columns(2) | ||
# Get the initial frame dimensions | ||
initial_frame = load_image(frame_path) | ||
original_width, original_height = initial_frame.width, initial_frame.height | ||
with col1: | ||
xmin = st.slider("xmin", 0, original_width, original_width // 4) | ||
ymin = st.slider("ymin", 0, original_height, original_height // 4) | ||
with col2: | ||
xmax = st.slider("xmax", 0, original_width, original_width // 2) | ||
ymax = st.slider("ymax", 0, original_height, original_height // 2) | ||
|
||
# Draw the bounding box on a copy of the image | ||
draw = ImageDraw.Draw(initial_frame) | ||
draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=3) | ||
st.image(initial_frame, caption="Bounding Box Preview", use_column_width=True) | ||
if st.button("Save Bounding Box"): | ||
with open("temp_bbox.txt", "w") as bbox_file: | ||
bbox_file.write(f"{xmin} {ymin} {xmax} {ymax}") | ||
st.write(f"Bounding box saved to {os.path.abspath('temp_bbox.txt')}") | ||
|
||
if st.button("Segment Video"): | ||
if not os.path.exists("./temp_bbox.txt"): | ||
with open("temp_bbox.txt", "w") as bbox_file: | ||
bbox_file.write(f"{xmin} {ymin} {xmax} {ymax}") | ||
|
||
st.write("Segmenting video...") | ||
so = st.empty() | ||
with rd.stdouterr(to=so): | ||
segment_video( | ||
"temp_video.mp4", | ||
"temp_images", | ||
0, | ||
0, | ||
"temp_bbox.txt", | ||
False, | ||
"./models/mobile_sam.pt", | ||
output_video="video_segmented.mp4", | ||
) | ||
# Display the segmented video | ||
st.video("video_segmented.mp4") | ||
st.write(f"Video saved to {os.path.abspath('video_segmented.mp4')}") | ||
# Download the segmented video | ||
st.markdown( | ||
f'<a href="video_segmented.mp4" download="video_segmented.mp4">Download video</a>', | ||
unsafe_allow_html=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
import streamlit as st | ||
import io | ||
import contextlib | ||
import sys | ||
import re | ||
import threading | ||
|
||
|
||
class _Redirect: | ||
class IOStuff(io.StringIO): | ||
def __init__( | ||
self, trigger, max_buffer, buffer_separator, regex, dup, need_dup, on_thread | ||
): | ||
super().__init__() | ||
self._trigger = trigger | ||
self._max_buffer = max_buffer | ||
self._buffer_separator = buffer_separator | ||
self._regex = regex and re.compile(regex) | ||
self._dup = dup | ||
self._need_dup = need_dup | ||
self._on_thread = on_thread | ||
|
||
def write(self, __s: str) -> int: | ||
res = None | ||
if self._on_thread == threading.get_ident(): | ||
if self._max_buffer: | ||
concatenated_len = super().tell() + len(__s) | ||
if concatenated_len > self._max_buffer: | ||
rest = self.get_filtered_output()[ | ||
concatenated_len - self._max_buffer : | ||
] | ||
if self._buffer_separator is not None: | ||
rest = rest.split(self._buffer_separator, 1)[-1] | ||
super().seek(0) | ||
super().write(rest) | ||
super().truncate(super().tell() + len(__s)) | ||
res = super().write(__s) | ||
self._trigger(self.get_filtered_output()) | ||
if self._on_thread != threading.get_ident() or self._need_dup: | ||
self._dup.write(__s) | ||
return res | ||
|
||
def get_filtered_output(self): | ||
if self._regex is None or self._buffer_separator is None: | ||
return self.getvalue() | ||
|
||
return self._buffer_separator.join( | ||
filter( | ||
self._regex.search, self.getvalue().split(self._buffer_separator) | ||
) | ||
) | ||
|
||
def print_at_end(self): | ||
self._trigger(self.get_filtered_output()) | ||
|
||
def __init__( | ||
self, | ||
stdout=None, | ||
stderr=False, | ||
format=None, | ||
to=None, | ||
max_buffer=None, | ||
buffer_separator="\n", | ||
regex=None, | ||
duplicate_out=False, | ||
): | ||
self.io_args = { | ||
"trigger": self._write, | ||
"max_buffer": max_buffer, | ||
"buffer_separator": buffer_separator, | ||
"regex": regex, | ||
"on_thread": threading.get_ident(), | ||
} | ||
self.redirections = [] | ||
self.st = None | ||
self.stderr = stderr is True | ||
self.stdout = stdout is True or (stdout is None and not self.stderr) | ||
self.format = format or "code" | ||
self.to = to | ||
self.fun = None | ||
self.duplicate_out = duplicate_out or None | ||
self.active_nested = None | ||
|
||
if not self.stdout and not self.stderr: | ||
raise ValueError("one of stdout or stderr must be True") | ||
|
||
if self.format not in ["text", "markdown", "latex", "code", "write"]: | ||
raise ValueError( | ||
f"format need oneof the following: {', '.join(['text', 'markdown', 'latex', 'code', 'write'])}" | ||
) | ||
|
||
if self.to and (not hasattr(self.to, "text") or not hasattr(self.to, "empty")): | ||
raise ValueError(f"'to' is not a streamlit container object") | ||
|
||
def __enter__(self): | ||
if self.st is not None: | ||
if self.to is None: | ||
if self.active_nested is None: | ||
self.active_nested = self( | ||
format=self.format, | ||
max_buffer=self.io_args["max_buffer"], | ||
buffer_separator=self.io_args["buffer_separator"], | ||
regex=self.io_args["regex"], | ||
duplicate_out=self.duplicate_out, | ||
) | ||
return self.active_nested.__enter__() | ||
else: | ||
raise Exception("Already entered") | ||
to = self.to or st | ||
|
||
to.text( | ||
f"Redirected output from " | ||
f"{'stdout and stderr' if self.stdout and self.stderr else 'stdout' if self.stdout else 'stderr'}" | ||
f"{' [' + self.io_args['regex'] + ']' if self.io_args['regex'] else ''}" | ||
f":" | ||
) | ||
self.st = to.empty() | ||
self.fun = getattr(self.st, self.format) | ||
|
||
io_obj = None | ||
|
||
def redirect(to_duplicate, context_redirect): | ||
nonlocal io_obj | ||
io_obj = _Redirect.IOStuff( | ||
need_dup=self.duplicate_out and True, dup=to_duplicate, **self.io_args | ||
) | ||
redirection = context_redirect(io_obj) | ||
self.redirections.append((redirection, io_obj)) | ||
redirection.__enter__() | ||
|
||
if self.stderr: | ||
redirect(sys.stderr, contextlib.redirect_stderr) | ||
if self.stdout: | ||
redirect(sys.stdout, contextlib.redirect_stdout) | ||
|
||
return io_obj | ||
|
||
def __call__( | ||
self, | ||
to=None, | ||
format=None, | ||
max_buffer=None, | ||
buffer_separator="\n", | ||
regex=None, | ||
duplicate_out=False, | ||
): | ||
return _Redirect( | ||
self.stdout, | ||
self.stderr, | ||
format=format, | ||
to=to, | ||
max_buffer=max_buffer, | ||
buffer_separator=buffer_separator, | ||
regex=regex, | ||
duplicate_out=duplicate_out, | ||
) | ||
|
||
def __exit__(self, *exc): | ||
if self.active_nested is not None: | ||
nested = self.active_nested | ||
if nested.active_nested is None: | ||
self.active_nested = None | ||
return nested.__exit__(*exc) | ||
|
||
res = None | ||
for redirection, io_obj in reversed(self.redirections): | ||
res = redirection.__exit__(*exc) | ||
io_obj.print_at_end() | ||
|
||
self.redirections = [] | ||
self.st = None | ||
self.fun = None | ||
return res | ||
|
||
def _write(self, data): | ||
self.fun(data) | ||
|
||
|
||
stdout = _Redirect() | ||
stderr = _Redirect(stderr=True) | ||
stdouterr = _Redirect(stdout=True, stderr=True) | ||
|
||
""" | ||
# can be used as | ||
import time | ||
import sys | ||
from random import getrandbits | ||
import streamlit.redirect as rd | ||
st.text('Suboutput:') | ||
so = st.empty() | ||
with rd.stdout, rd.stderr(format='markdown', to=st.sidebar): | ||
print("hello ") | ||
time.sleep(1) | ||
i = 5 | ||
while i > 0: | ||
print("**M**izu? ", file=sys.stdout if getrandbits(1) else sys.stderr) | ||
i -= 1 | ||
with rd.stdout(to=so): | ||
print(f" cica {i}") | ||
if i: | ||
time.sleep(1) | ||
# """ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters