Go module-aware AST Rewriting

Bottom line, upfront

In this blog post, you will learn what AST rewriting is, and how to write your own simple Go module aware AST-rewriting tool.

Linting - A starting place

Code linters are a popular option to detect problems in code bases. Linters like gometalinter and golangci-lint provide a huge number of generic checks, for all types of issues.

Whilst these generic linters are very useful, organizations often want to enforce specific code standards within their own codebase, which are not necessarily relevant or applicable externally. Examples of this may be; enforcing a standardized logging format or preventing libraries from being used. This is where writing custom linter come in.

Go lends itself well to writing custom linters as its standard library comes with many of the essential packages needed for doing so. More recently Go provides go/analysis which can be used with modules and provides ways to build linters in a composable way.

There are a number of excellent blog posts on writing custom linters in Go, so I won't be covering how to use go/analysis or the like, but they are well worth reading.

AST Rewrite - The problem

AST rewriting is a natural extension to linting. If you have a linter that can identify problems, the next step is having a tool which automatically fixes the issues it flags. This is quite common in the javascript/typescript world where tools like tslint and eslint accepts a --fix flag which automatically fixes the flagged issues. However,this seems to be less common in Go, although it looks like support is coming in go/analysis.

In this post, I will describe how to write a simple tool which rewrites log.Printf statements to a Printf function which takes a context.Context as its first argument. The behaviour we want is if there is a context.Context within the scope of the Printf call, the tool should pass in that context, otherwise, it should create a context.TODO(). This tool will not deal with re-nameing the imports, but that can be done by replacing the "log" import with one which imports from your own package with the import alias "log".

So given the following code:

func FunctionWithCtx (ctx context.Context) {  
    log.Printf("This function has a context")
}

func FunctionWithoutCtx () {  
    log.Printf("This function doesn't have a context")
}

The tool should be replace it with:

func FunctionWithCtx (ctx context.Context) {  
    log.Printf(ctx, "This function has a context")
}

func FunctionWithoutCtx () {  
    log.Printf(context.TODO(), "This function doesn't have a context")
}

Simple search and replace is not smart enough for this, as the tool needs to have knowledge of the function scopes and the types of a function. Additionally, whilst the logic is simple, this can quickly become tedious on even a moderately sized codebase.

Tools for the Tool

Whilst there are other posts on AST rewriting in general, the tool we will write will focus on rewriting in a go module environment. The tool will be based on two packages:

The first is github.com/fatih/astrewrite which is almost universally used in Go AST rewriting. In astrewrite the programmer implements a normal AST visitor, except that it can return a new or modified AST node. This means that once you have identified nodes of interest, you can simply modify it as needed, and write the AST back out to a file.

The second package needed is golang.org/x/tools/go/packages. This package provides the module aware type information. This is essential if you want to know the types of the variables and functions that are in the AST. Without this, we wouldn't be able to tell if the ctx in scope is a context.Context or something else named ctx.

There is also a very honourable mention for the GoAst Viewer, which is extremely valuable in figuring out the shape of the AST for some given Go code. Without this tool, it is very hard to know what type of AST node you are looking for, and what it should be rewritten to.

At a high level, the tool will work as follows:

  • Use go/packages to load package files and type information
  • Feed the AST to astrewrite along with the type information
  • Walk the AST and identify all function scopes which contain a context.Context
  • Walk the AST and find all log.Printf statements.
    • If log.Printf statement is within a function which has a context in its scope, replace with: log.Printf(ctx, ...), otherwise replace with log.Printf(context.TODO(), ...)

Step One - Main

We start with the main function taking the package path as the only argument. If you are running the tool in your current directory, you will normally pass in "./...". The package path is passed into packages.Load which does all the module aware magic. The returned value is an array packages.Package. You can see more from the documentation here but the TLDR; is that it contains all the data that we need to rewrite the AST. Specifically, it contains a Syntax field, which provides the fill AST for a file, and the TypesInfo field which contains type information about each AST node.

package main

import (  
    "os"

    "github.com/pkg/errors"
    "golang.org/x/tools/go/packages"
)

func main() {  
    err := load(os.Args[1])
    if err != nil {
        panic(err)
    }
}

func load(packageName string) error {  
    config := &packages.Config{
        Mode: packages.LoadSyntax,
    }
    pkgs, err := packages.Load(config, packageName)
    if err != nil {
        return errors.Wrapf(err, "Error loading package %s", packageName)
    }

    for _, pkg := range pkgs {
        rewriter := &Rewriter{pkg: pkg}
        err := rewriter.Rewrite()
        if err != nil {
            return err
        }
    }
    return nil
}

Step Two - Rewriter Struct

Now we have this information, now its time to turn our attention to the rewriter. This struct will be responsible for rewriting all instances of log.Printf in a single package. For every package from packages.Load, we will initialize a Rewriter and call the Rewrite method.

In the Rewrite method, we iterate over the pkg.GoFiles Field, which contains a file path for every .go file in the package and use pkg.Syntax[i] to get the AST for the file. This AST can be passed to astrewrite.Walk along with a visitor function. This visitor function will be called for every node in the AST, and provides us a hook into modifying the AST in a manageable way.

type scope struct {  
    s *types.Scope
    i *ast.Ident
}

type Rewriter struct {  
    pkg                 *packages.Package
    contextScopes       []*scope
    fileHasLogReWritten bool
}

func (r *Rewriter) Rewrite() error {  
    for i, _ := range r.pkg.GoFiles {
        astrewrite.Walk(r.pkg.Syntax[i],r.visitForContextFuncs)
    }
    return nil
}

Step Three - Finding Contexts

The first part of the rewriting is to identify all the functions where context.Context is taken as an argument. When we find a function which takes a context, we will record the scope of the context in the contextScopes on the Rewriter struct. This will allow us to later identify easily in a second step if a log statement has a context.Context in its scope.

visitForContextFuncs is the Visitor function statisifying astrewrite.WalkFunc. It checks if the node is a function declaration, appends to contextScopes if it is determined that the function has a context in scope.

func (r *Rewriter) visitForContextFuncs(n ast.Node) (ast.Node, bool) {  
    switch v := n.(type) {
    case *ast.FuncDecl:
        if scope := r.checkHasContext(v); scope != nil {
            r.contextScopes = append(r.contextScopes, scope)
        }
    }
    return n, true
}

checkHasContext checks if the function has a context, and if so creates a scope struct. This scope scruct contains a pointer to types.Scope as well as ast.Ident both of which will be useful later.

func (r *Rewriter) checkHasContext(node *ast.FuncDecl) *scope {  
    if id, ok := hasContextParam(node.Type.Params); ok {
        funcScope := r.pkg.TypesInfo.Scopes[node.Type]
        return &scope{s: funcScope, i: id}
    }
    return nil
}

hasContextParam takes an AST node representing the params of a function. For each param, it checks the type name with getParamName and returns the ast.Ident of the name of the parameter if it is found to be of type context.Context.

func hasContextParam(params *ast.FieldList) (*ast.Ident, bool) {  
    if len(params.List) > 0 {
        for _, param := range params.List {
            paramTypeName := getParamName(param.Type)
            if paramTypeName == "context.Context" {
                if len(param.Names) > 0 {
                    return param.Names[0], true
                }
            }
        }
    }

    // no params
    return nil, false
}

func getParamName(n ast.Expr) string {  
    s, ok := n.(*ast.SelectorExpr)
    if !ok {
        return ""
    }

    i, ok := s.X.(*ast.Ident)
    if !ok {
        return ""
    }
    return i.Name + "." + s.Sel.Name
}

Step Four - Rewrite Log statements!

At this point, you should be able to run the program and it should populate the contextScopes array. Now we need to use this information to rewrite the logs. We now need a visitor which does the following:

  • Checks if AST node is a function call
  • If it is a function call, check if its to log.Printf
  • If it is:
    • And ctx is present change to: log.Printf(ctx)
    • If ctx is not in scope, change to log.Printf(context.TODO())

The following code implements this behaviour:

  • visitForLogStatements is a visitor function which will be called for every node. It checks if the node is a ast.CallExpr which represents a function call. If the node is a function call, it calls checkHasLogStatement to determine if the function call is to the log.Printf function.

  • checkHasLogStatement uses type information from go/packages to determine the function being called. It first checks the name of the function, and if the name is "Printf" checks the package path. In this case, we care about the "Printf" function from the "log" package. If it finds that the function call is the one we are looking for, it returns the AST id and a boolean to indicate it is correct.

  • Once checkHasLogStatement returns, visitForLogStatements checks the return value and if true, it moves on to check if the function has a context.Context within scope. This is the final step to determine what the Printf function is being rewritten to.

  • hasContextInScope iterates through the previously found scopes which have a context, and checks if the node is found within it. If it is, we know that we can pass a context into the Printf function. If it is not, we know we have to create a context.TODO().

  • Finally rewriteLog is called with the node, and a valid scope or nil. If it is called with a nil scope, we create a new ast.CallExpr which represents a call to context.TODO(). If the scope is not nil, we use the AST Id for the context variable that we previously found and stored in our own scope struct. In either case, we prepend the arguments list with the new AST node that we made and return the modified node.

func (r *Rewriter) visitForLogStatements(n ast.Node) (ast.Node, bool) {  
    switch v := n.(type) {
    case *ast.CallExpr:
        if id, ok := r.checkHasLogStatement(v); ok {
            r.fileHasLogReWritten = true
            if scope, ok := r.hasContextInScope(id); ok {
                return r.rewriteLog(v, scope), true
            } else {
                return r.rewriteLog(v, nil), true
            }
        }
    }
    return n, true
}

func (r *Rewriter) checkHasLogStatement(n *ast.CallExpr) (*ast.Ident, bool) {  
    s, ok := n.Fun.(*ast.SelectorExpr)
    if !ok {
        return nil, false
    }

    id := s.Sel
    if id != nil && !r.pkg.TypesInfo.Types[id].IsType() {
        if id.Name != "Printf" {
            return nil, false
        }
        use, ok := r.pkg.TypesInfo.Uses[id]
        if ok {
            if use.Pkg().Path() == "log" {
                return id, true
            }
        }
    }

    return nil, false
}

func (r *Rewriter) hasContextInScope(id *ast.Ident) (*scope, bool) {  
    for _, s := range r.contextScopes {
        if s.s.Contains(id.Pos()) {
            return s, true
        }
    }
    return nil, false
}

func (r *Rewriter) rewriteLog(n *ast.CallExpr, scope *scope) ast.Node {  
    var context ast.Expr
    if scope != nil {
        context = scope.i
    } else {
        context = &ast.CallExpr{
            Fun: &ast.SelectorExpr{
                X: &ast.Ident{
                    Name: "context",
                },
                Sel: &ast.Ident{
                    Name: "TODO",
                },
            },
        }
    }
    newArgs := []ast.Expr{context}
    newArgs = append(newArgs, n.Args...)
    n.Args = newArgs
    return n
}



... One more thing

Right now the code should compile, but if you try it on one of your own go packages, it doesn't actually do anything! That is because the final thing we need to do is actually call our visitForLogStatements visitor method and write our new AST back to the input file. Luckily for us, this only requires some small modifications to our (r *Rewriter) Rewrite() method:

func (r *Rewriter) Rewrite() error {  
    for i, file := range r.pkg.GoFiles {
        astrewrite.Walk(r.pkg.Syntax[i], r.visitForContextFuncs)
        rewritten := astrewrite.Walk(r.pkg.Syntax[i], r.visitForLogStatements)
                // rewritten is the new AST
        var buf bytes.Buffer
                //  printer is from the "go/printer" package
        printer.Fprint(&buf, r.pkg.Fset, rewritten)
        ioutil.WriteFile(file, buf.Bytes(), 0644)
    }
    return nil
}

That's it! If you build and run the binary, you should have a go module aware AST re-write tool! The full code can be found on github