Skip to content

Warning

This page is under construction. The content may be incomplete or incorrect. Submit an issue on GitHub if you need help or want to contribute.

The Function Dialect

The function dialect provides a set of statements to model semantics of Python-like functions, that means:

  • def <name>(<args>*, <kwargs>*) like function declarations
  • nested functions (namely closures)
  • high-order functions (functions can be used as arguments)
  • dynamically/statically calling a function or closure

func.Return

This is a simple statement that models the return statement in a function declaration. While this is a very simple statement, it is worth noting that this statement only accepts one argument of type ir.SSAValue because in Python (and most of other languages) functions always have a single return value, and multiple return values are represented by returning a tuple.

func.Function

This is the most fundamental statement that models a Python function.

Definition The func.Function takes no arguments, but contains a special str attribute (thus stored as PyAttr) that can be used as a symbolic reference within a symbol table. The func.Function also takes a func.Signature attribute to store the signature of corresponding function declaration. Last, it contains a ir.Region that represents the function body. The ir.Region follows the SSACFG convention where the blocks in the region forms a control flow graph.

Differences with MLIR

As Kirin's priority is writing eDSL as kernel functions in Python. To support high-order functions the entry block arguments always have their first argument self of type [types.MethodType][kirin.types.MethodType]. This is a design inspired by Julia's IR design.

As an example, the following Python function

from kirin.prelude import basic_no_opt

@basic_no_opt
def main(x):
    return x

will be lowered into the following IR, where main_self referencing the function itself.

func.func main(!Any) -> !Any {
  ^0(%main_self, %x):
  │ func.return %x
} // func.func main

the function can be terminated by a func.Return statement. All blocks in the function region must have terminators. In the lowering process, if the block is not terminated, a func.Return will be attached to return None in the function body. Thus func.Function can only have a single return value.

func.Call and func.Invoke

These two statements models the most common call convention in Python with consideration of compilation:

  • func.Call models dynamic calls where the callee is unknown at compile time, thus of type ir.SSAValue
  • func.Invoke models static calls where the callee is known at compile time, thus of type ir.Method

they both take inputs which is a tuple of ir.SSAValue as argument. Because we assume all functions will only return a single value, func.Call and func.Invoke only have a single result.

func.Lambda

This statement models nested functions (a.k.a closures). While most definitions are similar to func.Function the key difference is func.Lambda takes a tuple of ir.SSAValue arguments as captured. This models the captured variables for a nested function, e.g

the following Python function containing a closure inside with variable x being captured:

from kirin import basic_no_opt

@basic_no_opt
def main(x):
    def closure():
        return x
    return closure

will be lowered into

func.func main(!Any) -> !Any {
  ^0(%main_self, %x):
  │ %closure = func.lambda closure(%x) -> !Any {
  │            │ ^1(%closure_self):
  │            │ │ %x_1 = func.getfield(%closure_self, 0) : !Any
  │            │ │        func.return %x_1
  │            } // func.lambda closure
  │            func.return %closure
} // func.func main

Unlike func.Function this statement also has a result value which points to the closure itself. Inside the closure body, we insert [func.GetField][kirin.dialects.func.GetField] to unpack captured variables into the closure body.

API Reference

Call dataclass

Call(
    *,
    args: Sequence[SSAValue] = (),
    regions: Sequence[Region] = (),
    successors: Sequence[Block] = (),
    attributes: Mapping[str, Attribute] = {},
    results: Sequence[ResultValue] = (),
    result_types: Sequence[TypeAttribute] = (),
    args_slice: Mapping[str, int | slice] = {},
    source: SourceInfo | None = None
)

Bases: Statement


              flowchart TD
              kirin.dialects.func.stmts.Call[Call]
              kirin.ir.nodes.stmt.Statement[Statement]
              kirin.ir.nodes.base.IRNode[IRNode]
              kirin.print.printable.Printable[Printable]

                              kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.Call
                                kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
                                kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
                




              click kirin.dialects.func.stmts.Call href "" "kirin.dialects.func.stmts.Call"
              click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
              click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
              click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
            
Source code in src/kirin/ir/nodes/stmt.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
def __init__(
    self,
    *,
    args: Sequence[SSAValue] = (),
    regions: Sequence[Region] = (),
    successors: Sequence[Block] = (),
    attributes: Mapping[str, Attribute] = {},
    results: Sequence[ResultValue] = (),
    result_types: Sequence[TypeAttribute] = (),
    args_slice: Mapping[str, int | slice] = {},
    source: SourceInfo | None = None,
) -> None:
    super().__init__()
    """Initialize the Statement.

    Args:
        arsg (Sequence[SSAValue], optional): The arguments of the Statement. Defaults to ().
        regions (Sequence[Region], optional): The regions where the Statement belong to. Defaults to ().
        successors (Sequence[Block], optional): The successors of the Statement. Defaults to ().
        attributes (Mapping[str, Attribute], optional): The attributes of the Statement. Defaults to {}.
        results (Sequence[ResultValue], optional): The result values of the Statement. Defaults to ().
        result_types (Sequence[TypeAttribute], optional): The result types of the Statement. Defaults to ().
        args_slice (Mapping[str, int | slice], optional): The arguments slice of the Statement. Defaults to {}.
        source (SourceInfo | None, optional): The source information of the Statement for debugging/stacktracing. Defaults to None.

    """
    self._args = ()
    self._regions = []
    self._name_args_slice = dict(args_slice)
    self.source = source
    self.args = args

    if results:
        self._results = list(results)
        assert (
            len(result_types) == 0
        ), "expect either results or result_types specified, got both"

    if result_types:
        self._results = [
            ResultValue(self, idx, type=type)
            for idx, type in enumerate(result_types)
        ]

    if not results and not result_types:
        self._results = list(results)

    self.successors = list(successors)
    self.attributes = dict(attributes)
    self.regions = list(regions)

    self.parent = None
    self._next_stmt = None
    self._prev_stmt = None
    self.__post_init__()

callee class-attribute instance-attribute

callee: SSAValue = argument()

inputs class-attribute instance-attribute

inputs: tuple[SSAValue, ...] = argument()

keys class-attribute instance-attribute

keys: tuple[str, ...] = attribute(default=())

kwargs class-attribute instance-attribute

kwargs: tuple[SSAValue, ...] = argument()

name class-attribute instance-attribute

name = 'call'

purity class-attribute instance-attribute

purity: bool = attribute(default=False)

result class-attribute instance-attribute

result: ResultValue = result()

traits class-attribute instance-attribute

traits = frozenset({MaybePure()})

check_type

check_type() -> None

Check the types of the Block. Raises Exception if the types are not correct. This method is called by the verify_type method, which will detect the source of the error in the IR. One should always call the verify_type method to verify the types of the IR.

Note

This method is generated by the @statement decorator. But can be overridden if needed.

Source code in src/kirin/dialects/func/stmts.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def check_type(self) -> None:
    if not self.callee.type.is_subseteq(types.MethodType):
        if self.callee.type.is_subseteq(types.PyClass(PyFunctionType)):
            raise ir.TypeCheckError(
                self,
                f"callee must be a method type, got {self.callee.type}",
                help="did you call a Python function directly? "
                "consider decorating it with kernel decorator",
            )

        if self.callee.type.is_subseteq(types.PyClass(PyClassMethodType)):
            raise ir.TypeCheckError(
                self,
                "callee must be a method type, got class method",
                help="did you try to call a Python class method within a kernel? "
                "consider rewriting it with a captured variable instead of calling it inside the kernel",
            )

        if self.callee.type is types.Any:
            return
        raise ir.TypeCheckError(
            self,
            f"callee must be a method type, got {self.callee.type}",
            help="did you forget to decorate the function with kernel decorator?",
        )

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def print_impl(self, printer: Printer) -> None:
    with printer.rich(style="red"):
        printer.print_name(self)
    printer.plain_print(" ")
    printer.plain_print(printer.state.ssa_id[self.callee])

    printer.plain_print("(")
    printer.print_seq(self.inputs, delim=", ")
    if self.kwargs and self.inputs:
        printer.plain_print(", ")

    kwargs = dict(zip(self.keys, self.kwargs))
    printer.print_mapping(kwargs, delim=", ")
    printer.plain_print(")")

    with printer.rich(style="comment"):
        printer.plain_print(" : ")
        printer.print_seq(
            [result.type for result in self._results],
            delim=", ",
        )
        printer.plain_print(f" maybe_pure={self.purity}")

ConstantNone dataclass

ConstantNone(
    *,
    args: Sequence[SSAValue] = (),
    regions: Sequence[Region] = (),
    successors: Sequence[Block] = (),
    attributes: Mapping[str, Attribute] = {},
    results: Sequence[ResultValue] = (),
    result_types: Sequence[TypeAttribute] = (),
    args_slice: Mapping[str, int | slice] = {},
    source: SourceInfo | None = None
)

Bases: Statement


              flowchart TD
              kirin.dialects.func.stmts.ConstantNone[ConstantNone]
              kirin.ir.nodes.stmt.Statement[Statement]
              kirin.ir.nodes.base.IRNode[IRNode]
              kirin.print.printable.Printable[Printable]

                              kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.ConstantNone
                                kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
                                kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
                




              click kirin.dialects.func.stmts.ConstantNone href "" "kirin.dialects.func.stmts.ConstantNone"
              click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
              click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
              click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
            

A constant None value.

This is mainly used to represent the None return value of a function to match Python semantics.

Source code in src/kirin/ir/nodes/stmt.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
def __init__(
    self,
    *,
    args: Sequence[SSAValue] = (),
    regions: Sequence[Region] = (),
    successors: Sequence[Block] = (),
    attributes: Mapping[str, Attribute] = {},
    results: Sequence[ResultValue] = (),
    result_types: Sequence[TypeAttribute] = (),
    args_slice: Mapping[str, int | slice] = {},
    source: SourceInfo | None = None,
) -> None:
    super().__init__()
    """Initialize the Statement.

    Args:
        arsg (Sequence[SSAValue], optional): The arguments of the Statement. Defaults to ().
        regions (Sequence[Region], optional): The regions where the Statement belong to. Defaults to ().
        successors (Sequence[Block], optional): The successors of the Statement. Defaults to ().
        attributes (Mapping[str, Attribute], optional): The attributes of the Statement. Defaults to {}.
        results (Sequence[ResultValue], optional): The result values of the Statement. Defaults to ().
        result_types (Sequence[TypeAttribute], optional): The result types of the Statement. Defaults to ().
        args_slice (Mapping[str, int | slice], optional): The arguments slice of the Statement. Defaults to {}.
        source (SourceInfo | None, optional): The source information of the Statement for debugging/stacktracing. Defaults to None.

    """
    self._args = ()
    self._regions = []
    self._name_args_slice = dict(args_slice)
    self.source = source
    self.args = args

    if results:
        self._results = list(results)
        assert (
            len(result_types) == 0
        ), "expect either results or result_types specified, got both"

    if result_types:
        self._results = [
            ResultValue(self, idx, type=type)
            for idx, type in enumerate(result_types)
        ]

    if not results and not result_types:
        self._results = list(results)

    self.successors = list(successors)
    self.attributes = dict(attributes)
    self.regions = list(regions)

    self.parent = None
    self._next_stmt = None
    self._prev_stmt = None
    self.__post_init__()

name class-attribute instance-attribute

name = 'const.none'

result class-attribute instance-attribute

result: ResultValue = result(NoneType)

traits class-attribute instance-attribute

traits = frozenset({Pure(), ConstantLike()})

FuncOpCallableInterface dataclass

FuncOpCallableInterface()

Bases: CallableStmtInterface['Function']


              flowchart TD
              kirin.dialects.func.stmts.FuncOpCallableInterface[FuncOpCallableInterface]
              kirin.ir.traits.callable.CallableStmtInterface[CallableStmtInterface]
              kirin.ir.traits.abc.StmtTrait[StmtTrait]
              kirin.ir.traits.abc.Trait[Trait]

                              kirin.ir.traits.callable.CallableStmtInterface --> kirin.dialects.func.stmts.FuncOpCallableInterface
                                kirin.ir.traits.abc.StmtTrait --> kirin.ir.traits.callable.CallableStmtInterface
                                kirin.ir.traits.abc.Trait --> kirin.ir.traits.abc.StmtTrait
                




              click kirin.dialects.func.stmts.FuncOpCallableInterface href "" "kirin.dialects.func.stmts.FuncOpCallableInterface"
              click kirin.ir.traits.callable.CallableStmtInterface href "" "kirin.ir.traits.callable.CallableStmtInterface"
              click kirin.ir.traits.abc.StmtTrait href "" "kirin.ir.traits.abc.StmtTrait"
              click kirin.ir.traits.abc.Trait href "" "kirin.ir.traits.abc.Trait"
            

ValueType class-attribute instance-attribute

ValueType = TypeVar('ValueType')

align_input_args classmethod

align_input_args(
    stmt: Function, *args: ValueType, **kwargs: ValueType
) -> tuple[ValueType, ...]

Permute the arguments and keyword arguments of the statement to match the execution order of the callable region input.

Source code in src/kirin/dialects/func/stmts.py
22
23
24
25
26
27
28
29
30
@classmethod
def align_input_args(
    cls, stmt: Function, *args: ValueType, **kwargs: ValueType
) -> tuple[ValueType, ...]:
    inputs = [*args]
    for name in stmt.slots:
        if name in kwargs:
            inputs.append(kwargs[name])
    return tuple(inputs)

get_callable_region classmethod

get_callable_region(stmt: 'Function') -> ir.Region

Returns the body of the callable region

Source code in src/kirin/dialects/func/stmts.py
16
17
18
@classmethod
def get_callable_region(cls, stmt: "Function") -> ir.Region:
    return stmt.body

Function dataclass

Function(
    *,
    args: Sequence[SSAValue] = (),
    regions: Sequence[Region] = (),
    successors: Sequence[Block] = (),
    attributes: Mapping[str, Attribute] = {},
    results: Sequence[ResultValue] = (),
    result_types: Sequence[TypeAttribute] = (),
    args_slice: Mapping[str, int | slice] = {},
    source: SourceInfo | None = None
)

Bases: Statement


              flowchart TD
              kirin.dialects.func.stmts.Function[Function]
              kirin.ir.nodes.stmt.Statement[Statement]
              kirin.ir.nodes.base.IRNode[IRNode]
              kirin.print.printable.Printable[Printable]

                              kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.Function
                                kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
                                kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
                




              click kirin.dialects.func.stmts.Function href "" "kirin.dialects.func.stmts.Function"
              click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
              click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
              click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
            
Source code in src/kirin/ir/nodes/stmt.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
def __init__(
    self,
    *,
    args: Sequence[SSAValue] = (),
    regions: Sequence[Region] = (),
    successors: Sequence[Block] = (),
    attributes: Mapping[str, Attribute] = {},
    results: Sequence[ResultValue] = (),
    result_types: Sequence[TypeAttribute] = (),
    args_slice: Mapping[str, int | slice] = {},
    source: SourceInfo | None = None,
) -> None:
    super().__init__()
    """Initialize the Statement.

    Args:
        arsg (Sequence[SSAValue], optional): The arguments of the Statement. Defaults to ().
        regions (Sequence[Region], optional): The regions where the Statement belong to. Defaults to ().
        successors (Sequence[Block], optional): The successors of the Statement. Defaults to ().
        attributes (Mapping[str, Attribute], optional): The attributes of the Statement. Defaults to {}.
        results (Sequence[ResultValue], optional): The result values of the Statement. Defaults to ().
        result_types (Sequence[TypeAttribute], optional): The result types of the Statement. Defaults to ().
        args_slice (Mapping[str, int | slice], optional): The arguments slice of the Statement. Defaults to {}.
        source (SourceInfo | None, optional): The source information of the Statement for debugging/stacktracing. Defaults to None.

    """
    self._args = ()
    self._regions = []
    self._name_args_slice = dict(args_slice)
    self.source = source
    self.args = args

    if results:
        self._results = list(results)
        assert (
            len(result_types) == 0
        ), "expect either results or result_types specified, got both"

    if result_types:
        self._results = [
            ResultValue(self, idx, type=type)
            for idx, type in enumerate(result_types)
        ]

    if not results and not result_types:
        self._results = list(results)

    self.successors = list(successors)
    self.attributes = dict(attributes)
    self.regions = list(regions)

    self.parent = None
    self._next_stmt = None
    self._prev_stmt = None
    self.__post_init__()

body class-attribute instance-attribute

body: Region = region(multi=True)

The body of the function.

name class-attribute instance-attribute

name = 'func'

result class-attribute instance-attribute

result: ResultValue = result(MethodType)

The result of the function.

signature class-attribute instance-attribute

signature: Signature = attribute()

The signature of the function at declaration.

slots class-attribute instance-attribute

slots: tuple[str, ...] = attribute(default=())

The argument names of the function.

sym_name class-attribute instance-attribute

sym_name: str = attribute()

The symbol name of the function.

traits class-attribute instance-attribute

traits = frozenset(
    {
        IsolatedFromAbove(),
        SymbolOpInterface(),
        HasSignature(),
        FuncOpCallableInterface(),
        HasCFG(),
        SSACFG(),
    }
)

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def print_impl(self, printer: Printer) -> None:
    with printer.rich(style="keyword"):
        printer.print_name(self)
        printer.plain_print(" ")

    with printer.rich(style="symbol"):
        printer.plain_print("@", self.sym_name)

    def print_arg(pair: tuple[str, types.TypeAttribute]):
        with printer.rich(style="symbol"):
            printer.plain_print(pair[0])
        with printer.rich(style="black"):
            printer.plain_print(" : ")
            printer.print(pair[1])

    printer.print_seq(
        zip(self.slots, self.signature.inputs),
        emit=print_arg,
        prefix="(",
        suffix=")",
        delim=", ",
    )

    with printer.rich(style="comment"):
        printer.plain_print(" -> ")
        printer.print(self.signature.output)
        printer.plain_print(" ")

    printer.print(self.body)

    with printer.rich(style="comment"):
        printer.plain_print(f" // func.func {self.sym_name}")

GetField dataclass

GetField(
    *,
    args: Sequence[SSAValue] = (),
    regions: Sequence[Region] = (),
    successors: Sequence[Block] = (),
    attributes: Mapping[str, Attribute] = {},
    results: Sequence[ResultValue] = (),
    result_types: Sequence[TypeAttribute] = (),
    args_slice: Mapping[str, int | slice] = {},
    source: SourceInfo | None = None
)

Bases: Statement


              flowchart TD
              kirin.dialects.func.stmts.GetField[GetField]
              kirin.ir.nodes.stmt.Statement[Statement]
              kirin.ir.nodes.base.IRNode[IRNode]
              kirin.print.printable.Printable[Printable]

                              kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.GetField
                                kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
                                kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
                




              click kirin.dialects.func.stmts.GetField href "" "kirin.dialects.func.stmts.GetField"
              click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
              click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
              click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
            
Source code in src/kirin/ir/nodes/stmt.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
def __init__(
    self,
    *,
    args: Sequence[SSAValue] = (),
    regions: Sequence[Region] = (),
    successors: Sequence[Block] = (),
    attributes: Mapping[str, Attribute] = {},
    results: Sequence[ResultValue] = (),
    result_types: Sequence[TypeAttribute] = (),
    args_slice: Mapping[str, int | slice] = {},
    source: SourceInfo | None = None,
) -> None:
    super().__init__()
    """Initialize the Statement.

    Args:
        arsg (Sequence[SSAValue], optional): The arguments of the Statement. Defaults to ().
        regions (Sequence[Region], optional): The regions where the Statement belong to. Defaults to ().
        successors (Sequence[Block], optional): The successors of the Statement. Defaults to ().
        attributes (Mapping[str, Attribute], optional): The attributes of the Statement. Defaults to {}.
        results (Sequence[ResultValue], optional): The result values of the Statement. Defaults to ().
        result_types (Sequence[TypeAttribute], optional): The result types of the Statement. Defaults to ().
        args_slice (Mapping[str, int | slice], optional): The arguments slice of the Statement. Defaults to {}.
        source (SourceInfo | None, optional): The source information of the Statement for debugging/stacktracing. Defaults to None.

    """
    self._args = ()
    self._regions = []
    self._name_args_slice = dict(args_slice)
    self.source = source
    self.args = args

    if results:
        self._results = list(results)
        assert (
            len(result_types) == 0
        ), "expect either results or result_types specified, got both"

    if result_types:
        self._results = [
            ResultValue(self, idx, type=type)
            for idx, type in enumerate(result_types)
        ]

    if not results and not result_types:
        self._results = list(results)

    self.successors = list(successors)
    self.attributes = dict(attributes)
    self.regions = list(regions)

    self.parent = None
    self._next_stmt = None
    self._prev_stmt = None
    self.__post_init__()

field class-attribute instance-attribute

field: int = attribute()

name class-attribute instance-attribute

name = 'getfield'

obj class-attribute instance-attribute

obj: SSAValue = argument(MethodType)

result class-attribute instance-attribute

result: ResultValue = result(init=False)

traits class-attribute instance-attribute

traits = frozenset({Pure()})

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
153
154
155
156
157
158
159
160
def print_impl(self, printer: Printer) -> None:
    printer.print_name(self)
    printer.plain_print(
        "(", printer.state.ssa_id[self.obj], ", ", str(self.field), ")"
    )
    with printer.rich(style="black"):
        printer.plain_print(" : ")
        printer.print(self.result.type)

Invoke dataclass

Invoke(
    *,
    args: Sequence[SSAValue] = (),
    regions: Sequence[Region] = (),
    successors: Sequence[Block] = (),
    attributes: Mapping[str, Attribute] = {},
    results: Sequence[ResultValue] = (),
    result_types: Sequence[TypeAttribute] = (),
    args_slice: Mapping[str, int | slice] = {},
    source: SourceInfo | None = None
)

Bases: Statement


              flowchart TD
              kirin.dialects.func.stmts.Invoke[Invoke]
              kirin.ir.nodes.stmt.Statement[Statement]
              kirin.ir.nodes.base.IRNode[IRNode]
              kirin.print.printable.Printable[Printable]

                              kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.Invoke
                                kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
                                kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
                




              click kirin.dialects.func.stmts.Invoke href "" "kirin.dialects.func.stmts.Invoke"
              click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
              click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
              click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
            
Source code in src/kirin/ir/nodes/stmt.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
def __init__(
    self,
    *,
    args: Sequence[SSAValue] = (),
    regions: Sequence[Region] = (),
    successors: Sequence[Block] = (),
    attributes: Mapping[str, Attribute] = {},
    results: Sequence[ResultValue] = (),
    result_types: Sequence[TypeAttribute] = (),
    args_slice: Mapping[str, int | slice] = {},
    source: SourceInfo | None = None,
) -> None:
    super().__init__()
    """Initialize the Statement.

    Args:
        arsg (Sequence[SSAValue], optional): The arguments of the Statement. Defaults to ().
        regions (Sequence[Region], optional): The regions where the Statement belong to. Defaults to ().
        successors (Sequence[Block], optional): The successors of the Statement. Defaults to ().
        attributes (Mapping[str, Attribute], optional): The attributes of the Statement. Defaults to {}.
        results (Sequence[ResultValue], optional): The result values of the Statement. Defaults to ().
        result_types (Sequence[TypeAttribute], optional): The result types of the Statement. Defaults to ().
        args_slice (Mapping[str, int | slice], optional): The arguments slice of the Statement. Defaults to {}.
        source (SourceInfo | None, optional): The source information of the Statement for debugging/stacktracing. Defaults to None.

    """
    self._args = ()
    self._regions = []
    self._name_args_slice = dict(args_slice)
    self.source = source
    self.args = args

    if results:
        self._results = list(results)
        assert (
            len(result_types) == 0
        ), "expect either results or result_types specified, got both"

    if result_types:
        self._results = [
            ResultValue(self, idx, type=type)
            for idx, type in enumerate(result_types)
        ]

    if not results and not result_types:
        self._results = list(results)

    self.successors = list(successors)
    self.attributes = dict(attributes)
    self.regions = list(regions)

    self.parent = None
    self._next_stmt = None
    self._prev_stmt = None
    self.__post_init__()

callee class-attribute instance-attribute

callee: Method = attribute()

inputs class-attribute instance-attribute

inputs: tuple[SSAValue, ...] = argument()

name class-attribute instance-attribute

name = 'invoke'

purity class-attribute instance-attribute

purity: bool = attribute(default=False)

result class-attribute instance-attribute

result: ResultValue = result()

traits class-attribute instance-attribute

traits = frozenset({MaybePure(), InvokeCall()})

check

check() -> None

Check the statement. Raises Exception if the statement is not correct. This method is called by the verify method, which will detect the source of the error in the IR. One should always call the verify method to verify the IR.

The difference between check and check_type is that check is called at any time to check the structure of the IR by verify, while check_type is called after the type inference to check the types of the IR.

Source code in src/kirin/dialects/func/stmts.py
304
305
306
307
308
def check(self) -> None:
    if self.callee.nargs - 1 != len(self.args):
        raise ValueError(
            f"expected {self.callee.nargs - 1} arguments, got {len(self.args)}"
        )

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def print_impl(self, printer: Printer) -> None:
    with printer.rich(style="red"):
        printer.print_name(self)
    printer.plain_print(" ")
    printer.plain_print(self.callee.sym_name)

    printer.plain_print("(")
    printer.print_seq(self.inputs, delim=", ")
    printer.plain_print(")")

    with printer.rich(style="comment"):
        printer.plain_print(" : ")
        printer.print_seq(
            [result.type for result in self._results],
            delim=", ",
        )
        printer.plain_print(f" maybe_pure={self.purity}")

InvokeCall dataclass

InvokeCall()

Bases: StaticCall['Invoke']


              flowchart TD
              kirin.dialects.func.stmts.InvokeCall[InvokeCall]
              kirin.ir.traits.callable.StaticCall[StaticCall]
              kirin.ir.traits.abc.StmtTrait[StmtTrait]
              kirin.ir.traits.abc.Trait[Trait]

                              kirin.ir.traits.callable.StaticCall --> kirin.dialects.func.stmts.InvokeCall
                                kirin.ir.traits.abc.StmtTrait --> kirin.ir.traits.callable.StaticCall
                                kirin.ir.traits.abc.Trait --> kirin.ir.traits.abc.StmtTrait
                




              click kirin.dialects.func.stmts.InvokeCall href "" "kirin.dialects.func.stmts.InvokeCall"
              click kirin.ir.traits.callable.StaticCall href "" "kirin.ir.traits.callable.StaticCall"
              click kirin.ir.traits.abc.StmtTrait href "" "kirin.ir.traits.abc.StmtTrait"
              click kirin.ir.traits.abc.Trait href "" "kirin.ir.traits.abc.Trait"
            

get_callee classmethod

get_callee(stmt: Invoke) -> ir.Method

Returns the callee of the static call statement.

Source code in src/kirin/dialects/func/stmts.py
35
36
37
@classmethod
def get_callee(cls, stmt: Invoke) -> ir.Method:
    return stmt.callee

Lambda dataclass

Lambda(
    *,
    args: Sequence[SSAValue] = (),
    regions: Sequence[Region] = (),
    successors: Sequence[Block] = (),
    attributes: Mapping[str, Attribute] = {},
    results: Sequence[ResultValue] = (),
    result_types: Sequence[TypeAttribute] = (),
    args_slice: Mapping[str, int | slice] = {},
    source: SourceInfo | None = None
)

Bases: Statement


              flowchart TD
              kirin.dialects.func.stmts.Lambda[Lambda]
              kirin.ir.nodes.stmt.Statement[Statement]
              kirin.ir.nodes.base.IRNode[IRNode]
              kirin.print.printable.Printable[Printable]

                              kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.Lambda
                                kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
                                kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
                




              click kirin.dialects.func.stmts.Lambda href "" "kirin.dialects.func.stmts.Lambda"
              click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
              click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
              click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
            
Source code in src/kirin/ir/nodes/stmt.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
def __init__(
    self,
    *,
    args: Sequence[SSAValue] = (),
    regions: Sequence[Region] = (),
    successors: Sequence[Block] = (),
    attributes: Mapping[str, Attribute] = {},
    results: Sequence[ResultValue] = (),
    result_types: Sequence[TypeAttribute] = (),
    args_slice: Mapping[str, int | slice] = {},
    source: SourceInfo | None = None,
) -> None:
    super().__init__()
    """Initialize the Statement.

    Args:
        arsg (Sequence[SSAValue], optional): The arguments of the Statement. Defaults to ().
        regions (Sequence[Region], optional): The regions where the Statement belong to. Defaults to ().
        successors (Sequence[Block], optional): The successors of the Statement. Defaults to ().
        attributes (Mapping[str, Attribute], optional): The attributes of the Statement. Defaults to {}.
        results (Sequence[ResultValue], optional): The result values of the Statement. Defaults to ().
        result_types (Sequence[TypeAttribute], optional): The result types of the Statement. Defaults to ().
        args_slice (Mapping[str, int | slice], optional): The arguments slice of the Statement. Defaults to {}.
        source (SourceInfo | None, optional): The source information of the Statement for debugging/stacktracing. Defaults to None.

    """
    self._args = ()
    self._regions = []
    self._name_args_slice = dict(args_slice)
    self.source = source
    self.args = args

    if results:
        self._results = list(results)
        assert (
            len(result_types) == 0
        ), "expect either results or result_types specified, got both"

    if result_types:
        self._results = [
            ResultValue(self, idx, type=type)
            for idx, type in enumerate(result_types)
        ]

    if not results and not result_types:
        self._results = list(results)

    self.successors = list(successors)
    self.attributes = dict(attributes)
    self.regions = list(regions)

    self.parent = None
    self._next_stmt = None
    self._prev_stmt = None
    self.__post_init__()

body class-attribute instance-attribute

body: Region = region(multi=True)

captured class-attribute instance-attribute

captured: tuple[SSAValue, ...] = argument()

name class-attribute instance-attribute

name = 'lambda'

result class-attribute instance-attribute

result: ResultValue = result(MethodType)

signature class-attribute instance-attribute

signature: Signature = attribute()

The signature of the function at declaration.

slots class-attribute instance-attribute

slots: tuple[str, ...] = attribute(default=())

The argument names of the function.

sym_name class-attribute instance-attribute

sym_name: str = attribute()

traits class-attribute instance-attribute

traits = frozenset(
    {
        Pure(),
        HasSignature(),
        SymbolOpInterface(),
        FuncOpCallableInterface(),
        HasCFG(),
        SSACFG(),
    }
)

check

check() -> None

Check the statement. Raises Exception if the statement is not correct. This method is called by the verify method, which will detect the source of the error in the IR. One should always call the verify method to verify the IR.

The difference between check and check_type is that check is called at any time to check the structure of the IR by verify, while check_type is called after the type inference to check the types of the IR.

Source code in src/kirin/dialects/func/stmts.py
120
121
def check(self) -> None:
    assert self.body.blocks, "lambda body must not be empty"

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def print_impl(self, printer: Printer) -> None:
    with printer.rich(style="keyword"):
        printer.print_name(self)
    printer.plain_print(" ")

    with printer.rich(style="symbol"):
        printer.plain_print(self.sym_name)

    printer.print_seq(self.captured, prefix="(", suffix=")", delim=", ")

    with printer.rich(style="bright_black"):
        printer.plain_print(" -> ")
        printer.print(self.signature.output)

    printer.plain_print(" ")
    printer.print(self.body)

    with printer.rich(style="black"):
        printer.plain_print(f" // func.lambda {self.sym_name}")

Return

Return(value_or_stmt: SSAValue | Statement | None = None)

Bases: Statement


              flowchart TD
              kirin.dialects.func.stmts.Return[Return]
              kirin.ir.nodes.stmt.Statement[Statement]
              kirin.ir.nodes.base.IRNode[IRNode]
              kirin.print.printable.Printable[Printable]

                              kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.Return
                                kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
                                kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
                




              click kirin.dialects.func.stmts.Return href "" "kirin.dialects.func.stmts.Return"
              click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
              click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
              click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
            
Source code in src/kirin/dialects/func/stmts.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def __init__(self, value_or_stmt: ir.SSAValue | ir.Statement | None = None) -> None:
    if isinstance(value_or_stmt, ir.SSAValue):
        args = [value_or_stmt]
    elif isinstance(value_or_stmt, ir.Statement):
        if len(value_or_stmt._results) == 1:
            args = [value_or_stmt._results[0]]
        else:
            raise ValueError(
                f"expected a single result, got {len(value_or_stmt._results)} results from {value_or_stmt.name}"
            )
    elif value_or_stmt is None:
        args = []
    else:
        raise ValueError(f"expected SSAValue or Statement, got {value_or_stmt}")

    super().__init__(args=args, args_slice={"value": 0})

name class-attribute instance-attribute

name = 'return'

traits class-attribute instance-attribute

traits = frozenset({IsTerminator(), HasParent((Function,))})

value class-attribute instance-attribute

value: SSAValue = argument()

check

check() -> None

Check the statement. Raises Exception if the statement is not correct. This method is called by the verify method, which will detect the source of the error in the IR. One should always call the verify method to verify the IR.

The difference between check and check_type is that check is called at any time to check the structure of the IR by verify, while check_type is called after the type inference to check the types of the IR.

Source code in src/kirin/dialects/func/stmts.py
207
208
209
210
211
212
def check(self) -> None:
    assert self.args, "return statement must have at least one value"
    assert len(self.args) <= 1, (
        "return statement must have at most one value"
        ", wrap multiple values in a tuple"
    )

print_impl

print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
199
200
201
202
203
204
205
def print_impl(self, printer: Printer) -> None:
    with printer.rich(style="keyword"):
        printer.print_name(self)

    if self.args:
        printer.plain_print(" ")
        printer.print_seq(self.args, delim=", ")