Fix protected branch files detection on pre_receive hook (#31778)

Fix #31738

When pushing a new branch, the old commit is zero. Most git commands
cannot recognize the zero commit id. To get the changed files in the
push, we need to get the first diverge commit of this branch. In most
situations, we could check commits one by one until one commit is
contained by another branch. Then we will think that commit is the
diverge point.

And in a pre-receive hook, this will be more difficult because all
commits haven't been merged and they actually stored in a temporary
place by git. So we need to bring some envs to let git know the commit
exist.
This commit is contained in:
Lunny Xiao 2024-08-06 21:32:49 +08:00 committed by GitHub
parent de175e3b06
commit df7f1c2ead
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 81 additions and 14 deletions

View File

@ -4,6 +4,8 @@
package git package git
import ( import (
"context"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
@ -345,3 +347,18 @@ func TestGetCommitFileStatusMerges(t *testing.T) {
assert.Equal(t, commitFileStatus.Removed, expected.Removed) assert.Equal(t, commitFileStatus.Removed, expected.Removed)
assert.Equal(t, commitFileStatus.Modified, expected.Modified) assert.Equal(t, commitFileStatus.Modified, expected.Modified)
} }
func Test_GetCommitBranchStart(t *testing.T) {
bareRepo1Path := filepath.Join(testReposDir, "repo1_bare")
repo, err := OpenRepository(context.Background(), bareRepo1Path)
assert.NoError(t, err)
defer repo.Close()
commit, err := repo.GetBranchCommit("branch1")
assert.NoError(t, err)
assert.EqualValues(t, "2839944139e0de9737a044f78b0e4b40d989a9e3", commit.ID.String())
startCommitID, err := repo.GetCommitBranchStart(os.Environ(), "branch1", commit.ID.String())
assert.NoError(t, err)
assert.NotEmpty(t, startCommitID)
assert.EqualValues(t, "9c9aef8dd84e02bc7ec12641deb4c930a7c30185", startCommitID)
}

View File

@ -271,7 +271,17 @@ func CutDiffAroundLine(originalDiff io.Reader, line int64, old bool, numbersOfLi
} }
// GetAffectedFiles returns the affected files between two commits // GetAffectedFiles returns the affected files between two commits
func GetAffectedFiles(repo *Repository, oldCommitID, newCommitID string, env []string) ([]string, error) { func GetAffectedFiles(repo *Repository, branchName, oldCommitID, newCommitID string, env []string) ([]string, error) {
if oldCommitID == emptySha1ObjectID.String() || oldCommitID == emptySha256ObjectID.String() {
startCommitID, err := repo.GetCommitBranchStart(env, branchName, newCommitID)
if err != nil {
return nil, err
}
if startCommitID == "" {
return nil, fmt.Errorf("cannot find the start commit of %s", newCommitID)
}
oldCommitID = startCommitID
}
stdoutReader, stdoutWriter, err := os.Pipe() stdoutReader, stdoutWriter, err := os.Pipe()
if err != nil { if err != nil {
log.Error("Unable to create os.Pipe for %s", repo.Path) log.Error("Unable to create os.Pipe for %s", repo.Path)

View File

@ -7,6 +7,7 @@ package git
import ( import (
"bytes" "bytes"
"io" "io"
"os"
"strconv" "strconv"
"strings" "strings"
@ -414,7 +415,7 @@ func (repo *Repository) commitsBefore(id ObjectID, limit int) ([]*Commit, error)
commits := make([]*Commit, 0, len(formattedLog)) commits := make([]*Commit, 0, len(formattedLog))
for _, commit := range formattedLog { for _, commit := range formattedLog {
branches, err := repo.getBranches(commit, 2) branches, err := repo.getBranches(os.Environ(), commit.ID.String(), 2)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -437,12 +438,15 @@ func (repo *Repository) getCommitsBeforeLimit(id ObjectID, num int) ([]*Commit,
return repo.commitsBefore(id, num) return repo.commitsBefore(id, num)
} }
func (repo *Repository) getBranches(commit *Commit, limit int) ([]string, error) { func (repo *Repository) getBranches(env []string, commitID string, limit int) ([]string, error) {
if DefaultFeatures().CheckVersionAtLeast("2.7.0") { if DefaultFeatures().CheckVersionAtLeast("2.7.0") {
stdout, _, err := NewCommand(repo.Ctx, "for-each-ref", "--format=%(refname:strip=2)"). stdout, _, err := NewCommand(repo.Ctx, "for-each-ref", "--format=%(refname:strip=2)").
AddOptionFormat("--count=%d", limit). AddOptionFormat("--count=%d", limit).
AddOptionValues("--contains", commit.ID.String(), BranchPrefix). AddOptionValues("--contains", commitID, BranchPrefix).
RunStdString(&RunOpts{Dir: repo.Path}) RunStdString(&RunOpts{
Dir: repo.Path,
Env: env,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -451,7 +455,10 @@ func (repo *Repository) getBranches(commit *Commit, limit int) ([]string, error)
return branches, nil return branches, nil
} }
stdout, _, err := NewCommand(repo.Ctx, "branch").AddOptionValues("--contains", commit.ID.String()).RunStdString(&RunOpts{Dir: repo.Path}) stdout, _, err := NewCommand(repo.Ctx, "branch").AddOptionValues("--contains", commitID).RunStdString(&RunOpts{
Dir: repo.Path,
Env: env,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -513,3 +520,35 @@ func (repo *Repository) AddLastCommitCache(cacheKey, fullName, sha string) error
} }
return nil return nil
} }
func (repo *Repository) GetCommitBranchStart(env []string, branch, endCommitID string) (string, error) {
cmd := NewCommand(repo.Ctx, "log", prettyLogFormat)
cmd.AddDynamicArguments(endCommitID)
stdout, _, runErr := cmd.RunStdBytes(&RunOpts{
Dir: repo.Path,
Env: env,
})
if runErr != nil {
return "", runErr
}
parts := bytes.Split(bytes.TrimSpace(stdout), []byte{'\n'})
var startCommitID string
for _, commitID := range parts {
branches, err := repo.getBranches(env, string(commitID), 2)
if err != nil {
return "", err
}
for _, b := range branches {
if b != branch {
return startCommitID, nil
}
}
startCommitID = string(commitID)
}
return "", nil
}

View File

@ -4,6 +4,7 @@
package git package git
import ( import (
"os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -31,7 +32,7 @@ func TestRepository_GetCommitBranches(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
commit, err := bareRepo1.GetCommit(testCase.CommitID) commit, err := bareRepo1.GetCommit(testCase.CommitID)
assert.NoError(t, err) assert.NoError(t, err)
branches, err := bareRepo1.getBranches(commit, 2) branches, err := bareRepo1.getBranches(os.Environ(), commit.ID.String(), 2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, testCase.ExpectedBranches, branches) assert.Equal(t, testCase.ExpectedBranches, branches)
} }

View File

@ -235,7 +235,7 @@ func preReceiveBranch(ctx *preReceiveContext, oldCommitID, newCommitID string, r
globs := protectBranch.GetProtectedFilePatterns() globs := protectBranch.GetProtectedFilePatterns()
if len(globs) > 0 { if len(globs) > 0 {
_, err := pull_service.CheckFileProtection(gitRepo, oldCommitID, newCommitID, globs, 1, ctx.env) _, err := pull_service.CheckFileProtection(gitRepo, branchName, oldCommitID, newCommitID, globs, 1, ctx.env)
if err != nil { if err != nil {
if !models.IsErrFilePathProtected(err) { if !models.IsErrFilePathProtected(err) {
log.Error("Unable to check file protection for commits from %s to %s in %-v: %v", oldCommitID, newCommitID, repo, err) log.Error("Unable to check file protection for commits from %s to %s in %-v: %v", oldCommitID, newCommitID, repo, err)
@ -293,7 +293,7 @@ func preReceiveBranch(ctx *preReceiveContext, oldCommitID, newCommitID string, r
// Allow commits that only touch unprotected files // Allow commits that only touch unprotected files
globs := protectBranch.GetUnprotectedFilePatterns() globs := protectBranch.GetUnprotectedFilePatterns()
if len(globs) > 0 { if len(globs) > 0 {
unprotectedFilesOnly, err := pull_service.CheckUnprotectedFiles(gitRepo, oldCommitID, newCommitID, globs, ctx.env) unprotectedFilesOnly, err := pull_service.CheckUnprotectedFiles(gitRepo, branchName, oldCommitID, newCommitID, globs, ctx.env)
if err != nil { if err != nil {
log.Error("Unable to check file protection for commits from %s to %s in %-v: %v", oldCommitID, newCommitID, repo, err) log.Error("Unable to check file protection for commits from %s to %s in %-v: %v", oldCommitID, newCommitID, repo, err)
ctx.JSON(http.StatusInternalServerError, private.Response{ ctx.JSON(http.StatusInternalServerError, private.Response{

View File

@ -503,11 +503,11 @@ func checkConflicts(ctx context.Context, pr *issues_model.PullRequest, gitRepo *
} }
// CheckFileProtection check file Protection // CheckFileProtection check file Protection
func CheckFileProtection(repo *git.Repository, oldCommitID, newCommitID string, patterns []glob.Glob, limit int, env []string) ([]string, error) { func CheckFileProtection(repo *git.Repository, branchName, oldCommitID, newCommitID string, patterns []glob.Glob, limit int, env []string) ([]string, error) {
if len(patterns) == 0 { if len(patterns) == 0 {
return nil, nil return nil, nil
} }
affectedFiles, err := git.GetAffectedFiles(repo, oldCommitID, newCommitID, env) affectedFiles, err := git.GetAffectedFiles(repo, branchName, oldCommitID, newCommitID, env)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -533,11 +533,11 @@ func CheckFileProtection(repo *git.Repository, oldCommitID, newCommitID string,
} }
// CheckUnprotectedFiles check if the commit only touches unprotected files // CheckUnprotectedFiles check if the commit only touches unprotected files
func CheckUnprotectedFiles(repo *git.Repository, oldCommitID, newCommitID string, patterns []glob.Glob, env []string) (bool, error) { func CheckUnprotectedFiles(repo *git.Repository, branchName, oldCommitID, newCommitID string, patterns []glob.Glob, env []string) (bool, error) {
if len(patterns) == 0 { if len(patterns) == 0 {
return false, nil return false, nil
} }
affectedFiles, err := git.GetAffectedFiles(repo, oldCommitID, newCommitID, env) affectedFiles, err := git.GetAffectedFiles(repo, branchName, oldCommitID, newCommitID, env)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -574,7 +574,7 @@ func checkPullFilesProtection(ctx context.Context, pr *issues_model.PullRequest,
return nil return nil
} }
pr.ChangedProtectedFiles, err = CheckFileProtection(gitRepo, pr.MergeBase, "tracking", pb.GetProtectedFilePatterns(), 10, os.Environ()) pr.ChangedProtectedFiles, err = CheckFileProtection(gitRepo, pr.HeadBranch, pr.MergeBase, "tracking", pb.GetProtectedFilePatterns(), 10, os.Environ())
if err != nil && !models.IsErrFilePathProtected(err) { if err != nil && !models.IsErrFilePathProtected(err) {
return err return err
} }