Refactor rules

Using refactor rules to programmatically rewrite programs via AST

The refactor library

I've heard of the 'rope' library but it seems somewhat abandoned and does not cover this use case.

Python core developer Batuhan Taskaya has a library for "fragmental" AST refactors, named refactor (see docs for an overview).

After using it I'd describe the interface as a combination of autoflake and black (if you're familiar with those tools).

refactor is an end-to-end refactoring framework that is built on top of the 'simple but effective refactorings' assumption...

Every refactoring rule offers a single entrypoint, match(), where they accept an AST node (from the ast module in the standard library) and respond with either returning an action to refactor or nothing. If the rule succeeds on the input, then the returned action will build a replacement node and refactor will simply replace the code segment that belong to the input with the new version.

The most important module to understand is the core, which uses nice dataclass patterns to create Actions that have Rules in Sessions:

The source code for an Action shows that it unparses the AST node output from its build() method as the replacement, and replaces the list of lines in a sliced view with the replacement lines (elegant solution taking into account indentation and mid-line nodes).

@dataclass
class Action:
    """Base class for all actions.
    Override the `build()` method to programmatically build
    the replacement nodes.
    """

    node: ast.AST

    def apply(self, context: Context, source: str) -> str:
        """Refactor a source segment in the given string."""
        lines = split_lines(source)
        view = slice(self.node.lineno - 1, self.node.end_lineno)

        target_lines = lines[view]
        indentation, start_prefix = find_indent(
            target_lines[0][: self.node.col_offset]
        )
        end_prefix = target_lines[-1][self.node.end_col_offset :]

        replacement = split_lines(context.unparse(self.build()))
        replacement.apply_indentation(
            indentation, start_prefix=start_prefix, end_suffix=end_prefix
        )

        lines[view] = replacement
        return lines.join()

    def build(self) -> ast.AST:
        """Create the replacement node."""
        raise NotImplementedError

    def branch(self) -> ast.AST:
        """Return a copy view of the original node."""
        return copy.deepcopy(self.node)

The core module defines two subclasses of Action:

These don't quite fit the use cases of the rules sketched out in the previous section:

  1. An ast.NamedExpr walrus assignment ⇒ a preceding ast.Assign.
  2. An ast.Assignment from an ast.IfExp ternary value ⇒ an ast.If if/else statement with 2 ast.Assign statements (in body and orelse).
  3. Unnamed ast.Return value + name ⇒ ast.Return by id (named value)
  4. An ast.Compare if block with an ast.Return in the body ⇒ an ast.orelse clause for everything that follows.

The 1st rule is the reverse of NewStatementAction, as it involves prepending an AST node.

The 2nd rule seems most straightforward to define as 'replacement', though the target in ReplacementAction must be built from the source node, not an invariant/predefined node (we can do this in the Rule).

Likewise, the 4th is not a straight exchange (as "the rest of the function body" is not a single node target but a sequence of one or more). Again, we can define this in the Rule.

The 1st and 4th seem almost symmetrical (one pulls a node out of the expression so it moves 'above', the other pulls a node into the expression so it moves from 'below').

However, reviewing the source code for these actions gives an idea of how to go about all of them (except the 3rd, which is more of a comparison — I'll return to that as the final piece of the puzzle).

@dataclass
class ReplacementAction(Action):
    """An action for replacing the `node` with
    the given `target` node."""

    node: ast.AST
    target: ast.AST

    def build(self) -> ast.AST:
        return self.target


class NewStatementAction(Action):
    """An action base for adding a new statement right after
    the given `node`."""

    def apply(self, context: Context, source: str) -> str:
        lines = split_lines(source)

        start_line = lines[self.node.lineno - 1]
        indentation, start_prefix = find_indent(
            start_line[: self.node.col_offset]
        )

        replacement = split_lines(context.unparse(self.build()))
        replacement.apply_indentation(indentation, start_prefix=start_prefix)

        end_line = cast(int, self.node.end_lineno)
        for line in reversed(replacement):
            lines.insert(end_line, line)

        return lines.join()


class TargetedNewStatementAction(ReplacementAction, NewStatementAction):
    """An action for appending the given `target` node
    right after the `node`."""

The ReplacementAction build()s its target, while the NewStatementAction redefines apply() to insert() the output of build() after the source node's lines, instead of substituting the source node's lines view for the replacement from build(). The TargetedNewStatementAction combines these two components using multiple inheritance.

refactor.Action in action

Since the 2nd rule looks most straightforward, let's start by writing a refactor.Rule subclass that gives us the desired behaviour it describes:

  1. An ast.Assignment from an ast.IfExp ternary value ⇒ an ast.If if/else statement with 2 ast.Assign statements (in body and orelse).

To test this rule out on its own I'll write the simplest possible test case, which means (unfortunately) yet another version of our function, with only the ternary assignment (as if rule 1 had already been done on the ternary handler). We'll build it up with the other parts later.

def ternary_assigned_handler(event: dict) -> int | None:
    payload = event.get("payload")
    output = None if payload is None else len(payload)
    return output

As before, we can pretty-print its simplified form to aid debugging.

Click to show the pprinted node this rule targets
{'targets': [{'id': 'output', '__type__': 'Name'}],
 'value': {'test': {'left': {'id': 'payload', '__type__': 'Name'},
                    'ops': [{'__type__': 'Is'}],
                    'comparators': [{'value': None,
                                     'kind': None,
                                     '__type__': 'Constant'}],
                    '__type__': 'Compare'},
           'body': {'value': None, 'kind': None, '__type__': 'Constant'},
           'orelse': {'func': {'id': 'len', '__type__': 'Name'},
                      'args': [{'id': 'payload', '__type__': 'Name'}],
                      '__type__': 'Call'},
           '__type__': 'IfExp'},
 '__type__': 'Assign'}

After a couple false starts (I actually thought I'd have to abandon this entire exercise!) it turned out to be much simpler than I expected.

import ast

from refactor import ReplacementAction, Rule, run


class ReplaceTernaryAssignment(Rule):
    def match(self, node):
        assert isinstance(node, ast.Assign)
        assert isinstance(node.value, ast.IfExp)
        replacement = ast.If(
            test=ast.Compare(
                left=node.value.test.left,
                ops=node.value.test.ops,
                comparators=node.value.test.comparators,
            ),
            body=[
                ast.Assign(
                    targets=node.targets,
                    value=node.value.body,
                    lineno=None,
                )
            ],
            orelse=[
                ast.Assign(
                    targets=node.targets,
                    value=node.value.orelse,
                    lineno=None,
                )
            ],
        )
        return ReplacementAction(node, replacement)


if __name__ == "__main__":
    run(rules=[ReplaceTernaryAssignment])

Running this script on the command line (from the ast-rewriting working directory):

python refactor_rule_2.py ../early-returns/event_handler/ternary_assigned_handler.py

You get a nice readable diff that confirms the result was as expected: we have our rule 2!

--- ../early-returns/event_handler/ternary_assigned_handler.py

+++ ../early-returns/event_handler/ternary_assigned_handler.py

@@ -3,5 +3,8 @@


 def ternary_assigned_handler(event: dict) -> int | None:
     payload = event.get("payload")
-    output = None if payload is None else len(payload)
+    if payload is None:
+        output = None
+    else:
+        output = len(payload)
     return output
All done!
1 file reformatted

Note: the file wasn't actually changed by this, it's a preview of what would happen. This is because the refactor.runner parser's --apply flag defaults to False.

Passing -a/--apply on the command line will rewrite the file (and not print out the diff):

cp ../early-returns/event_handler/ternary_assigned_handler.py ./linearised_handler.py
python refactor_rule_2.py -a linearised_handler.py

reformatted linearised_handler.py
All done!
1 file reformatted

For clarity I renamed the function to match its module name, linearised_handler:

from __future__ import annotations


def linearised_handler(event: dict) -> int | None:
    payload = event.get("payload")
    if payload is None:
        output = None
    else:
        output = len(payload)
    return output

It's clear here that if we can generate the diff for a refactor without actually overwriting the file, then we can generate as many consecutive diffs as we want to 'preview' a series of refactors, and ultimately compare their collective outcome with the refactor we did to verify its correctness.

That said, it's not so clear why you wouldn't just automate refactors if you established a sufficiently confident routine for verifying their correctness.


Since we did the 2nd rule, it makes sense to do the 1st rule next so that we have half our transformation completed.

  1. An ast.NamedExpr walrus assignment ⇒ a preceding ast.Assign.

This rule pulls out the walrus assignment in the ternary_handler function:

def ternary_handler(event: dict) -> int | None:
    return None if (payload := event.get("payload")) is None else len(payload)

into a separate assignment, which we'll call ternary_walrusless_handler.

def ternary_walrusless_handler(event: dict) -> int | None:
    payload = event.get("payload")
    return None if payload is None else len(payload)

Again, we can handle all of the logic in a custom Rule and pass the appropriate parts into the ReplacementAction without needing to write our own action.

It's worth noting here that the replacement is a list of two AST nodes: these are applied iteratively (that is, recursively) until no more nodes can be found that match() the assertions.

The 3rd assertion says that the left side of the ternary condition in the return statement must include a 'walrus' assignment, or named expression. If this 3rd assertion is removed, then the rule will match both the node both before and after the replacements are applied.

import ast

from refactor import ReplacementAction, Rule, run


class ReplaceTernaryWalrusAssignment(Rule):
    def match(self, node):
        assert isinstance(node, ast.Return)
        assert isinstance(node.value, ast.IfExp)
        assert isinstance(node.value.test.left, ast.NamedExpr)
        replacement = [
            ast.Assign(
                targets=[node.value.test.left.target],
                value=node.value.test.left.value,
                lineno=None,
            ),
            ast.Return(
                value=ast.IfExp(
                    test=ast.Compare(
                        left=node.value.test.left.target,
                        ops=node.value.test.ops,
                        comparators=node.value.test.comparators,
                    ),
                    body=node.value.body,
                    orelse=node.value.orelse,
                )
            ),
        ]
        return ReplacementAction(node, replacement)


if __name__ == "__main__":
    run(rules=[ReplaceTernaryWalrusAssignment])

et voila, rule 1 is complete:

--- ../early-returns/event_handler/ternary_handler.py

+++ ../early-returns/event_handler/ternary_handler.py

@@ -2,4 +2,5 @@



 def ternary_handler(event: dict) -> int | None:
-    return None if (payload := event.get("payload")) is None else len(payload)
+    payload = event.get("payload")
+    return None if payload is None else len(payload)
All done!
1 file reformatted

This is almost the source for the ternary_assigned_handler function, except we have an anonymous return value instead of one named output.


Since we had to use two replacement nodes, our intuition to write rule 2 first, with just one replacement node was a good choice!

The next logical step would be to proceed onto step 3, but this was the awkward not-really-a-rule rule...

  1. Unnamed ast.Return value + name ⇒ ast.Return by id (named value)

So if rule 1 took the ternary_handler function into the ternary_walrusless_handler function by pulling the assignment out of the named expression in the return value, and rule 2 expanded the remaining ternary condition in ternary_assigned_handler into an if/else block which we then called linearised_handler, then rule 3 takes the linearised if/else block function with an unnamed return value and gives it a name.

Awkwardly, I implemented this as the bridge between steps 1 and 2, so we should rename rule 3 as rule 1.5... That'd get too messy so we'll leave it as is for now, this is a much simpler rule to write and we can integrate the 4 rules together later.

We can specify the convention that the default name for a return value produced from an unnamed return value will be output. Since this rule should operate after rule 1, we can use it on the output of that, which was ternary_walrusless_handler.

def ternary_walrusless_handler(event: dict) -> int | None:
    payload = event.get("payload")
    return None if payload is None else len(payload)

This rule is much simpler: the return statement becomes an assignment (to the default name ID), and the return value becomes that name.

import ast

from refactor import ReplacementAction, Rule, run


class ReplaceAnonRetVal(Rule):
    def match(self, node):
        assert isinstance(node, ast.Return)
        assert isinstance(node.value, ast.IfExp)
        replacement = [
            ast.Assign(
                targets=[ast.Name(id="output")],
                value=node.value,
                lineno=None,
            ),
            ast.Return(value=ast.Name(id="output")),
        ]
        return ReplacementAction(node, replacement)


if __name__ == "__main__":
    run(rules=[ReplaceAnonRetVal])

1.5'th time's the charm: the output matches the function we referred to as ternary_assigned_handler (as the input to rule 2).

python refactor_rule_3.py ../early-returns/event_handler/ternary_walrusless_handler.py 

--- ../early-returns/event_handler/ternary_walrusless_handler.py

+++ ../early-returns/event_handler/ternary_walrusless_handler.py

@@ -3,4 +3,5 @@


 def ternary_walrusless_handler(event: dict) -> int | None:
     payload = event.get("payload")
-    return None if payload is None else len(payload)
+    output = None if payload is None else len(payload)
+    return output
All done!
1 file reformatted

Note that by this point I got the hang of the fastest way to write these, which is rapidly flicking between calls to ast.parse() on the code string you're trying to target/ast.unparse() on the AST node/tree you're working with, and help() on the ast class you're trying to instantiate. This workflow brought the time to write a rule down by more than I care to admit!

Which just leaves us with the final rule, rule 4, which must operate on the output of rule 2 (since we mislabelled rule 3 which should be 1.5), the linearised_handler function:

def linearised_handler(event: dict) -> int | None:
    payload = event.get("payload")
    if payload is None:
        output = None
    else:
        output = len(payload)
    return output

The final rule to implement is:

  1. An ast.Compare if block with an ast.Return in the body ⇒ an ast.orelse clause for everything that follows.

...but since we moved rule 3 backwards accidentally, it turns out we avoided the need for this!

Rewriting early returns with 3 refactor.Rules

We're now ready to implement a refactor.Rule routine that transforms the guarded handler into a linear handler, via an intermediate ternary form. We've shown that all are equivalent, and we need just 3 rules to convert them programmatically (in doing so verifying their equivalence).

It'd be nice to just chain them one after the other, and we've already seen that they act recursively, so we should be able to set them up in a chain reaction like a chemical process.

The numbers weren't much help to us anyway, and since our rules run in series, we can just use the list argument passed to run() to order them: as

[ReplaceTernaryWalrusAssignment, ReplaceAnonRetVal, ReplaceTernaryAssignment]

Our full program to refactor all in one go is therefore:

import ast

from refactor import ReplacementAction, Rule, run


class ReplaceTernaryWalrusAssignment(Rule):
    def match(self, node):
        assert isinstance(node, ast.Return)
        assert isinstance(node.value, ast.IfExp)
        assert isinstance(node.value.test.left, ast.NamedExpr)
        replacement = [
            ast.Assign(
                targets=[node.value.test.left.target],
                value=node.value.test.left.value,
                lineno=None,
            ),
            ast.Return(
                value=ast.IfExp(
                    test=ast.Compare(
                        left=node.value.test.left.target,
                        ops=node.value.test.ops,
                        comparators=node.value.test.comparators,
                    ),
                    body=node.value.body,
                    orelse=node.value.orelse,
                )
            ),
        ]
        return ReplacementAction(node, replacement)


class ReplaceAnonRetVal(Rule):
    def match(self, node):
        assert isinstance(node, ast.Return)
        assert isinstance(node.value, ast.IfExp)
        replacement = [
            ast.Assign(
                targets=[ast.Name(id="output")],
                value=node.value,
                lineno=None,
            ),
            ast.Return(value=ast.Name(id="output")),
        ]
        return ReplacementAction(node, replacement)


class ReplaceTernaryAssignment(Rule):
    def match(self, node):
        assert isinstance(node, ast.Assign)
        assert isinstance(node.value, ast.IfExp)
        replacement = ast.If(
            test=ast.Compare(
                left=node.value.test.left,
                ops=node.value.test.ops,
                comparators=node.value.test.comparators,
            ),
            body=[
                ast.Assign(
                    targets=node.targets,
                    value=node.value.body,
                    lineno=None,
                )
            ],
            orelse=[
                ast.Assign(
                    targets=node.targets,
                    value=node.value.orelse,
                    lineno=None,
                )
            ],
        )
        return ReplacementAction(node, replacement)


if __name__ == "__main__":
    run(rules=[ReplaceTernaryWalrusAssignment, ReplaceTernaryAssignment, ReplaceAnonRetVal])

...and when we run them:

python refactor_rules_all.py ../early-returns/event_handler/ternary_handler.py

--- ../early-returns/event_handler/ternary_handler.py

+++ ../early-returns/event_handler/ternary_handler.py

@@ -2,4 +2,9 @@



 def ternary_handler(event: dict) -> int | None:
-    return None if (payload := event.get("payload")) is None else len(payload)
+    payload = event.get("payload")
+    if payload is None:
+        output = None
+    else:
+        output = len(payload)
+    return output
All done!
1 file reformatted

We're done!

A schema for refactor rules

When reviewing the result here, and considering how unintuitive it was to eyeball each step here, I grasped that the step taken in the previous section on AST rewriting by the function simplify() which we used to 'pretty print' the ASTs actually had a less ornamental purpose.

There's something crucial to how lineno=None was passed to the calls to ast.Assign in the rules we wrote in this section.

If I were to summarise what I think this crucial idea is: the proper representation for AST comparisons ought to be nested dicts of minimal kwargs per node.

On the basis of this realisation I would suggest that a much better way to write these AST rules is to translate the rules above out of Python code and into a JSON schema (more specifically a recursive schema, since AST rules may specify arbitrarily deep replacements etc.).

For example, rule 1 changed x = "foo" if y is None else "bar" into

if y is None:
    x = "foo"
else:
    x = "bar"

and we called this the ReplaceTernaryAssignment rule:

class ReplaceTernaryAssignment(Rule):
    def match(self, node):
        assert isinstance(node, ast.Assign)
        assert isinstance(node.value, ast.IfExp)
        replacement = ast.If(
            test=ast.Compare(
                left=node.value.test.left,
                ops=node.value.test.ops,
                comparators=node.value.test.comparators,
            ),
            body=[
                ast.Assign(
                    targets=node.targets,
                    value=node.value.body,
                    lineno=None,
                )
            ],
            orelse=[
                ast.Assign(
                    targets=node.targets,
                    value=node.value.orelse,
                    lineno=None,
                )
            ],
        )
        return ReplacementAction(node, replacement)

The structure of this rule is two parts:

You can of course have other actions, but as shown above the ReplacementAction was sufficient to implement the entire pattern in this case.

This could be rewritten as a JSON schema:

def rel(qualname: str, node: type[ast.stmt]):
    return functools.reduce(getattr, qualname.split("."), node)

Since a function call cannot be a JSON key, the precondition of rel(qualname="", node=ast.Assign) is encoded as a dict {"rel": {"": {"ast": "Assign"}}} with rel, reltype and ast becoming reserved words in the schema (i.e. we define namespaces on them when used as keys).

{
    "preconditions": {
        "reltype": {
            "": "ast.Assign",
            "value": "ast.IfExp",
        }
    },
    "replacement": {
         "ast.If": {
             "test": {
                 "ast.Compare": {
                     "left": "value.test.left",
                     "ops": "value.test.ops",
                     "comparators": "value.test.comparators",
                 }
             },
             "body": [
                 {
                     "ast.Assign": {
                         "targets": "targets",
                         "value": "value.body",
                     }
                 }
             ],
             "orelse": [
                 {
                     "ast.Assign": {
                         "targets": "targets",
                         "value": "value.orelse",
                     }
                 }
             ],
        }
    }
}

Once written out this way there's immediately a clearer separation of concerns, and the self-similarity stands out more clearly (so much so you could introduce shorthand formats).

{
    "preconditions": {
        "reltype": {
            "": "ast.Assign",
            "value": "ast.IfExp",
        }
    },
    "replacement": {
         "ast.If": {
             "test": {
                 "ast.Compare": {
                     "left": "value.@2.@1",
                     "ops": "value.@2.@1",
                     "comparators": "value.@2.@1",
                 }
             },
             "body": [
                 {
                     "ast.Assign": {
                         "targets": "@1",
                         "value": "@1.@2",
                     }
                 }
             ],
             "orelse": [
                 {
                     "ast.Assign": {
                         "targets": "@1",
                         "value": "@1.@2",
                     }
                 }
             ],
        }
    }
}

which once deserialised from JSON would store the Python standard library objects and function calls:

{
    "preconditions": {
        "reltype": {
            "": ast.Assign,
            "value": ast.IfExp,
        }
    },
    "replacement": {
         ast.If: {
             "test": {
                 ast.Compare: {
                     "left": rel("value.@2.@1"),
                     "ops": rel("value.@2.@1"),
                     "comparators": rel("value.@2.@1"),
                 }
             },
             "body": [
                 {
                     ast.Assign: {
                         "targets": rel("@1"),
                         "value": rel("@1.@2"),
                     }
                 }
             ],
             "orelse": [
                 {
                     ast.Assign: {
                         "targets": rel("@1"),
                         "value": rel("@1.@2"),
                     }
                 }
             ],
        }
    }
}

For an even more compact representation you could alias the rel calls in the ast.Compare node and the entire ast.Assign node

{
    "aliases": {
        "!1": "value.@2.@1",
        "!2": [
            {
                "ast.Assign": {
                    "targets": "@1",
                    "value": "@1.@2",
               }
            }
        ]
    },
    "preconditions": {
        "reltype": {
            "": "ast.Assign",
            "value": "ast.IfExp",
        }
    },
    "replacement": {
         "ast.If": {
             "test": {
                 "ast.Compare": {
                     "left": "!1",
                     "ops": "!1",
                     "comparators": "!1",
                 }
             },
             "body": "!2",
             "orelse": "!2",
        }
    }
}

A representation like this could be much easier to read and write (depending on the level at which it's exposed to a user of course), and more compact. The aliasing ought be automated away.

The namespace representation of the AST classes and relative paths to the target node rely on the assumptions that:

I'd also use the $ symbol to 'switch off' the interpolation expected from rel in these qualnames, and that way the DSL could be open to use with actual string literals in future (for example to introduce variable names you'd want to pass a string literal for the name ID, which could be used as ${myvariablename} since of course a qualname cannot contain $). The reverse would not be workable as of course a string literal can contain $.

With this clarified, we can translate the ReplaceAnonRetVal rule, which notably contains the string literal "output" which is introduced as the variable name for the previously anonymous return value:

class ReplaceAnonRetVal(Rule):
    def match(self, node):
        assert isinstance(node, ast.Return)
        assert isinstance(node.value, ast.IfExp)
        replacement = [
            ast.Assign(
                targets=[ast.Name(id="output")],
                value=node.value,
                lineno=None,
            ),
            ast.Return(value=ast.Name(id="output")),
        ]
        return ReplacementAction(node, replacement)

into:

{
    "aliases": {
        "!1": {"ast.Name": {"id": "${output}"}},
        "!2": "ast.Return",
    },
    "preconditions": {
        "reltype": {
            "": "!2",
            "value": "ast.IfExp",
        }
    },
    "replacement": [
        {
            "ast.Assign": {
                "targets": ["!1"],
                "value": "@1",
            }
        },
        {
            "!2": {
                "value": "!1",
            }
        },
    ]
}

Which is an undeniably concise representation of the pattern, and more importantly shows where information is shared within the pattern (in this case the aliases are the 'key information').

We can type annotate this schema to then validate it, the code for which I've put in a separate library refactory.

A few further requirements were clarified:

and for the time being the depth of a replacement was limited to 2 AST nodes deep (as my goal in writing it was to illustrate this example case, and it turned out recursive types don't seem to work with pydantic validation).

Somewhat inconsistently perhaps, I made the aliases arbitrarily deep (meaning a replacement can be arbitrarily deep too by using an alias).

>>> import refactory
>>> rule_spec = refactory.patterns.early_returns.ReplaceAnonRetVal
>>> rs = refactory.load_spec(rule_spec)
>>> rs.aliases
{!1: {astName: {id**: `${output}`}}, !2: astReturn}
>>> rs.preconditions
Preconditions(reltype={:: !2, :value: astIfExp})
>>> rs.replacement
[{astAssign: {targets**: [!1], value**: @1}}, {!2: {value**: !1}}]

Alternatively:

>>> import refactory
>>> rule_spec = refactory.patterns.early_returns.ReplaceAnonRetVal
>>> rs = refactory.load_spec(rule_spec)
>>> rs.aliases
{?1: {ast.Name: {id**: "output"}}, ?2: ast.Return}
>>> rs.preconditions
Preconditions(reltype={:: ?2, :value: ast.IfExp})
>>> rs.replacement
[{ast.Assign: {targets**: [?1], value**: @1}}, {?2: {value**: ?1}}]

This post is the 4th of a series on Refactor verification, investigating how to verify the correctness of refactors (or automating the human error away). In the next section, we'll leave the cosy embrace of these toy examples and venture off to the real world with some case studies.