Skip to content

Commit

Permalink
Do not recompute write sets for equal child nodes belonging to IR::Ve…
Browse files Browse the repository at this point in the history
…ctors

Signed-off-by: Kyle Cripps <[email protected]>
  • Loading branch information
kfcripps committed Jul 16, 2024
1 parent 2e76c92 commit b8d1d8e
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 10 deletions.
30 changes: 20 additions & 10 deletions frontends/p4/def_use.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,11 +636,10 @@ bool ComputeWriteSet::preorder(const IR::Mux *expression) {
bool ComputeWriteSet::preorder(const IR::SelectExpression *expression) {
BUG_CHECK(!lhs, "%1%: unexpected in lhs", expression);
visit(expression->select);
visit(&expression->selectCases);
expression->selectCases.visit_unique_children(*this);
auto l = getWrites(expression->select);
const loc_t *selectCasesLoc = getLoc(&expression->selectCases, getChildContext());
for (auto *c : expression->selectCases) {
const loc_t *selectCaseLoc = getLoc(c, selectCasesLoc);
const loc_t *selectCaseLoc = getLoc(c, getChildContext());
auto s = getWrites(c->keyset, selectCaseLoc);
l = l->join(s);
}
Expand All @@ -649,7 +648,7 @@ bool ComputeWriteSet::preorder(const IR::SelectExpression *expression) {
}

bool ComputeWriteSet::preorder(const IR::ListExpression *expression) {
visit(expression->components, "components");
expression->components.visit_unique_children(*this);
auto l = LocationSet::empty;
for (auto c : expression->components) {
auto cl = getWrites(c);
Expand Down Expand Up @@ -874,7 +873,7 @@ bool ComputeWriteSet::preorder(const IR::IfStatement *statement) {
bool ComputeWriteSet::preorder(const IR::ForStatement *statement) {
LOG3("CWS Visiting " << dbp(statement));
if (currentDefinitions->isUnreachable()) return setDefinitions(currentDefinitions);
visit(statement->init, "init");
statement->init.visit_unique_children(*this);

auto saveBreak = breakDefinitions;
auto saveContinue = continueDefinitions;
Expand All @@ -892,7 +891,7 @@ bool ComputeWriteSet::preorder(const IR::ForStatement *statement) {
(void)setDefinitions(exitDefs, statement->condition, true);
visit(statement->body, "body");
currentDefinitions = currentDefinitions->joinDefinitions(continueDefinitions);
visit(statement->updates, "updates");
statement->updates.visit_unique_children(*this);
currentDefinitions = currentDefinitions->joinDefinitions(startDefs);
} while (!(*startDefs == *currentDefinitions));

Expand Down Expand Up @@ -936,7 +935,7 @@ bool ComputeWriteSet::preorder(const IR::ForInStatement *statement) {

bool ComputeWriteSet::preorder(const IR::BlockStatement *statement) {
if (currentDefinitions->isUnreachable()) return setDefinitions(currentDefinitions);
visit(statement->components, "components");
statement->components.visit_unique_children(*this);
return setDefinitions(currentDefinitions);
}

Expand Down Expand Up @@ -991,6 +990,12 @@ bool ComputeWriteSet::preorder(const IR::AssignmentStatement *statement) {
return setDefinitions(defs);
}

bool ComputeWriteSet::preorder(const IR::SwitchCase *c) {
visit(c->statement);
// Do not visit c->label, as it cannot write anything.
return false;
}

bool ComputeWriteSet::preorder(const IR::SwitchStatement *statement) {
LOG3("CWS Visiting " << dbp(statement));
if (currentDefinitions->isUnreachable()) return setDefinitions(currentDefinitions);
Expand All @@ -1001,11 +1006,16 @@ bool ComputeWriteSet::preorder(const IR::SwitchStatement *statement) {
auto save = currentDefinitions;
auto result = new Definitions();
bool seenDefault = false;
for (auto s : statement->cases) {
std::unordered_set<const IR::Node *> visitedCases;
for (auto *c : statement->cases) {
if (visitedCases.find(c) != visitedCases.end()) continue;

currentDefinitions = save;
if (s->label->is<IR::DefaultExpression>()) seenDefault = true;
visit(s->statement);
if (c->label->is<IR::DefaultExpression>()) seenDefault = true;
visit(c);
result = result->joinDefinitions(currentDefinitions);

visitedCases.emplace(c);
}
auto table = TableApplySolver::isActionRun(statement->expression, storageMap->refMap,
storageMap->typeMap);
Expand Down
1 change: 1 addition & 0 deletions frontends/p4/def_use.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint {
bool preorder(const IR::ForStatement *statement) override;
bool preorder(const IR::ForInStatement *statement) override;
bool preorder(const IR::BlockStatement *statement) override;
bool preorder(const IR::SwitchCase *c) override;
bool preorder(const IR::SwitchStatement *statement) override;
bool preorder(const IR::EmptyStatement *statement) override;
bool preorder(const IR::MethodCallStatement *statement) override;
Expand Down
14 changes: 14 additions & 0 deletions ir/ir-inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ limitations under the License.
#ifndef IR_IR_INLINE_H_
#define IR_IR_INLINE_H_

#include <unordered_set>

#include "ir/id.h"
#include "ir/indexed_vector.h"
#include "ir/json_generator.h"
#include "ir/namemap.h"
#include "ir/nodemap.h"
#include "ir/visitor.h"

#define DEFINE_APPLY_FUNCTIONS(CLASS, TEMPLATE, TT, INLINE) \
TEMPLATE INLINE bool IR::CLASS TT::apply_visitor_preorder(Modifier &v) { \
Node::traceVisit("Mod pre"); \
Expand Down Expand Up @@ -131,6 +134,17 @@ template <class T>
void IR::Vector<T>::parallel_visit_children(Visitor &v) const {
SplitFlowVisitVector<T>(v, *this).run_visit();
}
template <class T>
void IR::Vector<T>::visit_unique_children(Visitor &v) const {
std::unordered_set<const T *> visited;
for (const auto *node : vec) {
// Visit each child component only once.
if (visited.find(node) != visited.end()) continue;
v.visit(node);
visited.emplace(node);
}
}

IRNODE_DEFINE_APPLY_OVERLOAD(Vector, template <class T>, <T>)
template <class T>
void IR::Vector<T>::toJSON(JSONGenerator &json) const {
Expand Down
1 change: 1 addition & 0 deletions ir/vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class Vector : public VectorBase {
cstring node_type_name() const override { return "Vector<" + T::static_type_name() + ">"; }
static cstring static_type_name() { return "Vector<" + T::static_type_name() + ">"; }
void visit_children(Visitor &v) override;
void visit_unique_children(Visitor &v) const;
void visit_children(Visitor &v) const override;
virtual void parallel_visit_children(Visitor &v);
virtual void parallel_visit_children(Visitor &v) const;
Expand Down

0 comments on commit b8d1d8e

Please sign in to comment.