diff --git a/src/front/spv/flow.rs b/src/front/spv/flow.rs index e8c6a0a856..65f8ca4a43 100644 --- a/src/front/spv/flow.rs +++ b/src/front/spv/flow.rs @@ -143,10 +143,10 @@ impl FlowGraph { } Terminator::Switch { selector: _, - default, + default_id, ref targets, } => { - let default_node_index = block_to_node[&default]; + let default_node_index = block_to_node[&default_id]; self.flow.add_edge( source_node_index, @@ -479,12 +479,12 @@ impl FlowGraph { } Terminator::Switch { selector: _, - default, + default_id, ref targets, } => { - self.compute_postorder_traverse(Some(self.block_to_node[&default])); - for target in targets.iter() { - self.compute_postorder_traverse(Some(self.block_to_node[&target.1])); + self.compute_postorder_traverse(Some(self.block_to_node[&default_id])); + for &(_, target_id) in targets.iter() { + self.compute_postorder_traverse(Some(self.block_to_node[&target_id])); } } _ => {} @@ -619,7 +619,7 @@ impl FlowGraph { } Terminator::Switch { selector, - default, + default_id, ref targets, } => { let merge_node_index = @@ -660,7 +660,7 @@ impl FlowGraph { selector, cases, default: self.convert_to_naga_traverse( - self.block_to_node[&default], + self.block_to_node[&default_id], stop_nodes_cases, )?, }); diff --git a/src/front/spv/function.rs b/src/front/spv/function.rs index 952b693e07..ab72543d6d 100644 --- a/src/front/spv/function.rs +++ b/src/front/spv/function.rs @@ -38,7 +38,7 @@ pub enum Terminator { /// selector: Handle, /// Default block of the switch case. - default: BlockId, + default_id: BlockId, /// Tuples of (literal, target block) targets: Vec<(i32, BlockId)>, }, diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index e69f8701c4..dbe39bc0cf 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -45,7 +45,7 @@ use crate::{ use num_traits::cast::FromPrimitive; use petgraph::graphmap::GraphMap; -use std::{convert::TryInto, num::NonZeroU32, path::PathBuf}; +use std::{convert::TryInto, mem, num::NonZeroU32, path::PathBuf}; pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[ spirv::Capability::Shader, @@ -1973,7 +1973,7 @@ impl> Parser { Op::Switch => { inst.expect_at_least(3)?; let selector = self.next()?; - let default = self.next()?; + let default_id = self.next()?; let selector_lexp = &self.lookup_expression[&selector]; let selector_lty = self.lookup_type.lookup(selector_lexp.type_id)?; @@ -2005,7 +2005,7 @@ impl> Parser { break Terminator::Switch { selector, - default, + default_id, targets, }; } @@ -2205,41 +2205,55 @@ impl> Parser { } } - fn patch_function_call_statements( - &self, - statements: &mut [crate::Statement], - ) -> Result<(), Error> { + /// Walk the statement tree and patch it in the following cases: + /// 1. Function call targets are replaced by `deferred_function_calls` map + /// 2. Lift the contents of "If" that only breaks on rejection, onto the parent after it. + /// 3. Lift the contents of "Switch" that only has a default, onto the parent after it. + fn patch_statements(&self, statements: &mut crate::Block) -> Result<(), Error> { use crate::Statement as S; - for statement in statements.iter_mut() { - match *statement { + let mut i = 0usize; + while i < statements.len() { + match statements[i] { S::Emit(_) => {} S::Block(ref mut block) => { - self.patch_function_call_statements(block)?; + self.patch_statements(block)?; } S::If { condition: _, ref mut accept, ref mut reject, } => { - self.patch_function_call_statements(accept)?; - self.patch_function_call_statements(reject)?; + if let [S::Break] = reject[..] { + // uplift "accept" into the parent + let extracted = mem::replace(accept, Vec::new()); + statements.splice(i + 1..i + 1, extracted.into_iter()); + } else { + self.patch_statements(reject)?; + self.patch_statements(accept)?; + } } S::Switch { selector: _, ref mut cases, ref mut default, } => { - for case in cases.iter_mut() { - self.patch_function_call_statements(&mut case.body)?; + if cases.is_empty() { + // uplift "default" into the parent + let extracted = mem::replace(default, Vec::new()); + statements.splice(i + 1..i + 1, extracted.into_iter()); + } else { + for case in cases.iter_mut() { + self.patch_statements(&mut case.body)?; + } + self.patch_statements(default)?; } - self.patch_function_call_statements(default)?; } S::Loop { ref mut body, ref mut continuing, } => { - self.patch_function_call_statements(body)?; - self.patch_function_call_statements(continuing)?; + self.patch_statements(body)?; + self.patch_statements(continuing)?; } S::Break | S::Continue @@ -2255,18 +2269,19 @@ impl> Parser { *function = *self.lookup_function.lookup(fun_id)?; } } + i += 1; } Ok(()) } - fn patch_function_calls(&self, fun: &mut crate::Function) -> Result<(), Error> { + fn patch_function(&self, fun: &mut crate::Function) -> Result<(), Error> { for (_, expr) in fun.expressions.iter_mut() { if let crate::Expression::Call(ref mut function) = *expr { let fun_id = self.deferred_function_calls[function.index()]; *function = *self.lookup_function.lookup(fun_id)?; } } - self.patch_function_call_statements(&mut fun.body)?; + self.patch_statements(&mut fun.body)?; Ok(()) } @@ -2357,11 +2372,10 @@ impl> Parser { log::info!("Patching..."); { - use std::mem::take; let mut nodes = petgraph::algo::toposort(&self.function_call_graph, None) .map_err(|cycle| Error::FunctionCallCycle(cycle.node_id()))?; nodes.reverse(); // we need dominated first - let mut functions = take(&mut module.functions).into_inner(); + let mut functions = mem::take(&mut module.functions).into_inner(); for fun_id in nodes { if fun_id > !(functions.len() as u32) { // skip all the fake IDs registered for the entry points @@ -2369,17 +2383,17 @@ impl> Parser { } let handle = self.lookup_function.get_mut(&fun_id).unwrap(); // take out the function from the old array - let fun = take(&mut functions[handle.index()]); + let fun = mem::take(&mut functions[handle.index()]); // add it to the newly formed arena, and adjust the lookup *handle = module.functions.append(fun); } } - // patch all the function calls + // patch all the functions for (_, fun) in module.functions.iter_mut() { - self.patch_function_calls(fun)?; + self.patch_function(fun)?; } for ep in module.entry_points.iter_mut() { - self.patch_function_calls(&mut ep.function)?; + self.patch_function(&mut ep.function)?; } // Check all the images and samplers to have consistent comparison property.