From 660b89fbb4b492f878b92133dbb57fec7a1af551 Mon Sep 17 00:00:00 2001 From: Chrstopher Hunter <8398225+crhntr@users.noreply.github.com> Date: Thu, 29 Aug 2024 11:07:52 -0700 Subject: [PATCH] refactor: method nil checks --- generate.go | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/generate.go b/generate.go index 8980da2..cf6aeda 100644 --- a/generate.go +++ b/generate.go @@ -129,7 +129,7 @@ func (def TemplateName) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStm } func (def TemplateName) funcLit(method *ast.FuncType, files []*ast.File) (*ast.FuncLit, []*ast.ImportSpec, error) { - if def.handler == "" { + if method == nil { return def.httpRequestReceiverTemplateHandlerFunc(), nil, nil } lit := &ast.FuncLit{ @@ -137,18 +137,16 @@ func (def TemplateName) funcLit(method *ast.FuncType, files []*ast.File) (*ast.F Body: &ast.BlockStmt{}, } call := &ast.CallExpr{Fun: callReceiverMethod(def.fun)} + if method.Params.NumFields() != len(def.call.Args) { + return nil, nil, errWrongNumberOfArguments(def, method) + } var formStruct *ast.StructType - if method != nil { - if method.Params.NumFields() != len(def.call.Args) { - return nil, nil, errWrongNumberOfArguments(def, method) + for pi, pt := range fieldListTypes(method.Params) { + if err := checkArgument(method, pi, def.call.Args[pi], pt, files); err != nil { + return nil, nil, err } - for pi, pt := range fieldListTypes(method.Params) { - if err := checkArgument(method, pi, def.call.Args[pi], pt, files); err != nil { - return nil, nil, err - } - if s, ok := findFormStruct(pt, files); ok { - formStruct = s - } + if s, ok := findFormStruct(pt, files); ok { + formStruct = s } } const errVarIdent = "err" @@ -258,7 +256,7 @@ func (def TemplateName) funcLit(method *ast.FuncType, files []*ast.File) (*ast.F } const dataVarIdent = "data" - if method != nil && len(method.Results.List) > 1 { + if len(method.Results.List) > 1 { errVar := ast.NewIdent("err") lit.Body.List = append(lit.Body.List,