Skip to content

Commit

Permalink
Add callgraphutil.WriteDOT
Browse files Browse the repository at this point in the history
  • Loading branch information
picatz committed Jan 5, 2024
1 parent eebdf3f commit b107a34
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 0 deletions.
41 changes: 41 additions & 0 deletions callgraphutil/dot.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package callgraphutil

import (
"bufio"
"fmt"
"io"

"golang.org/x/tools/go/callgraph"
)

// WriteDOT writes the given callgraph.Graph to the given io.Writer in the
// DOT format, which can be used to generate a visual representation of the
// call graph using Graphviz.
func WriteDOT(w io.Writer, g *callgraph.Graph) error {
b := bufio.NewWriter(w)
defer b.Flush()

b.WriteString("digraph callgraph {\n")
b.WriteString("\tgraph [fontname=\"Helvetica\"];\n")
b.WriteString("\tnode [fontname=\"Helvetica\"];\n")
b.WriteString("\tedge [fontname=\"Helvetica\"];\n")

edges := []*callgraph.Edge{}

// Write nodes.
for _, n := range g.Nodes {
b.WriteString(fmt.Sprintf("\t%q [label=%q];\n", fmt.Sprintf("%d", n.ID), n.Func))

// Add edges
edges = append(edges, n.Out...)
}

// Write edges.
for _, e := range edges {
b.WriteString(fmt.Sprintf("\t%q -> %q [label=%q];\n", fmt.Sprintf("%d", e.Caller.ID), fmt.Sprintf("%d", e.Callee.ID), e.Site))
}

b.WriteString("}\n")

return nil
}
196 changes: 196 additions & 0 deletions callgraphutil/dot_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package callgraphutil_test

import (
"bytes"
"context"
"fmt"
"go/ast"
"go/parser"
"go/token"
"os"
"path/filepath"
"testing"

"github.com/go-git/go-git/v5"
"github.com/picatz/taint/callgraphutil"
"golang.org/x/tools/go/callgraph"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/ssa"
"golang.org/x/tools/go/ssa/ssautil"
)

func cloneGitHubRepository(ctx context.Context, ownerName, repoName string) (string, string, error) {
// Get the owner and repo part of the URL.
ownerAndRepo := ownerName + "/" + repoName

// Get the directory path.
dir := filepath.Join(os.TempDir(), "taint", "github", ownerAndRepo)

// Check if the directory exists.
_, err := os.Stat(dir)
if err == nil {
// If the directory exists, we'll assume it's a valid repository,
// and return the directory. Open the directory to
repo, err := git.PlainOpen(dir)
if err != nil {
return dir, "", fmt.Errorf("%w", err)
}

// Get the repository's HEAD.
head, err := repo.Head()
if err != nil {
return dir, "", fmt.Errorf("%w", err)
}

return dir, head.Hash().String(), nil
}

// Clone the repository.
repo, err := git.PlainCloneContext(ctx, dir, false, &git.CloneOptions{
URL: fmt.Sprintf("https://github.com/%s", ownerAndRepo),
Depth: 1,
Tags: git.NoTags,
SingleBranch: true,
})
if err != nil {
return dir, "", fmt.Errorf("%w", err)
}

// Get the repository's HEAD.
head, err := repo.Head()
if err != nil {
return dir, "", fmt.Errorf("%w", err)
}

return dir, head.Hash().String(), nil
}

func loadPackages(ctx context.Context, dir, pattern string) ([]*packages.Package, error) {
loadMode :=
packages.NeedName |
packages.NeedDeps |
packages.NeedFiles |
packages.NeedModule |
packages.NeedTypes |
packages.NeedImports |
packages.NeedSyntax |
packages.NeedTypesInfo
// packages.NeedTypesSizes |
// packages.NeedCompiledGoFiles |
// packages.NeedExportFile |
// packages.NeedEmbedPatterns

// parseMode := parser.ParseComments
parseMode := parser.SkipObjectResolution

// patterns := []string{dir}
patterns := []string{pattern}
// patterns := []string{"all"}

pkgs, err := packages.Load(&packages.Config{
Mode: loadMode,
Context: ctx,
Env: os.Environ(),
Dir: dir,
Tests: false,
ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) {
return parser.ParseFile(fset, filename, src, parseMode)
},
}, patterns...)
if err != nil {
return nil, err
}

return pkgs, nil

}

func loadSSA(ctx context.Context, pkgs []*packages.Package) (mainFn *ssa.Function, srcFns []*ssa.Function, err error) {
ssaBuildMode := ssa.InstantiateGenerics // ssa.SanityCheckFunctions | ssa.GlobalDebug

// Analyze the package.
ssaProg, ssaPkgs := ssautil.Packages(pkgs, ssaBuildMode)

ssaProg.Build()

for _, pkg := range ssaPkgs {
pkg.Build()
}

mainPkgs := ssautil.MainPackages(ssaPkgs)

mainFn = mainPkgs[0].Members["main"].(*ssa.Function)

for _, pkg := range ssaPkgs {
for _, fn := range pkg.Members {
if fn.Object() == nil {
continue
}

if fn.Object().Name() == "_" {
continue
}

pkgFn := pkg.Func(fn.Object().Name())
if pkgFn == nil {
continue
}

var addAnons func(f *ssa.Function)
addAnons = func(f *ssa.Function) {
srcFns = append(srcFns, f)
for _, anon := range f.AnonFuncs {
addAnons(anon)
}
}
addAnons(pkgFn)
}
}

if mainFn == nil {
err = fmt.Errorf("failed to find main function")
return
}

return
}

func loadCallGraph(ctx context.Context, mainFn *ssa.Function, srcFns []*ssa.Function) (*callgraph.Graph, error) {
cg, err := callgraphutil.NewGraph(mainFn, srcFns...)
if err != nil {
return nil, fmt.Errorf("failed to create new callgraph: %w", err)
}

return cg, nil
}

func TestWriteDOT(t *testing.T) {
repo, _, err := cloneGitHubRepository(context.Background(), "picatz", "taint")
if err != nil {
t.Fatal(err)
}

pkgs, err := loadPackages(context.Background(), repo, "./...")
if err != nil {
t.Fatal(err)
}

mainFn, srcFns, err := loadSSA(context.Background(), pkgs)
if err != nil {
t.Fatal(err)
}

cg, err := loadCallGraph(context.Background(), mainFn, srcFns)
if err != nil {
t.Fatal(err)
}

output := &bytes.Buffer{}

err = callgraphutil.WriteDOT(output, cg)
if err != nil {
t.Fatal(err)
}

fmt.Println(output.String())
}

0 comments on commit b107a34

Please sign in to comment.