diff --git a/modules/context/context.go b/modules/context/context.go index 21bae9112..7e986b011 100644 --- a/modules/context/context.go +++ b/modules/context/context.go @@ -47,7 +47,7 @@ const CookieNameFlash = "gitea_flash" // Render represents a template render type Render interface { - TemplateLookup(tmpl string) (*template.Template, error) + TemplateLookup(tmpl string) (templates.TemplateExecutor, error) HTML(w io.Writer, status int, name string, data interface{}) error } diff --git a/modules/templates/htmlrenderer.go b/modules/templates/htmlrenderer.go index 833c2acdc..2cecef5f8 100644 --- a/modules/templates/htmlrenderer.go +++ b/modules/templates/htmlrenderer.go @@ -9,7 +9,6 @@ import ( "context" "errors" "fmt" - "html/template" "io" "net/http" "path/filepath" @@ -22,13 +21,16 @@ import ( "code.gitea.io/gitea/modules/assetfs" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/templates/scopedtmpl" "code.gitea.io/gitea/modules/util" ) var rendererKey interface{} = "templatesHtmlRenderer" +type TemplateExecutor scopedtmpl.TemplateExecutor + type HTMLRender struct { - templates atomic.Pointer[template.Template] + templates atomic.Pointer[scopedtmpl.ScopedTemplate] } var ErrTemplateNotInitialized = errors.New("template system is not initialized, check your log for errors") @@ -47,22 +49,20 @@ func (h *HTMLRender) HTML(w io.Writer, status int, name string, data interface{} return t.Execute(w, data) } -func (h *HTMLRender) TemplateLookup(name string) (*template.Template, error) { +func (h *HTMLRender) TemplateLookup(name string) (TemplateExecutor, error) { tmpls := h.templates.Load() if tmpls == nil { return nil, ErrTemplateNotInitialized } - tmpl := tmpls.Lookup(name) - if tmpl == nil { - return nil, util.ErrNotExist - } - return tmpl, nil + + return tmpls.Executor(name, NewFuncMap()[0]) } func (h *HTMLRender) CompileTemplates() error { - extSuffix := ".tmpl" - tmpls := template.New("") assets := AssetFS() + extSuffix := ".tmpl" + tmpls := scopedtmpl.NewScopedTemplate() + tmpls.Funcs(NewFuncMap()[0]) files, err := ListWebTemplateAssetNames(assets) if err != nil { return nil @@ -73,9 +73,6 @@ func (h *HTMLRender) CompileTemplates() error { } name := strings.TrimSuffix(file, extSuffix) tmpl := tmpls.New(filepath.ToSlash(name)) - for _, fm := range NewFuncMap() { - tmpl.Funcs(fm) - } buf, err := assets.ReadFile(file) if err != nil { return err @@ -84,6 +81,7 @@ func (h *HTMLRender) CompileTemplates() error { return err } } + tmpls.Freeze() h.templates.Store(tmpls) return nil } diff --git a/modules/templates/scopedtmpl/scopedtmpl.go b/modules/templates/scopedtmpl/scopedtmpl.go new file mode 100644 index 000000000..a8b67ad0f --- /dev/null +++ b/modules/templates/scopedtmpl/scopedtmpl.go @@ -0,0 +1,239 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package scopedtmpl + +import ( + "fmt" + "html/template" + "io" + "reflect" + "sync" + texttemplate "text/template" + "text/template/parse" + "unsafe" +) + +type TemplateExecutor interface { + Execute(wr io.Writer, data interface{}) error +} + +type ScopedTemplate struct { + all *template.Template + parseFuncs template.FuncMap // this func map is only used for parsing templates + frozen bool + + scopedMu sync.RWMutex + scopedTemplateSets map[string]*scopedTemplateSet +} + +func NewScopedTemplate() *ScopedTemplate { + return &ScopedTemplate{ + all: template.New(""), + parseFuncs: template.FuncMap{}, + scopedTemplateSets: map[string]*scopedTemplateSet{}, + } +} + +func (t *ScopedTemplate) Funcs(funcMap template.FuncMap) { + if t.frozen { + panic("cannot add new functions to frozen template set") + } + t.all.Funcs(funcMap) + for k, v := range funcMap { + t.parseFuncs[k] = v + } +} + +func (t *ScopedTemplate) New(name string) *template.Template { + if t.frozen { + panic("cannot add new template to frozen template set") + } + return t.all.New(name) +} + +func (t *ScopedTemplate) Freeze() { + t.frozen = true + // reset the exec func map, then `escapeTemplate` is safe to call `Execute` to do escaping + m := template.FuncMap{} + for k := range t.parseFuncs { + m[k] = func(v ...any) any { return nil } + } + t.all.Funcs(m) +} + +func (t *ScopedTemplate) Executor(name string, funcMap template.FuncMap) (TemplateExecutor, error) { + t.scopedMu.RLock() + scopedTmplSet, ok := t.scopedTemplateSets[name] + t.scopedMu.RUnlock() + + if !ok { + var err error + t.scopedMu.Lock() + if scopedTmplSet, ok = t.scopedTemplateSets[name]; !ok { + if scopedTmplSet, err = newScopedTemplateSet(t.all, name); err == nil { + t.scopedTemplateSets[name] = scopedTmplSet + } + } + t.scopedMu.Unlock() + if err != nil { + return nil, err + } + } + + if scopedTmplSet == nil { + return nil, fmt.Errorf("template %s not found", name) + } + return scopedTmplSet.newExecutor(funcMap), nil +} + +type scopedTemplateSet struct { + name string + htmlTemplates map[string]*template.Template + textTemplates map[string]*texttemplate.Template + execFuncs map[string]reflect.Value +} + +func escapeTemplate(t *template.Template) error { + // force the Golang HTML template to complete the escaping work + err := t.Execute(io.Discard, nil) + if _, ok := err.(*template.Error); ok { + return err + } + return nil +} + +//nolint:unused +type htmlTemplate struct { + escapeErr error + text *texttemplate.Template +} + +//nolint:unused +type textTemplateCommon struct { + tmpl map[string]*template.Template // Map from name to defined templates. + muTmpl sync.RWMutex // protects tmpl + option struct { + missingKey int + } + muFuncs sync.RWMutex // protects parseFuncs and execFuncs + parseFuncs texttemplate.FuncMap + execFuncs map[string]reflect.Value +} + +//nolint:unused +type textTemplate struct { + name string + *parse.Tree + *textTemplateCommon + leftDelim string + rightDelim string +} + +func ptr[T, P any](ptr *P) *T { + // https://pkg.go.dev/unsafe#Pointer + // (1) Conversion of a *T1 to Pointer to *T2. + // Provided that T2 is no larger than T1 and that the two share an equivalent memory layout, + // this conversion allows reinterpreting data of one type as data of another type. + return (*T)(unsafe.Pointer(ptr)) +} + +func newScopedTemplateSet(all *template.Template, name string) (*scopedTemplateSet, error) { + targetTmpl := all.Lookup(name) + if targetTmpl == nil { + return nil, fmt.Errorf("template %q not found", name) + } + if err := escapeTemplate(targetTmpl); err != nil { + return nil, fmt.Errorf("template %q has an error when escaping: %w", name, err) + } + + ts := &scopedTemplateSet{ + name: name, + htmlTemplates: map[string]*template.Template{}, + textTemplates: map[string]*texttemplate.Template{}, + } + + htmlTmpl := ptr[htmlTemplate](all) + textTmpl := htmlTmpl.text + textTmplPtr := ptr[textTemplate](textTmpl) + + textTmplPtr.muFuncs.Lock() + ts.execFuncs = map[string]reflect.Value{} + for k, v := range textTmplPtr.execFuncs { + ts.execFuncs[k] = v + } + textTmplPtr.muFuncs.Unlock() + + var collectTemplates func(nodes []parse.Node) + var collectErr error // only need to collect the one error + collectTemplates = func(nodes []parse.Node) { + for _, node := range nodes { + if node.Type() == parse.NodeTemplate { + nodeTemplate := node.(*parse.TemplateNode) + subName := nodeTemplate.Name + if ts.htmlTemplates[subName] == nil { + subTmpl := all.Lookup(subName) + if subTmpl == nil { + // HTML template will add some internal templates like "$delimDoubleQuote" into the text template + ts.textTemplates[subName] = textTmpl.Lookup(subName) + } else if subTmpl.Tree == nil || subTmpl.Tree.Root == nil { + collectErr = fmt.Errorf("template %q has no tree, it's usually caused by broken templates", subName) + } else { + ts.htmlTemplates[subName] = subTmpl + if err := escapeTemplate(subTmpl); err != nil { + collectErr = fmt.Errorf("template %q has an error when escaping: %w", subName, err) + return + } + collectTemplates(subTmpl.Tree.Root.Nodes) + } + } + } else if node.Type() == parse.NodeList { + nodeList := node.(*parse.ListNode) + collectTemplates(nodeList.Nodes) + } else if node.Type() == parse.NodeIf { + nodeIf := node.(*parse.IfNode) + collectTemplates(nodeIf.BranchNode.List.Nodes) + if nodeIf.BranchNode.ElseList != nil { + collectTemplates(nodeIf.BranchNode.ElseList.Nodes) + } + } else if node.Type() == parse.NodeRange { + nodeRange := node.(*parse.RangeNode) + collectTemplates(nodeRange.BranchNode.List.Nodes) + if nodeRange.BranchNode.ElseList != nil { + collectTemplates(nodeRange.BranchNode.ElseList.Nodes) + } + } else if node.Type() == parse.NodeWith { + nodeWith := node.(*parse.WithNode) + collectTemplates(nodeWith.BranchNode.List.Nodes) + if nodeWith.BranchNode.ElseList != nil { + collectTemplates(nodeWith.BranchNode.ElseList.Nodes) + } + } + } + } + ts.htmlTemplates[name] = targetTmpl + collectTemplates(targetTmpl.Tree.Root.Nodes) + return ts, collectErr +} + +func (ts *scopedTemplateSet) newExecutor(funcMap map[string]any) TemplateExecutor { + tmpl := texttemplate.New("") + tmplPtr := ptr[textTemplate](tmpl) + tmplPtr.execFuncs = map[string]reflect.Value{} + for k, v := range ts.execFuncs { + tmplPtr.execFuncs[k] = v + } + if funcMap != nil { + tmpl.Funcs(funcMap) + } + // after escapeTemplate, the html templates are also escaped text templates, so it could be added to the text template directly + for _, t := range ts.htmlTemplates { + _, _ = tmpl.AddParseTree(t.Name(), t.Tree) + } + for _, t := range ts.textTemplates { + _, _ = tmpl.AddParseTree(t.Name(), t.Tree) + } + + // now the text template has all necessary escaped templates, so we can safely execute, just like what the html template does + return tmpl.Lookup(ts.name) +} diff --git a/modules/templates/scopedtmpl/scopedtmpl_test.go b/modules/templates/scopedtmpl/scopedtmpl_test.go new file mode 100644 index 000000000..774b8c7d4 --- /dev/null +++ b/modules/templates/scopedtmpl/scopedtmpl_test.go @@ -0,0 +1,98 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package scopedtmpl + +import ( + "bytes" + "html/template" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestScopedTemplateSetFuncMap(t *testing.T) { + all := template.New("") + + all.Funcs(template.FuncMap{"CtxFunc": func(s string) string { + return "default" + }}) + + _, err := all.New("base").Parse(`{{CtxFunc "base"}}`) + assert.NoError(t, err) + + _, err = all.New("test").Parse(strings.TrimSpace(` +{{template "base"}} +{{CtxFunc "test"}} +{{template "base"}} +{{CtxFunc "test"}} +`)) + assert.NoError(t, err) + + ts, err := newScopedTemplateSet(all, "test") + assert.NoError(t, err) + + // try to use different CtxFunc to render concurrently + + funcMap1 := template.FuncMap{ + "CtxFunc": func(s string) string { + time.Sleep(100 * time.Millisecond) + return s + "1" + }, + } + + funcMap2 := template.FuncMap{ + "CtxFunc": func(s string) string { + time.Sleep(100 * time.Millisecond) + return s + "2" + }, + } + + out1 := bytes.Buffer{} + out2 := bytes.Buffer{} + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + err := ts.newExecutor(funcMap1).Execute(&out1, nil) + assert.NoError(t, err) + wg.Done() + }() + go func() { + err := ts.newExecutor(funcMap2).Execute(&out2, nil) + assert.NoError(t, err) + wg.Done() + }() + wg.Wait() + assert.Equal(t, "base1\ntest1\nbase1\ntest1", out1.String()) + assert.Equal(t, "base2\ntest2\nbase2\ntest2", out2.String()) +} + +func TestScopedTemplateSetEscape(t *testing.T) { + all := template.New("") + _, err := all.New("base").Parse(`{{.text}}`) + assert.NoError(t, err) + + _, err = all.New("test").Parse(`{{template "base" .}}
`) + assert.NoError(t, err) + + ts, err := newScopedTemplateSet(all, "test") + assert.NoError(t, err) + + out := bytes.Buffer{} + err = ts.newExecutor(nil).Execute(&out, map[string]string{"param": "/", "text": "<"}) + assert.NoError(t, err) + + assert.Equal(t, `<`, out.String()) +} + +func TestScopedTemplateSetUnsafe(t *testing.T) { + all := template.New("") + _, err := all.New("test").Parse(``) + assert.NoError(t, err) + + _, err = newScopedTemplateSet(all, "test") + assert.ErrorContains(t, err, "appears in an ambiguous context within a URL") +} diff --git a/modules/test/context_tests.go b/modules/test/context_tests.go index e558bf1b9..4af765125 100644 --- a/modules/test/context_tests.go +++ b/modules/test/context_tests.go @@ -5,7 +5,6 @@ package test import ( scontext "context" - "html/template" "io" "net/http" "net/http/httptest" @@ -18,6 +17,7 @@ import ( user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/git" + "code.gitea.io/gitea/modules/templates" "code.gitea.io/gitea/modules/translation" "code.gitea.io/gitea/modules/web/middleware" @@ -120,7 +120,7 @@ func (rw *mockResponseWriter) Push(target string, opts *http.PushOptions) error type mockRender struct{} -func (tr *mockRender) TemplateLookup(tmpl string) (*template.Template, error) { +func (tr *mockRender) TemplateLookup(tmpl string) (templates.TemplateExecutor, error) { return nil, nil }