Skip to content

Warning

This page is under construction. The content may be incomplete or incorrect. Submit an issue on GitHub if you need help or want to contribute.

SCF Dialects

The structured control flow (SCF) dialect is a dialect we adopt from the MLIR project with modifications to better fit the semantics of Python. This page will explain the SCF dialects semantics and how they are used.

scf.Yield

The scf.Yield statement is used to mark the end of a block and yield to the region parent. It is used in the following way, for example with scf.if statement:

%value_1 = scf.if %cond {
    // body
    scf.yield %value
} else {
    // body
    scf.yield %value
}

scf.Yield marks that the %value will be returned to the parent statement as its result. Unlike MLIR, most of the Kirin scf dialect can also terminate with func.Return statement to make things easier to lower from Python.

scf.If

The scf.If statement is used to conditionally execute a block of code. It is used in the following way:

scf.if %cond {
    // body
} else {
    // body
}

Definition The scf.If statement can have a cond argument, a then_body region with single block, and optionally a else_body with single block. The then_body block is executed if the condition is true, and the else_body block is executed if the condition is false.

Termination then_body must terminate with scf.Yield or func.Return statement. else_body is optional and can be omitted. If one of the body terminates with scf.Yield the other body must terminate explicitly with scf.Yield or func.Return.

scf.For

The scf.For statement is used to iterate over a range of values. It is used in the following way:

def simple_loop():
    j = 0.0
    for i in range(10):
        j = j + i
    return j

lowered to the following IR:

func.func simple_loop() -> !Any {
  ^0(%simple_loop_self):
     %j = py.constant.constant 0.0
     %0 = py.constant.constant IList(range(0, 10))
   %j_1 = py.constant.constant 45.0
   %j_2 = scf.for %i in %0
           iter_args(%j_3 = %j) {
           %j_4 = py.binop.add(%j_3, %i)
                  scf.yield %j_4
          }
          func.return %j_1
} // func.func simple_loop

Definition The scf.For statement takes an iterable as an argument.

Note

Unlike MLIR where the loop iterable is restricted to a step range, Kirin allows any Python iterable object to be used as a loop iterable by marking this iterable argument as ir.types.Any. While it can be any Python iterable object, the actual loop compilation can only happen if the iterable type is known and supported by the compiler implementation.

scf.For can also take an optional initializers tuple of values that are used to initialize the loop variables (printed as right-hand side of the iter_args field).

Termination The loop body must terminate with scf.Yield or func.Return statement.

Scoping The loop body creates a new scope. As a result of this, any variables defined inside the loop body are not accessible outside the loop body unless they are explicitly yielded.

Known difference with Python for loop

The scf.For statement does not follow exactly the same semantics as Python for loop. This difference is due to the context difference of compilation vs. interpretation. Like many other compiled languages, the loop body introduces a new scope and the loop variable is not accessible outside the loop body, e.g the following code will error in Julia:

function simple_loop()
    for i in 1:10
        j = j + i
        if j > 5
            return j
        end
    end
    return j
end
will error with UndefVarError:

julia> simple_loop()
    ERROR: UndefVarError: `j` not defined in local scope
    Suggestion: check for an assignment to a local variable that shadows a global of the same name.
    Stacktrace:
        [1] simple_loop()
        @ Main ./REPL[1]:3
        [2] top-level scope
        @ REPL[2]:1

However, in Python this code will work due to the fact that interpreter will not actually create a new scope for the loop body:

def simple_loop():
    for i in range(10):
        j = j + i
        if j == 5:
            return j
    return j # will refer to the j defined in the loop body

Reference

For kirin-statement

For(
    iterable: ir.SSAValue,
    body: ir.Region,
    *initializers: ir.SSAValue
)

Bases: Statement

Source code in src/kirin/dialects/scf/stmts.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def __init__(
    self,
    iterable: ir.SSAValue,
    body: ir.Region,
    *initializers: ir.SSAValue,
):
    stmt = body.blocks[0].last_stmt
    if isinstance(stmt, Yield):
        result_types = tuple(value.type for value in stmt.values)
    else:
        result_types = ()
    super().__init__(
        args=(iterable, *initializers),
        regions=(body,),
        result_types=result_types,
        args_slice={"iterable": 0, "initializers": slice(1, None)},
        attributes={"purity": ir.PyAttr(False)},
    )

body kirin-region kw-only

body: Region = region(multi=False)

initializers kirin-argument

initializers: tuple[SSAValue, ...] = argument(Any)

iterable kirin-argument

iterable: SSAValue = argument(Any)

name class-attribute instance-attribute

name = 'for'

purity kirin-attribute kw-only

purity: bool = attribute(default=False)

traits class-attribute instance-attribute

traits = frozenset({MaybePure(), SSACFGRegion()})

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/scf/stmts.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def print_impl(self, printer: Printer) -> None:
    printer.print_name(self)
    printer.plain_print(" ")
    block = self.body.blocks[0]
    printer.print(block.args[0])
    printer.plain_print(" in ", style="keyword")
    printer.print(self.iterable)
    if self.results:
        with printer.rich(style="comment"):
            printer.plain_print(" -> ")
            printer.print_seq(
                tuple(result.type for result in self.results),
                delim=", ",
                style="comment",
            )

    with printer.indent():
        if self.initializers:
            printer.print_newline()
            printer.plain_print("iter_args(")
            for idx, (arg, val) in enumerate(
                zip(block.args[1:], self.initializers)
            ):
                printer.print(arg)
                printer.plain_print(" = ")
                printer.print(val)
                if idx < len(self.initializers) - 1:
                    printer.plain_print(", ")
            printer.plain_print(")")

        printer.plain_print(" {")
        if printer.analysis is not None:
            with printer.rich(style="warning"):
                for arg in block.args:
                    printer.print_newline()
                    printer.print_analysis(
                        arg, prefix=f"{printer.state.ssa_id[arg]} --> "
                    )
        with printer.align(printer.result_width(block.stmts)):
            for stmt in block.stmts:
                printer.print_newline()
                printer.print_stmt(stmt)
    printer.print_newline()
    printer.plain_print("}")
    with printer.rich(style="comment"):
        printer.plain_print(f" -> purity={self.purity}")

verify

verify() -> None

run mandatory validation checks. This is not same as typecheck, which may be optional.

Source code in src/kirin/dialects/scf/stmts.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def verify(self) -> None:
    from kirin.dialects.func import Return

    if len(self.body.blocks) != 1:
        raise VerificationError(self, "for loop body must have a single block")

    if len(self.body.blocks[0].args) != len(self.initializers) + 1:
        raise VerificationError(
            self,
            "for loop body must have arguments for all initializers and the loop variable",
        )

    stmt = self.body.blocks[0].last_stmt
    if stmt is None or not isinstance(stmt, (Yield, Return)):
        raise VerificationError(
            self, "for loop body must terminate with a yield or return"
        )

    if isinstance(stmt, Return):
        return

    if len(stmt.values) != len(self.initializers):
        raise VerificationError(
            self,
            "for loop body must have the same number of results as initializers",
        )
    if len(self.results) != len(stmt.values):
        raise VerificationError(
            self,
            "for loop must have the same number of results as the yield in the body",
        )

IfElse kirin-statement

IfElse(
    cond: ir.SSAValue,
    then_body: ir.Region | ir.Block,
    else_body: ir.Region | ir.Block | None = None,
)

Bases: Statement

Python-like if-else statement.

This statement has a condition, then body, and else body.

Then body either terminates with a yield statement or scf.return.

Source code in src/kirin/dialects/scf/stmts.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self,
    cond: ir.SSAValue,
    then_body: ir.Region | ir.Block,
    else_body: ir.Region | ir.Block | None = None,
):
    if isinstance(then_body, ir.Region):
        then_body_region = then_body
        if then_body_region.blocks:
            then_body_block = then_body_region.blocks[-1]
        else:
            then_body_block = None
    elif isinstance(then_body, ir.Block):
        then_body_block = then_body
        then_body_region = ir.Region(then_body)

    if isinstance(else_body, ir.Region):
        if not else_body.blocks:  # empty region
            else_body_region = else_body
            else_body_block = None
        elif len(else_body.blocks) == 0:
            else_body_region = else_body
            else_body_block = None
        else:
            else_body_region = else_body
            else_body_block = else_body_region.blocks[0]
    elif isinstance(else_body, ir.Block):
        else_body_region = ir.Region(else_body)
        else_body_block = else_body
    else:
        else_body_region = ir.Region()
        else_body_block = None

    # if either then or else body has yield, we generate results
    # we assume if both have yields, they have the same number of results
    results = ()
    if then_body_block is not None:
        then_yield = then_body_block.last_stmt
        else_yield = (
            else_body_block.last_stmt if else_body_block is not None else None
        )
        if then_yield is not None and isinstance(then_yield, Yield):
            results = then_yield.values
        elif else_yield is not None and isinstance(else_yield, Yield):
            results = else_yield.values

    result_types = tuple(value.type for value in results)
    super().__init__(
        args=(cond,),
        regions=(then_body_region, else_body_region),
        result_types=result_types,
        args_slice={"cond": 0},
        attributes={"purity": ir.PyAttr(False)},
    )

cond kirin-argument

cond: SSAValue = argument(Any)

else_body kirin-region kw-only

else_body: Region = region(
    multi=False, default_factory=Region
)

name class-attribute instance-attribute

name = 'if'

purity kirin-attribute kw-only

purity: bool = attribute(default=False)

then_body kirin-region kw-only

then_body: Region = region(multi=False)

traits class-attribute instance-attribute

traits = frozenset({MaybePure(), SSACFGRegion()})

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/scf/stmts.py
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def print_impl(self, printer: Printer) -> None:
    printer.print_name(self)
    printer.plain_print(" ")
    printer.print(self.cond)
    printer.plain_print(" ")
    printer.print(self.then_body)
    if self.else_body.blocks and not (
        len(self.else_body.blocks[0].stmts) == 1
        and isinstance(else_term := self.else_body.blocks[0].last_stmt, Yield)
        and not else_term.values  # empty yield
    ):
        printer.plain_print(" else ", style="keyword")
        printer.print(self.else_body)

    with printer.rich(style="comment"):
        printer.plain_print(f" -> purity={self.purity}")

verify

verify() -> None

run mandatory validation checks. This is not same as typecheck, which may be optional.

Source code in src/kirin/dialects/scf/stmts.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def verify(self) -> None:
    from kirin.dialects.func import Return

    if len(self.then_body.blocks) != 1:
        raise VerificationError(self, "then region must have a single block")

    if len(self.else_body.blocks) != 1:
        raise VerificationError(self, "else region must have a single block")

    then_block = self.then_body.blocks[0]
    else_block = self.else_body.blocks[0]
    if len(then_block.args) != 1:
        raise VerificationError(
            self, "then block must have a single argument for condition"
        )

    if len(else_block.args) != 1:
        raise VerificationError(
            self, "else block must have a single argument for condition"
        )

    then_stmt = then_block.last_stmt
    else_stmt = else_block.last_stmt
    if then_stmt is None or not isinstance(then_stmt, (Yield, Return)):
        raise VerificationError(
            self, "then block must terminate with a yield or return"
        )

    if else_stmt is None or not isinstance(else_stmt, (Yield, Return)):
        raise VerificationError(
            self, "else block must terminate with a yield or return"
        )

Yield kirin-statement

Yield(*values: ir.SSAValue)

Bases: Statement

Source code in src/kirin/dialects/scf/stmts.py
248
249
def __init__(self, *values: ir.SSAValue):
    super().__init__(args=values, args_slice={"values": slice(None)})

name class-attribute instance-attribute

name = 'yield'

traits class-attribute instance-attribute

traits = frozenset({IsTerminator()})

values kirin-argument

values: tuple[SSAValue, ...] = argument(Any)

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/scf/stmts.py
251
252
253
def print_impl(self, printer: Printer) -> None:
    printer.print_name(self)
    printer.print_seq(self.values, prefix=" ", delim=", ")