Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instrument GitHub source with a ChunkReporter #3296

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 50 additions & 60 deletions pkg/sources/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,12 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) source_metada

// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, targets ...sources.ChunkingTarget) error {
chunksReporter := sources.ChanReporter{Ch: chunksChan}
// If targets are provided, we're only scanning the data in those targets.
// Otherwise, we're scanning all data.
// This allows us to only scan the commit where a vulnerability was found.
if len(targets) > 0 {
errs := s.scanTargets(ctx, targets, chunksChan)
errs := s.scanTargets(ctx, targets, chunksReporter)
return errors.Join(errs...)
}

Expand All @@ -335,7 +336,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, tar
return fmt.Errorf("error enumerating: %w", err)
}

return s.scan(ctx, chunksChan)
return s.scan(ctx, chunksReporter)
}

func (s *Source) enumerate(ctx context.Context) error {
Expand Down Expand Up @@ -564,7 +565,7 @@ func createGitHubClient(httpClient *http.Client, apiEndpoint string) (*github.Cl
return github.NewClient(httpClient).WithEnterpriseURLs(apiEndpoint, apiEndpoint)
}

func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error {
func (s *Source) scan(ctx context.Context, reporter sources.ChunkReporter) error {
var scannedCount uint64 = 1

ctx.Logger().V(2).Info("Found repos to scan", "count", len(s.repos))
Expand Down Expand Up @@ -609,7 +610,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error
return nil
}
repoCtx := context.WithValues(ctx, "repo", repoURL)
duration, err := s.cloneAndScanRepo(repoCtx, repoURL, repoInfo, chunksChan)
duration, err := s.cloneAndScanRepo(repoCtx, repoURL, repoInfo, reporter)
if err != nil {
scanErrs.Add(err)
return nil
Expand All @@ -620,7 +621,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error
wikiURL := strings.TrimSuffix(repoURL, ".git") + ".wiki.git"
wikiCtx := context.WithValue(ctx, "repo", wikiURL)

_, err := s.cloneAndScanRepo(wikiCtx, wikiURL, repoInfo, chunksChan)
_, err := s.cloneAndScanRepo(wikiCtx, wikiURL, repoInfo, reporter)
if err != nil {
// Ignore "Repository not found" errors.
// It's common for GitHub's API to say a repo has a wiki when it doesn't.
Expand All @@ -634,7 +635,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error

// Scan comments, if enabled.
if s.includeGistComments || s.includeIssueComments || s.includePRComments {
if err = s.scanComments(repoCtx, repoURL, repoInfo, chunksChan); err != nil {
if err = s.scanComments(repoCtx, repoURL, repoInfo, reporter); err != nil {
scanErrs.Add(fmt.Errorf("error scanning comments in repo %s: %w", repoURL, err))
return nil
}
Expand All @@ -656,7 +657,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error
return nil
}

func (s *Source) cloneAndScanRepo(ctx context.Context, repoURL string, repoInfo repoInfo, chunksChan chan *sources.Chunk) (time.Duration, error) {
func (s *Source) cloneAndScanRepo(ctx context.Context, repoURL string, repoInfo repoInfo, reporter sources.ChunkReporter) (time.Duration, error) {
var duration time.Duration

ctx.Logger().V(2).Info("attempting to clone repo")
Expand All @@ -679,7 +680,7 @@ func (s *Source) cloneAndScanRepo(ctx context.Context, repoURL string, repoInfo
logger.V(2).Info("scanning repo")

start := time.Now()
if err = s.git.ScanRepo(ctx, repo, path, s.scanOptions, sources.ChanReporter{Ch: chunksChan}); err != nil {
if err = s.git.ScanRepo(ctx, repo, path, s.scanOptions, reporter); err != nil {
return duration, fmt.Errorf("error scanning repo %s: %w", repoURL, err)
}
duration = time.Since(start)
Expand Down Expand Up @@ -948,16 +949,16 @@ func (s *Source) setProgressCompleteWithRepo(index int, offset int, repoURL stri
s.SetProgressComplete(index+offset, len(s.repos)+offset, fmt.Sprintf("Repo: %s", repoURL), encodedResumeInfo)
}

func (s *Source) scanComments(ctx context.Context, repoPath string, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) scanComments(ctx context.Context, repoPath string, repoInfo repoInfo, reporter sources.ChunkReporter) error {
urlString, urlParts, err := getRepoURLParts(repoPath)
if err != nil {
return err
}

if s.includeGistComments && isGistUrl(urlParts) {
return s.processGistComments(ctx, urlString, urlParts, repoInfo, chunksChan)
return s.processGistComments(ctx, urlString, urlParts, repoInfo, reporter)
} else if s.includeIssueComments || s.includePRComments {
return s.processRepoComments(ctx, repoInfo, chunksChan)
return s.processRepoComments(ctx, repoInfo, reporter)
}
return nil
}
Expand Down Expand Up @@ -1017,7 +1018,7 @@ func getRepoURLParts(repoURLString string) (string, []string, error) {

const initialPage = 1 // page to start listing from

func (s *Source) processGistComments(ctx context.Context, gistURL string, urlParts []string, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) processGistComments(ctx context.Context, gistURL string, urlParts []string, repoInfo repoInfo, reporter sources.ChunkReporter) error {
ctx.Logger().V(2).Info("Scanning GitHub Gist comments")

// GitHub Gist URL.
Expand All @@ -1036,7 +1037,7 @@ func (s *Source) processGistComments(ctx context.Context, gistURL string, urlPar
return err
}

if err = s.chunkGistComments(ctx, gistURL, repoInfo, comments, chunksChan); err != nil {
if err = s.chunkGistComments(ctx, gistURL, repoInfo, comments, reporter); err != nil {
return err
}

Expand All @@ -1056,10 +1057,10 @@ func isGistUrl(urlParts []string) bool {
return strings.EqualFold(urlParts[0], "gist.github.com") || (len(urlParts) == 4 && strings.EqualFold(urlParts[1], "gist"))
}

func (s *Source) chunkGistComments(ctx context.Context, gistURL string, gistInfo repoInfo, comments []*github.GistComment, chunksChan chan *sources.Chunk) error {
func (s *Source) chunkGistComments(ctx context.Context, gistURL string, gistInfo repoInfo, comments []*github.GistComment, reporter sources.ChunkReporter) error {
for _, comment := range comments {
// Create chunk and send it to the channel.
chunk := &sources.Chunk{
chunk := sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
SourceType: s.Type(),
Expand All @@ -1080,10 +1081,8 @@ func (s *Source) chunkGistComments(ctx context.Context, gistURL string, gistInfo
Verify: s.verify,
}

select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
}
return nil
Expand All @@ -1104,23 +1103,23 @@ var (
state = "all"
)

func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error {
if s.includeIssueComments {
ctx.Logger().V(2).Info("Scanning issues")
if err := s.processIssues(ctx, repoInfo, chunksChan); err != nil {
if err := s.processIssues(ctx, repoInfo, reporter); err != nil {
return err
}
if err := s.processIssueComments(ctx, repoInfo, chunksChan); err != nil {
if err := s.processIssueComments(ctx, repoInfo, reporter); err != nil {
return err
}
}

if s.includePRComments {
ctx.Logger().V(2).Info("Scanning pull requests")
if err := s.processPRs(ctx, repoInfo, chunksChan); err != nil {
if err := s.processPRs(ctx, repoInfo, reporter); err != nil {
return err
}
if err := s.processPRComments(ctx, repoInfo, chunksChan); err != nil {
if err := s.processPRComments(ctx, repoInfo, reporter); err != nil {
return err
}
}
Expand All @@ -1129,7 +1128,7 @@ func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, chu

}

func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error {
bodyTextsOpts := &github.IssueListByRepoOptions{
Sort: sortType,
Direction: directionType,
Expand All @@ -1150,7 +1149,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksCha
return err
}

if err = s.chunkIssues(ctx, repoInfo, issues, chunksChan); err != nil {
if err = s.chunkIssues(ctx, repoInfo, issues, reporter); err != nil {
return err
}

Expand All @@ -1163,7 +1162,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksCha
return nil
}

func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*github.Issue, chunksChan chan *sources.Chunk) error {
func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*github.Issue, reporter sources.ChunkReporter) error {
for _, issue := range issues {

// Skip pull requests since covered by processPRs.
Expand All @@ -1172,7 +1171,7 @@ func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*g
}

// Create chunk and send it to the channel.
chunk := &sources.Chunk{
chunk := sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
JobID: s.JobID(),
Expand All @@ -1193,16 +1192,14 @@ func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*g
Verify: s.verify,
}

select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
}
return nil
}

func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error {
issueOpts := &github.IssueListCommentsOptions{
Sort: &sortType,
Direction: &directionType,
Expand All @@ -1221,7 +1218,7 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, ch
return err
}

if err = s.chunkIssueComments(ctx, repoInfo, issueComments, chunksChan); err != nil {
if err = s.chunkIssueComments(ctx, repoInfo, issueComments, reporter); err != nil {
return err
}

Expand All @@ -1233,10 +1230,10 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, ch
return nil
}

func (s *Source) chunkIssueComments(ctx context.Context, repoInfo repoInfo, comments []*github.IssueComment, chunksChan chan *sources.Chunk) error {
func (s *Source) chunkIssueComments(ctx context.Context, repoInfo repoInfo, comments []*github.IssueComment, reporter sources.ChunkReporter) error {
for _, comment := range comments {
// Create chunk and send it to the channel.
chunk := &sources.Chunk{
chunk := sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
JobID: s.JobID(),
Expand All @@ -1257,16 +1254,14 @@ func (s *Source) chunkIssueComments(ctx context.Context, repoInfo repoInfo, comm
Verify: s.verify,
}

select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
}
return nil
}

func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error {
prOpts := &github.PullRequestListOptions{
Sort: sortType,
Direction: directionType,
Expand All @@ -1286,7 +1281,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan c
return err
}

if err = s.chunkPullRequests(ctx, repoInfo, prs, chunksChan); err != nil {
if err = s.chunkPullRequests(ctx, repoInfo, prs, reporter); err != nil {
return err
}

Expand All @@ -1299,7 +1294,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan c
return nil
}

func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error {
prOpts := &github.PullRequestListCommentsOptions{
Sort: sortType,
Direction: directionType,
Expand All @@ -1318,7 +1313,7 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunk
return err
}

if err = s.chunkPullRequestComments(ctx, repoInfo, prComments, chunksChan); err != nil {
if err = s.chunkPullRequestComments(ctx, repoInfo, prComments, reporter); err != nil {
return err
}

Expand All @@ -1331,10 +1326,10 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunk
return nil
}

func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs []*github.PullRequest, chunksChan chan *sources.Chunk) error {
func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs []*github.PullRequest, reporter sources.ChunkReporter) error {
for _, pr := range prs {
// Create chunk and send it to the channel.
chunk := &sources.Chunk{
chunk := sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
SourceType: s.Type(),
Expand All @@ -1355,19 +1350,17 @@ func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs [
Verify: s.verify,
}

select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
}
return nil
}

func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo, comments []*github.PullRequestComment, chunksChan chan *sources.Chunk) error {
func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo, comments []*github.PullRequestComment, reporter sources.ChunkReporter) error {
for _, comment := range comments {
// Create chunk and send it to the channel.
chunk := &sources.Chunk{
chunk := sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
SourceType: s.Type(),
Expand All @@ -1388,19 +1381,17 @@ func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo
Verify: s.verify,
}

select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
}
return nil
}

func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarget, chunksChan chan *sources.Chunk) []error {
func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarget, reporter sources.ChunkReporter) []error {
var errs []error
for _, tgt := range targets {
if err := s.scanTarget(ctx, tgt, chunksChan); err != nil {
if err := s.scanTarget(ctx, tgt, reporter); err != nil {
ctx.Logger().Error(err, "error scanning target")
errs = append(errs, &sources.TargetedScanError{Err: err, SecretID: tgt.SecretID})
}
Expand All @@ -1409,7 +1400,7 @@ func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarg
return errs
}

func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget, chunksChan chan *sources.Chunk) error {
func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget, reporter sources.ChunkReporter) error {
metaType, ok := target.QueryCriteria.GetData().(*source_metadatapb.MetaData_Github)
if !ok {
return fmt.Errorf("unable to cast metadata type for targeted scan")
Expand Down Expand Up @@ -1446,7 +1437,6 @@ func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget,
return fmt.Errorf("unexpected HTTP response status when trying to download file for scan: %v", resp.Status)
}

reporter := sources.ChanReporter{Ch: chunksChan}
chunkSkel := sources.Chunk{
SourceType: s.Type(),
SourceName: s.name,
Expand Down
Loading