Skip to content

Commit

Permalink
feat: streamlit app
Browse files Browse the repository at this point in the history
  • Loading branch information
killian31 committed Feb 17, 2024
1 parent 68cdf61 commit 8a0d8d9
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 7 deletions.
80 changes: 80 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os

import streamlit as st
from PIL import Image, ImageDraw

import redirect as rd
from main import segment_video
from video_to_images import ImageCreator


def load_image(image_path):
return Image.open(image_path)


st.title("Video Background Removal")

video_file = st.file_uploader("Upload a video", type=["mp4", "avi", "mov"])

if video_file is not None:
# Temporary save uploaded file to process
with open("temp_video.mp4", "wb") as f:
f.write(video_file.getbuffer())

if not os.path.exists("temp_images"):
vid_to_im = ImageCreator(
"temp_video.mp4", "temp_images", image_start=0, image_end=0
)
vid_to_im.get_images()
# get initial frame filename (can vary depending on the video)
frame_path = os.path.join("temp_images", sorted(os.listdir("temp_images"))[0])
else:
frame_path = os.path.join("temp_images", sorted(os.listdir("temp_images"))[0])

# Display sliders for bounding box coordinates
col1, col2 = st.columns(2)
# Get the initial frame dimensions
initial_frame = load_image(frame_path)
original_width, original_height = initial_frame.width, initial_frame.height
with col1:
xmin = st.slider("xmin", 0, original_width, original_width // 4)
ymin = st.slider("ymin", 0, original_height, original_height // 4)
with col2:
xmax = st.slider("xmax", 0, original_width, original_width // 2)
ymax = st.slider("ymax", 0, original_height, original_height // 2)

# Draw the bounding box on a copy of the image
draw = ImageDraw.Draw(initial_frame)
draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=3)
st.image(initial_frame, caption="Bounding Box Preview", use_column_width=True)
if st.button("Save Bounding Box"):
with open("temp_bbox.txt", "w") as bbox_file:
bbox_file.write(f"{xmin} {ymin} {xmax} {ymax}")
st.write(f"Bounding box saved to {os.path.abspath('temp_bbox.txt')}")

if st.button("Segment Video"):
if not os.path.exists("./temp_bbox.txt"):
with open("temp_bbox.txt", "w") as bbox_file:
bbox_file.write(f"{xmin} {ymin} {xmax} {ymax}")

st.write("Segmenting video...")
so = st.empty()
with rd.stdouterr(to=so):
segment_video(
"temp_video.mp4",
"temp_images",
0,
0,
"temp_bbox.txt",
False,
"./models/mobile_sam.pt",
output_video="video_segmented.mp4",
)
# Display the segmented video
st.video("video_segmented.mp4")
st.write(f"Video saved to {os.path.abspath('video_segmented.mp4')}")
# Download the segmented video
st.markdown(
f'<a href="video_segmented.mp4" download="video_segmented.mp4">Download video</a>',
unsafe_allow_html=True,
)
20 changes: 17 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,15 @@ def segment_video(
tracker_name="yolov7",
output_dir="output_frames",
output_video="output.mp4",
pbar=False,
):
if not skip_vid2im:
vid_to_im = ImageCreator(
video_filename,
dir_frames,
image_start=image_start,
image_end=image_end,
pbar=pbar,
)
vid_to_im.get_images()
# Get fps of video
Expand All @@ -142,13 +144,16 @@ def segment_video(
with open(bbox_file, "r") as f:
bbox_orig = [int(coord) for coord in f.read().split(" ")]
download_mobile_sam_weight(mobile_sam_weights)
frames = sorted(os.listdir(dir_frames))
if image_end == 0:
frames = sorted(os.listdir(dir_frames))[image_start:]
else:
frames = sorted(os.listdir(dir_frames))[image_start:image_end]

model_type = "vit_t"

device = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry[model_type](checkpoint=args.mobile_sam_weights)
sam = sam_model_registry[model_type](checkpoint=mobile_sam_weights)
sam.to(device=device)
sam.eval()

Expand All @@ -166,7 +171,14 @@ def segment_video(

output_frames = []

for frame in tqdm(frames):
if pbar:
pb = tqdm(frames)
else:
pb = frames

processed_frames = 0
for frame in pb:
processed_frames += 1
image_file = dir_frames + "/" + frame
image_pil = Image.open(image_file)
image_np = np.array(image_pil)
Expand All @@ -187,6 +199,8 @@ def segment_video(
masked_image = image_np * mask.reshape(h, w, 1)
masked_image = masked_image + mask_image
output_frames.append(masked_image)
if not pbar:
print(f"Processed frame {processed_frames}/{len(frames)}")
if not os.path.exists(output_dir):
os.mkdir(output_dir)

Expand Down
205 changes: 205 additions & 0 deletions redirect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import streamlit as st
import io
import contextlib
import sys
import re
import threading


class _Redirect:
class IOStuff(io.StringIO):
def __init__(
self, trigger, max_buffer, buffer_separator, regex, dup, need_dup, on_thread
):
super().__init__()
self._trigger = trigger
self._max_buffer = max_buffer
self._buffer_separator = buffer_separator
self._regex = regex and re.compile(regex)
self._dup = dup
self._need_dup = need_dup
self._on_thread = on_thread

def write(self, __s: str) -> int:
res = None
if self._on_thread == threading.get_ident():
if self._max_buffer:
concatenated_len = super().tell() + len(__s)
if concatenated_len > self._max_buffer:
rest = self.get_filtered_output()[
concatenated_len - self._max_buffer :
]
if self._buffer_separator is not None:
rest = rest.split(self._buffer_separator, 1)[-1]
super().seek(0)
super().write(rest)
super().truncate(super().tell() + len(__s))
res = super().write(__s)
self._trigger(self.get_filtered_output())
if self._on_thread != threading.get_ident() or self._need_dup:
self._dup.write(__s)
return res

def get_filtered_output(self):
if self._regex is None or self._buffer_separator is None:
return self.getvalue()

return self._buffer_separator.join(
filter(
self._regex.search, self.getvalue().split(self._buffer_separator)
)
)

def print_at_end(self):
self._trigger(self.get_filtered_output())

def __init__(
self,
stdout=None,
stderr=False,
format=None,
to=None,
max_buffer=None,
buffer_separator="\n",
regex=None,
duplicate_out=False,
):
self.io_args = {
"trigger": self._write,
"max_buffer": max_buffer,
"buffer_separator": buffer_separator,
"regex": regex,
"on_thread": threading.get_ident(),
}
self.redirections = []
self.st = None
self.stderr = stderr is True
self.stdout = stdout is True or (stdout is None and not self.stderr)
self.format = format or "code"
self.to = to
self.fun = None
self.duplicate_out = duplicate_out or None
self.active_nested = None

if not self.stdout and not self.stderr:
raise ValueError("one of stdout or stderr must be True")

if self.format not in ["text", "markdown", "latex", "code", "write"]:
raise ValueError(
f"format need oneof the following: {', '.join(['text', 'markdown', 'latex', 'code', 'write'])}"
)

if self.to and (not hasattr(self.to, "text") or not hasattr(self.to, "empty")):
raise ValueError(f"'to' is not a streamlit container object")

def __enter__(self):
if self.st is not None:
if self.to is None:
if self.active_nested is None:
self.active_nested = self(
format=self.format,
max_buffer=self.io_args["max_buffer"],
buffer_separator=self.io_args["buffer_separator"],
regex=self.io_args["regex"],
duplicate_out=self.duplicate_out,
)
return self.active_nested.__enter__()
else:
raise Exception("Already entered")
to = self.to or st

to.text(
f"Redirected output from "
f"{'stdout and stderr' if self.stdout and self.stderr else 'stdout' if self.stdout else 'stderr'}"
f"{' [' + self.io_args['regex'] + ']' if self.io_args['regex'] else ''}"
f":"
)
self.st = to.empty()
self.fun = getattr(self.st, self.format)

io_obj = None

def redirect(to_duplicate, context_redirect):
nonlocal io_obj
io_obj = _Redirect.IOStuff(
need_dup=self.duplicate_out and True, dup=to_duplicate, **self.io_args
)
redirection = context_redirect(io_obj)
self.redirections.append((redirection, io_obj))
redirection.__enter__()

if self.stderr:
redirect(sys.stderr, contextlib.redirect_stderr)
if self.stdout:
redirect(sys.stdout, contextlib.redirect_stdout)

return io_obj

def __call__(
self,
to=None,
format=None,
max_buffer=None,
buffer_separator="\n",
regex=None,
duplicate_out=False,
):
return _Redirect(
self.stdout,
self.stderr,
format=format,
to=to,
max_buffer=max_buffer,
buffer_separator=buffer_separator,
regex=regex,
duplicate_out=duplicate_out,
)

def __exit__(self, *exc):
if self.active_nested is not None:
nested = self.active_nested
if nested.active_nested is None:
self.active_nested = None
return nested.__exit__(*exc)

res = None
for redirection, io_obj in reversed(self.redirections):
res = redirection.__exit__(*exc)
io_obj.print_at_end()

self.redirections = []
self.st = None
self.fun = None
return res

def _write(self, data):
self.fun(data)


stdout = _Redirect()
stderr = _Redirect(stderr=True)
stdouterr = _Redirect(stdout=True, stderr=True)

"""
# can be used as
import time
import sys
from random import getrandbits
import streamlit.redirect as rd
st.text('Suboutput:')
so = st.empty()
with rd.stdout, rd.stderr(format='markdown', to=st.sidebar):
print("hello ")
time.sleep(1)
i = 5
while i > 0:
print("**M**izu? ", file=sys.stdout if getrandbits(1) else sys.stderr)
i -= 1
with rd.stdout(to=so):
print(f" cica {i}")
if i:
time.sleep(1)
# """
12 changes: 8 additions & 4 deletions video_to_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class ImageCreator:
def __init__(self, filename, imgs_dir, image_start=0, image_end=0):
def __init__(self, filename, imgs_dir, image_start=0, image_end=0, pbar=True):
"""
:param str filename: The name of the video's filename.
:param str imgs_dir: The directory where to store the image files.
Expand All @@ -17,6 +17,7 @@ def __init__(self, filename, imgs_dir, image_start=0, image_end=0):
self.imgs_dir = imgs_dir
self.image_start = image_start
self.image_end = image_end
self.pbar = pbar
if not os.path.exists(imgs_dir):
os.makedirs(imgs_dir)

Expand All @@ -30,15 +31,18 @@ def get_images(self):
zfill_max = len(str(total_frames))
ok_count = 0
print("Writing images...")
pbar = tqdm(total=total_frames)
if self.pbar:
pb = tqdm(total=total_frames)
while success:
if count >= self.image_start and count <= self.image_end:
cv2.imwrite(
f"{self.imgs_dir}/frame_{str(ok_count).zfill(zfill_max)}.png", image
)
ok_count += 1
success, image = vid.read()
pbar.update(1)
if self.pbar:
pb.update(1)
count += 1
pbar.close()
if self.pbar:
pb.close()
print("Wrote {} image files.".format(ok_count))

0 comments on commit 8a0d8d9

Please sign in to comment.