Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spv-in] patch the SPIR-Vs wacky CFG cases by hand #844

Merged
merged 1 commit into from
May 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/front/spv/flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ impl FlowGraph {
}
Terminator::Switch {
selector: _,
default,
default_id,
ref targets,
} => {
let default_node_index = block_to_node[&default];
let default_node_index = block_to_node[&default_id];

self.flow.add_edge(
source_node_index,
Expand Down Expand Up @@ -479,12 +479,12 @@ impl FlowGraph {
}
Terminator::Switch {
selector: _,
default,
default_id,
ref targets,
} => {
self.compute_postorder_traverse(Some(self.block_to_node[&default]));
for target in targets.iter() {
self.compute_postorder_traverse(Some(self.block_to_node[&target.1]));
self.compute_postorder_traverse(Some(self.block_to_node[&default_id]));
for &(_, target_id) in targets.iter() {
self.compute_postorder_traverse(Some(self.block_to_node[&target_id]));
}
}
_ => {}
Expand Down Expand Up @@ -619,7 +619,7 @@ impl FlowGraph {
}
Terminator::Switch {
selector,
default,
default_id,
ref targets,
} => {
let merge_node_index =
Expand Down Expand Up @@ -660,7 +660,7 @@ impl FlowGraph {
selector,
cases,
default: self.convert_to_naga_traverse(
self.block_to_node[&default],
self.block_to_node[&default_id],
stop_nodes_cases,
)?,
});
Expand Down
2 changes: 1 addition & 1 deletion src/front/spv/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub enum Terminator {
///
selector: Handle<crate::Expression>,
/// Default block of the switch case.
default: BlockId,
default_id: BlockId,
/// Tuples of (literal, target block)
targets: Vec<(i32, BlockId)>,
},
Expand Down
64 changes: 39 additions & 25 deletions src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use crate::{

use num_traits::cast::FromPrimitive;
use petgraph::graphmap::GraphMap;
use std::{convert::TryInto, num::NonZeroU32, path::PathBuf};
use std::{convert::TryInto, mem, num::NonZeroU32, path::PathBuf};

pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[
spirv::Capability::Shader,
Expand Down Expand Up @@ -1973,7 +1973,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
Op::Switch => {
inst.expect_at_least(3)?;
let selector = self.next()?;
let default = self.next()?;
let default_id = self.next()?;

let selector_lexp = &self.lookup_expression[&selector];
let selector_lty = self.lookup_type.lookup(selector_lexp.type_id)?;
Expand Down Expand Up @@ -2005,7 +2005,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {

break Terminator::Switch {
selector,
default,
default_id,
targets,
};
}
Expand Down Expand Up @@ -2205,41 +2205,55 @@ impl<I: Iterator<Item = u32>> Parser<I> {
}
}

fn patch_function_call_statements(
&self,
statements: &mut [crate::Statement],
) -> Result<(), Error> {
/// Walk the statement tree and patch it in the following cases:
/// 1. Function call targets are replaced by `deferred_function_calls` map
/// 2. Lift the contents of "If" that only breaks on rejection, onto the parent after it.
/// 3. Lift the contents of "Switch" that only has a default, onto the parent after it.
fn patch_statements(&self, statements: &mut crate::Block) -> Result<(), Error> {
use crate::Statement as S;
for statement in statements.iter_mut() {
match *statement {
let mut i = 0usize;
while i < statements.len() {
match statements[i] {
S::Emit(_) => {}
S::Block(ref mut block) => {
self.patch_function_call_statements(block)?;
self.patch_statements(block)?;
}
S::If {
condition: _,
ref mut accept,
ref mut reject,
} => {
self.patch_function_call_statements(accept)?;
self.patch_function_call_statements(reject)?;
if let [S::Break] = reject[..] {
// uplift "accept" into the parent
let extracted = mem::replace(accept, Vec::new());
statements.splice(i + 1..i + 1, extracted.into_iter());
} else {
self.patch_statements(reject)?;
self.patch_statements(accept)?;
}
}
S::Switch {
selector: _,
ref mut cases,
ref mut default,
} => {
for case in cases.iter_mut() {
self.patch_function_call_statements(&mut case.body)?;
if cases.is_empty() {
// uplift "default" into the parent
let extracted = mem::replace(default, Vec::new());
statements.splice(i + 1..i + 1, extracted.into_iter());
} else {
for case in cases.iter_mut() {
self.patch_statements(&mut case.body)?;
}
self.patch_statements(default)?;
}
self.patch_function_call_statements(default)?;
}
S::Loop {
ref mut body,
ref mut continuing,
} => {
self.patch_function_call_statements(body)?;
self.patch_function_call_statements(continuing)?;
self.patch_statements(body)?;
self.patch_statements(continuing)?;
}
S::Break
| S::Continue
Expand All @@ -2255,18 +2269,19 @@ impl<I: Iterator<Item = u32>> Parser<I> {
*function = *self.lookup_function.lookup(fun_id)?;
}
}
i += 1;
}
Ok(())
}

fn patch_function_calls(&self, fun: &mut crate::Function) -> Result<(), Error> {
fn patch_function(&self, fun: &mut crate::Function) -> Result<(), Error> {
for (_, expr) in fun.expressions.iter_mut() {
if let crate::Expression::Call(ref mut function) = *expr {
let fun_id = self.deferred_function_calls[function.index()];
*function = *self.lookup_function.lookup(fun_id)?;
}
}
self.patch_function_call_statements(&mut fun.body)?;
self.patch_statements(&mut fun.body)?;
Ok(())
}

Expand Down Expand Up @@ -2357,29 +2372,28 @@ impl<I: Iterator<Item = u32>> Parser<I> {

log::info!("Patching...");
{
use std::mem::take;
let mut nodes = petgraph::algo::toposort(&self.function_call_graph, None)
.map_err(|cycle| Error::FunctionCallCycle(cycle.node_id()))?;
nodes.reverse(); // we need dominated first
let mut functions = take(&mut module.functions).into_inner();
let mut functions = mem::take(&mut module.functions).into_inner();
for fun_id in nodes {
if fun_id > !(functions.len() as u32) {
// skip all the fake IDs registered for the entry points
continue;
}
let handle = self.lookup_function.get_mut(&fun_id).unwrap();
// take out the function from the old array
let fun = take(&mut functions[handle.index()]);
let fun = mem::take(&mut functions[handle.index()]);
// add it to the newly formed arena, and adjust the lookup
*handle = module.functions.append(fun);
}
}
// patch all the function calls
// patch all the functions
for (_, fun) in module.functions.iter_mut() {
self.patch_function_calls(fun)?;
self.patch_function(fun)?;
}
for ep in module.entry_points.iter_mut() {
self.patch_function_calls(&mut ep.function)?;
self.patch_function(&mut ep.function)?;
}

// Check all the images and samplers to have consistent comparison property.
Expand Down