Skip to content

Warning

This page is under construction. The content may be incomplete or incorrect. Submit an issue on GitHub if you need help or want to contribute.

The Function Dialect

The function dialect provides a set of statements to model semantics of Python-like functions, that means:

  • def <name>(<args>*, <kwargs>*) like function declarations
  • nested functions (namely closures)
  • high-order functions (functions can be used as arguments)
  • dynamically/statically calling a function or closure

func.Return

This is a simple statement that models the return statement in a function declaration. While this is a very simple statement, it is worth noting that this statement only accepts one argument of type ir.SSAValue because in Python (and most of other languages) functions always have a single return value, and multiple return values are represented by returning a tuple.

func.Function

This is the most fundamental statement that models a Python function.

Definition The func.Function takes no arguments, but contains a special str attribute (thus stored as PyAttr) that can be used as a symbolic reference within a symbol table. The func.Function also takes a func.Signature attribute to store the signature of corresponding function declaration. Last, it contains a ir.Region that represents the function body. The ir.Region follows the SSACFG convention where the blocks in the region forms a control flow graph.

Differences with MLIR

As Kirin's priority is writing eDSL as kernel functions in Python. To support high-order functions the entry block arguments always have their first argument self of type [types.MethodType][kirin.types.MethodType]. This is a design inspired by Julia's IR design.

As an example, the following Python function

from kirin.prelude import basic_no_opt

@basic_no_opt
def main(x):
    return x

will be lowered into the following IR, where main_self referencing the function itself.

func.func main(!Any) -> !Any {
  ^0(%main_self, %x):
  │ func.return %x
} // func.func main

the function can be terminated by a func.Return statement. All blocks in the function region must have terminators. In the lowering process, if the block is not terminated, a func.Return will be attached to return None in the function body. Thus func.Function can only have a single return value.

func.Call and func.Invoke

These two statements models the most common call convention in Python with consideration of compilation:

  • func.Call models dynamic calls where the callee is unknown at compile time, thus of type ir.SSAValue
  • func.Invoke models static calls where the callee is known at compile time, thus of type ir.Method

they both take inputs which is a tuple of ir.SSAValue as argument. Because we assume all functions will only return a single value, func.Call and func.Invoke only have a single result.

func.Lambda

This statement models nested functions (a.k.a closures). While most definitions are similar to func.Function the key difference is func.Lambda takes a tuple of ir.SSAValue arguments as captured. This models the captured variables for a nested function, e.g

the following Python function containing a closure inside with variable x being captured:

from kirin import basic_no_opt

@basic_no_opt
def main(x):
    def closure():
        return x
    return closure

will be lowered into

func.func main(!Any) -> !Any {
  ^0(%main_self, %x):
  │ %closure = func.lambda closure(%x) -> !Any {
  │            │ ^1(%closure_self):
  │            │ │ %x_1 = func.getfield(%closure_self, 0) : !Any
  │            │ │        func.return %x_1
  │            } // func.lambda closure
  │            func.return %closure
} // func.func main

Unlike func.Function this statement also has a result value which points to the closure itself. Inside the closure body, we insert func.GetField to unpack captured variables into the closure body.

API Reference

Call kirin-statement

Call(
    callee: SSAValue,
    inputs: tuple[SSAValue, ...],
    *,
    kwargs: tuple[str, ...] = (),
    purity: bool
)

Bases: Statement

callee kirin-argument

callee: SSAValue = argument()

inputs kirin-argument

inputs: tuple[SSAValue, ...] = argument()

kwargs kirin-attribute kw-only

kwargs: tuple[str, ...] = attribute(
    default_factory=lambda: ()
)

name class-attribute instance-attribute

name = 'call'

purity kirin-attribute kw-only

purity: bool = attribute(default=False)

result kirin-result

result: ResultValue = result()

traits class-attribute instance-attribute

traits = frozenset({MaybePure()})

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
84
85
def print_impl(self, printer: Printer) -> None:
    pprint_calllike(self, printer.state.ssa_id[self.callee], printer)

ConstantNone kirin-statement

ConstantNone()

Bases: Statement

A constant None value.

This is mainly used to represent the None return value of a function to match Python semantics.

name class-attribute instance-attribute

name = 'const.none'

result kirin-result

result: ResultValue = result(NoneType)

traits class-attribute instance-attribute

traits = frozenset({Pure(), ConstantLike()})

FuncOpCallableInterface dataclass

FuncOpCallableInterface()

Bases: CallableStmtInterface['Function']

get_callable_region classmethod

get_callable_region(stmt: Function) -> Region

Returns the body of the callable region

Source code in src/kirin/dialects/func/stmts.py
30
31
32
@classmethod
def get_callable_region(cls, stmt: "Function") -> Region:
    return stmt.body

Function kirin-statement

Function(
    *, sym_name: str, signature: Signature, body: Region
)

Bases: Statement

body kirin-region kw-only

body: Region = region(multi=True)

name class-attribute instance-attribute

name = 'func'

signature kirin-attribute kw-only

signature: Signature = attribute()

sym_name kirin-attribute kw-only

sym_name: str = attribute()

The symbol name of the function.

traits class-attribute instance-attribute

traits = frozenset(
    {
        IsolatedFromAbove(),
        SymbolOpInterface(),
        HasSignature(),
        FuncOpCallableInterface(),
        SSACFGRegion(),
    }
)

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def print_impl(self, printer: Printer) -> None:
    with printer.rich(style="keyword"):
        printer.print_name(self)
        printer.plain_print(" ")

    with printer.rich(style="symbol"):
        printer.plain_print(self.sym_name)

    printer.print_seq(self.signature.inputs, prefix="(", suffix=")", delim=", ")

    with printer.rich(style="comment"):
        printer.plain_print(" -> ")
        printer.print(self.signature.output)
        printer.plain_print(" ")

    printer.print(self.body)

    with printer.rich(style="comment"):
        printer.plain_print(f" // func.func {self.sym_name}")

GetField kirin-statement

GetField(obj: SSAValue, *, field: int)

Bases: Statement

field kirin-attribute kw-only

field: int = attribute()

name class-attribute instance-attribute

name = 'getfield'

obj kirin-argument

obj: SSAValue = argument(MethodType)

result kirin-result

result: ResultValue = result(init=False)

traits class-attribute instance-attribute

traits = frozenset({Pure()})

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
198
199
200
201
202
203
204
205
def print_impl(self, printer: Printer) -> None:
    printer.print_name(self)
    printer.plain_print(
        "(", printer.state.ssa_id[self.obj], ", ", str(self.field), ")"
    )
    with printer.rich(style="black"):
        printer.plain_print(" : ")
        printer.print(self.result.type)

Invoke kirin-statement

Invoke(
    inputs: tuple[SSAValue, ...],
    *,
    callee: Method,
    kwargs: tuple[str, ...],
    purity: bool
)

Bases: Statement

callee kirin-attribute kw-only

callee: Method = attribute()

inputs kirin-argument

inputs: tuple[SSAValue, ...] = argument()

kwargs kirin-attribute kw-only

kwargs: tuple[str, ...] = attribute()

name class-attribute instance-attribute

name = 'invoke'

purity kirin-attribute kw-only

purity: bool = attribute(default=False)

result kirin-result

result: ResultValue = result()

traits class-attribute instance-attribute

traits = frozenset({MaybePure()})

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
218
219
def print_impl(self, printer: Printer) -> None:
    pprint_calllike(self, self.callee.sym_name, printer)

verify

verify() -> None

run mandatory validation checks. This is not same as typecheck, which may be optional.

Source code in src/kirin/dialects/func/stmts.py
221
222
223
224
225
226
227
228
229
230
231
232
233
def verify(self) -> None:
    if self.kwargs:
        for name in self.kwargs:
            if name not in self.callee.arg_names:
                raise VerificationError(
                    self,
                    f"method {self.callee.sym_name} does not have argument {name}",
                )
    elif len(self.callee.arg_names) - 1 != len(self.args):
        raise VerificationError(
            self,
            f"expected {len(self.callee.arg_names)} arguments, got {len(self.args)}",
        )

Lambda kirin-statement

Lambda(
    captured: tuple[SSAValue, ...],
    *,
    sym_name: str,
    signature: Signature,
    body: Region
)

Bases: Statement

body kirin-region kw-only

body: Region = region(multi=True)

captured kirin-argument

captured: tuple[SSAValue, ...] = argument()

name class-attribute instance-attribute

name = 'lambda'

result kirin-result

result: ResultValue = result(MethodType)

signature kirin-attribute kw-only

signature: Signature = attribute()

sym_name kirin-attribute kw-only

sym_name: str = attribute()

traits class-attribute instance-attribute

traits = frozenset(
    {
        Pure(),
        HasSignature(),
        SymbolOpInterface(),
        FuncOpCallableInterface(),
        SSACFGRegion(),
    }
)

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def print_impl(self, printer: Printer) -> None:
    with printer.rich(style="keyword"):
        printer.print_name(self)
    printer.plain_print(" ")

    with printer.rich(style="symbol"):
        printer.plain_print(self.sym_name)

    printer.print_seq(self.captured, prefix="(", suffix=")", delim=", ")

    with printer.rich(style="bright_black"):
        printer.plain_print(" -> ")
        printer.print(self.signature.output)

    printer.plain_print(" ")
    printer.print(self.body)

    with printer.rich(style="black"):
        printer.plain_print(f" // func.lambda {self.sym_name}")

verify

verify() -> None

run mandatory validation checks. This is not same as typecheck, which may be optional.

Source code in src/kirin/dialects/func/stmts.py
164
165
166
def verify(self) -> None:
    if self.body.blocks.isempty():
        raise VerificationError(self, "lambda body must not be empty")

Return kirin-statement

Return(value_or_stmt: SSAValue | Statement | None = None)

Bases: Statement

Source code in src/kirin/dialects/func/stmts.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(self, value_or_stmt: SSAValue | Statement | None = None) -> None:
    if isinstance(value_or_stmt, SSAValue):
        args = [value_or_stmt]
    elif isinstance(value_or_stmt, Statement):
        if len(value_or_stmt._results) == 1:
            args = [value_or_stmt._results[0]]
        else:
            raise ValueError(
                f"expected a single result, got {len(value_or_stmt._results)} results from {value_or_stmt.name}"
            )
    elif value_or_stmt is None:
        args = []
    else:
        raise ValueError(f"expected SSAValue or Statement, got {value_or_stmt}")

    super().__init__(args=args, args_slice={"value": 0})

name class-attribute instance-attribute

name = 'return'

traits class-attribute instance-attribute

traits = frozenset({IsTerminator(), HasParent((Function,))})

value kirin-argument

value: SSAValue = argument()

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
124
125
126
127
128
129
130
def print_impl(self, printer: Printer) -> None:
    with printer.rich(style="keyword"):
        printer.print_name(self)

    if self.args:
        printer.plain_print(" ")
        printer.print_seq(self.args, delim=", ")

verify

verify() -> None

run mandatory validation checks. This is not same as typecheck, which may be optional.

Source code in src/kirin/dialects/func/stmts.py
132
133
134
135
136
137
138
139
140
141
142
143
def verify(self) -> None:
    if not self.args:
        raise VerificationError(
            self, "return statement must have at least one value"
        )

    if len(self.args) > 1:
        raise VerificationError(
            self,
            "return statement must have at most one value"
            ", wrap multiple values in a tuple",
        )