Skip to content

Commit

Permalink
[spv-in] patch the SPIR-Vs wacky CFG cases by hand
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed May 9, 2021
1 parent 1aa57b3 commit 851f504
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 34 deletions.
16 changes: 8 additions & 8 deletions src/front/spv/flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ impl FlowGraph {
}
Terminator::Switch {
selector: _,
default,
default_id,
ref targets,
} => {
let default_node_index = block_to_node[&default];
let default_node_index = block_to_node[&default_id];

self.flow.add_edge(
source_node_index,
Expand Down Expand Up @@ -479,12 +479,12 @@ impl FlowGraph {
}
Terminator::Switch {
selector: _,
default,
default_id,
ref targets,
} => {
self.compute_postorder_traverse(Some(self.block_to_node[&default]));
for target in targets.iter() {
self.compute_postorder_traverse(Some(self.block_to_node[&target.1]));
self.compute_postorder_traverse(Some(self.block_to_node[&default_id]));
for &(_, target_id) in targets.iter() {
self.compute_postorder_traverse(Some(self.block_to_node[&target_id]));
}
}
_ => {}
Expand Down Expand Up @@ -619,7 +619,7 @@ impl FlowGraph {
}
Terminator::Switch {
selector,
default,
default_id,
ref targets,
} => {
let merge_node_index =
Expand Down Expand Up @@ -660,7 +660,7 @@ impl FlowGraph {
selector,
cases,
default: self.convert_to_naga_traverse(
self.block_to_node[&default],
self.block_to_node[&default_id],
stop_nodes_cases,
)?,
});
Expand Down
2 changes: 1 addition & 1 deletion src/front/spv/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub enum Terminator {
///
selector: Handle<crate::Expression>,
/// Default block of the switch case.
default: BlockId,
default_id: BlockId,
/// Tuples of (literal, target block)
targets: Vec<(i32, BlockId)>,
},
Expand Down
64 changes: 39 additions & 25 deletions src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use crate::{

use num_traits::cast::FromPrimitive;
use petgraph::graphmap::GraphMap;
use std::{convert::TryInto, num::NonZeroU32, path::PathBuf};
use std::{convert::TryInto, mem, num::NonZeroU32, path::PathBuf};

pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[
spirv::Capability::Shader,
Expand Down Expand Up @@ -1973,7 +1973,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
Op::Switch => {
inst.expect_at_least(3)?;
let selector = self.next()?;
let default = self.next()?;
let default_id = self.next()?;

let selector_lexp = &self.lookup_expression[&selector];
let selector_lty = self.lookup_type.lookup(selector_lexp.type_id)?;
Expand Down Expand Up @@ -2005,7 +2005,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {

break Terminator::Switch {
selector,
default,
default_id,
targets,
};
}
Expand Down Expand Up @@ -2205,41 +2205,55 @@ impl<I: Iterator<Item = u32>> Parser<I> {
}
}

fn patch_function_call_statements(
&self,
statements: &mut [crate::Statement],
) -> Result<(), Error> {
/// Walk the statement tree and patch it in the following cases:
/// 1. Function call targets are replaced by `deferred_function_calls` map
/// 2. Lift the contents of "If" that only breaks on rejection, onto the parent after it.
/// 3. Lift the contents of "Switch" that only has a default, onto the parent after it.
fn patch_statements(&self, statements: &mut crate::Block) -> Result<(), Error> {
use crate::Statement as S;
for statement in statements.iter_mut() {
match *statement {
let mut i = 0usize;
while i < statements.len() {
match statements[i] {
S::Emit(_) => {}
S::Block(ref mut block) => {
self.patch_function_call_statements(block)?;
self.patch_statements(block)?;
}
S::If {
condition: _,
ref mut accept,
ref mut reject,
} => {
self.patch_function_call_statements(accept)?;
self.patch_function_call_statements(reject)?;
if let [S::Break] = reject[..] {
// uplift "accept" into the parent
let extracted = mem::replace(accept, Vec::new());
statements.splice(i + 1..i + 1, extracted.into_iter());
} else {
self.patch_statements(reject)?;
self.patch_statements(accept)?;
}
}
S::Switch {
selector: _,
ref mut cases,
ref mut default,
} => {
for case in cases.iter_mut() {
self.patch_function_call_statements(&mut case.body)?;
if cases.is_empty() {
// uplift "default" into the parent
let extracted = mem::replace(default, Vec::new());
statements.splice(i + 1..i + 1, extracted.into_iter());
} else {
for case in cases.iter_mut() {
self.patch_statements(&mut case.body)?;
}
self.patch_statements(default)?;
}
self.patch_function_call_statements(default)?;
}
S::Loop {
ref mut body,
ref mut continuing,
} => {
self.patch_function_call_statements(body)?;
self.patch_function_call_statements(continuing)?;
self.patch_statements(body)?;
self.patch_statements(continuing)?;
}
S::Break
| S::Continue
Expand All @@ -2255,18 +2269,19 @@ impl<I: Iterator<Item = u32>> Parser<I> {
*function = *self.lookup_function.lookup(fun_id)?;
}
}
i += 1;
}
Ok(())
}

fn patch_function_calls(&self, fun: &mut crate::Function) -> Result<(), Error> {
fn patch_function(&self, fun: &mut crate::Function) -> Result<(), Error> {
for (_, expr) in fun.expressions.iter_mut() {
if let crate::Expression::Call(ref mut function) = *expr {
let fun_id = self.deferred_function_calls[function.index()];
*function = *self.lookup_function.lookup(fun_id)?;
}
}
self.patch_function_call_statements(&mut fun.body)?;
self.patch_statements(&mut fun.body)?;
Ok(())
}

Expand Down Expand Up @@ -2357,29 +2372,28 @@ impl<I: Iterator<Item = u32>> Parser<I> {

log::info!("Patching...");
{
use std::mem::take;
let mut nodes = petgraph::algo::toposort(&self.function_call_graph, None)
.map_err(|cycle| Error::FunctionCallCycle(cycle.node_id()))?;
nodes.reverse(); // we need dominated first
let mut functions = take(&mut module.functions).into_inner();
let mut functions = mem::take(&mut module.functions).into_inner();
for fun_id in nodes {
if fun_id > !(functions.len() as u32) {
// skip all the fake IDs registered for the entry points
continue;
}
let handle = self.lookup_function.get_mut(&fun_id).unwrap();
// take out the function from the old array
let fun = take(&mut functions[handle.index()]);
let fun = mem::take(&mut functions[handle.index()]);
// add it to the newly formed arena, and adjust the lookup
*handle = module.functions.append(fun);
}
}
// patch all the function calls
// patch all the functions
for (_, fun) in module.functions.iter_mut() {
self.patch_function_calls(fun)?;
self.patch_function(fun)?;
}
for ep in module.entry_points.iter_mut() {
self.patch_function_calls(&mut ep.function)?;
self.patch_function(&mut ep.function)?;
}

// Check all the images and samplers to have consistent comparison property.
Expand Down

0 comments on commit 851f504

Please sign in to comment.