Warning
This page is under construction. The content may be incomplete or incorrect. Submit an issue on GitHub if you need help or want to contribute.
The Function Dialect
The function dialect provides a set of statements to model semantics of Python-like functions, that means:
def <name>(<args>*, <kwargs>*)like function declarations- nested functions (namely closures)
- high-order functions (functions can be used as arguments)
- dynamically/statically calling a function or closure
func.Return
This is a simple statement that models the return statement in a function declaration. While this is a very simple statement, it is worth noting that this statement only accepts one argument of type ir.SSAValue because in Python (and most of other languages) functions always have a single return value, and multiple return values are represented by returning a tuple.
func.Function
This is the most fundamental statement that models a Python function.
Definition The func.Function takes no arguments, but contains a special str attribute (thus stored as PyAttr) that can be used as a symbolic reference within a symbol table. The func.Function also takes a func.Signature attribute to store the signature of corresponding function declaration. Last, it contains a ir.Region that represents the function body. The ir.Region follows the SSACFG convention where the blocks in the region forms a control flow graph.
Differences with MLIR
As Kirin's priority is writing eDSL as kernel functions in Python. To support high-order functions the entry block arguments always have their first argument self of type [types.MethodType][kirin.types.MethodType]. This is a design inspired by Julia's IR design.
As an example, the following Python function
from kirin.prelude import basic_no_opt
@basic_no_opt
def main(x):
return x
will be lowered into the following IR, where main_self referencing the function itself.
func.func main(!Any) -> !Any {
^0(%main_self, %x):
│ func.return %x
} // func.func main
the function can be terminated by a func.Return statement. All blocks in the function region must have terminators. In the lowering process, if the block is not terminated, a func.Return will be attached to return None in the function body. Thus func.Function can only have a single return value.
func.Call and func.Invoke
These two statements models the most common call convention in Python with consideration of compilation:
func.Callmodels dynamic calls where the callee is unknown at compile time, thus of typeir.SSAValuefunc.Invokemodels static calls where the callee is known at compile time, thus of typeir.Method
they both take inputs which is a tuple of ir.SSAValue as argument. Because we assume all functions will only return a single value, func.Call and func.Invoke only have a single result.
func.Lambda
This statement models nested functions (a.k.a closures). While most definitions are similar to func.Function the key difference is func.Lambda takes a tuple of ir.SSAValue arguments as captured. This models the captured variables for a nested function, e.g
the following Python function containing a closure inside with variable x being captured:
from kirin import basic_no_opt
@basic_no_opt
def main(x):
def closure():
return x
return closure
will be lowered into
func.func main(!Any) -> !Any {
^0(%main_self, %x):
│ %closure = func.lambda closure(%x) -> !Any {
│ │ ^1(%closure_self):
│ │ │ %x_1 = func.getfield(%closure_self, 0) : !Any
│ │ │ func.return %x_1
│ } // func.lambda closure
│ func.return %closure
} // func.func main
Unlike func.Function this statement also has a result value which points to the closure itself. Inside the closure body, we insert [func.GetField][kirin.dialects.func.GetField] to unpack captured variables into the closure body.
API Reference
Call dataclass
Call(
*,
args: Sequence[SSAValue] = (),
regions: Sequence[Region] = (),
successors: Sequence[Block] = (),
attributes: Mapping[str, Attribute] = {},
results: Sequence[ResultValue] = (),
result_types: Sequence[TypeAttribute] = (),
args_slice: Mapping[str, int | slice] = {},
source: SourceInfo | None = None
)
Bases: Statement
flowchart TD
kirin.dialects.func.stmts.Call[Call]
kirin.ir.nodes.stmt.Statement[Statement]
kirin.ir.nodes.base.IRNode[IRNode]
kirin.print.printable.Printable[Printable]
kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.Call
kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
click kirin.dialects.func.stmts.Call href "" "kirin.dialects.func.stmts.Call"
click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
Source code in src/kirin/ir/nodes/stmt.py
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 | |
callee class-attribute instance-attribute
callee: SSAValue = argument()
inputs class-attribute instance-attribute
inputs: tuple[SSAValue, ...] = argument()
keys class-attribute instance-attribute
keys: tuple[str, ...] = attribute(default=())
kwargs class-attribute instance-attribute
kwargs: tuple[SSAValue, ...] = argument()
name class-attribute instance-attribute
name = 'call'
purity class-attribute instance-attribute
purity: bool = attribute(default=False)
result class-attribute instance-attribute
result: ResultValue = result()
traits class-attribute instance-attribute
traits = frozenset({MaybePure()})
check_type
check_type() -> None
Check the types of the Block. Raises Exception if the types are not correct. This method is called by the verify_type method, which will detect the source of the error in the IR. One should always call the verify_type method to verify the types of the IR.
Note
This method is generated by the @statement decorator. But can be overridden if needed.
Source code in src/kirin/dialects/func/stmts.py
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 | |
print_impl
print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | |
ConstantNone dataclass
ConstantNone(
*,
args: Sequence[SSAValue] = (),
regions: Sequence[Region] = (),
successors: Sequence[Block] = (),
attributes: Mapping[str, Attribute] = {},
results: Sequence[ResultValue] = (),
result_types: Sequence[TypeAttribute] = (),
args_slice: Mapping[str, int | slice] = {},
source: SourceInfo | None = None
)
Bases: Statement
flowchart TD
kirin.dialects.func.stmts.ConstantNone[ConstantNone]
kirin.ir.nodes.stmt.Statement[Statement]
kirin.ir.nodes.base.IRNode[IRNode]
kirin.print.printable.Printable[Printable]
kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.ConstantNone
kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
click kirin.dialects.func.stmts.ConstantNone href "" "kirin.dialects.func.stmts.ConstantNone"
click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
A constant None value.
This is mainly used to represent the None return value of a function to match Python semantics.
Source code in src/kirin/ir/nodes/stmt.py
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 | |
name class-attribute instance-attribute
name = 'const.none'
result class-attribute instance-attribute
result: ResultValue = result(NoneType)
traits class-attribute instance-attribute
traits = frozenset({Pure(), ConstantLike()})
FuncOpCallableInterface dataclass
FuncOpCallableInterface()
Bases: CallableStmtInterface['Function']
flowchart TD
kirin.dialects.func.stmts.FuncOpCallableInterface[FuncOpCallableInterface]
kirin.ir.traits.callable.CallableStmtInterface[CallableStmtInterface]
kirin.ir.traits.abc.StmtTrait[StmtTrait]
kirin.ir.traits.abc.Trait[Trait]
kirin.ir.traits.callable.CallableStmtInterface --> kirin.dialects.func.stmts.FuncOpCallableInterface
kirin.ir.traits.abc.StmtTrait --> kirin.ir.traits.callable.CallableStmtInterface
kirin.ir.traits.abc.Trait --> kirin.ir.traits.abc.StmtTrait
click kirin.dialects.func.stmts.FuncOpCallableInterface href "" "kirin.dialects.func.stmts.FuncOpCallableInterface"
click kirin.ir.traits.callable.CallableStmtInterface href "" "kirin.ir.traits.callable.CallableStmtInterface"
click kirin.ir.traits.abc.StmtTrait href "" "kirin.ir.traits.abc.StmtTrait"
click kirin.ir.traits.abc.Trait href "" "kirin.ir.traits.abc.Trait"
ValueType class-attribute instance-attribute
ValueType = TypeVar('ValueType')
align_input_args classmethod
align_input_args(
stmt: Function, *args: ValueType, **kwargs: ValueType
) -> tuple[ValueType, ...]
Permute the arguments and keyword arguments of the statement to match the execution order of the callable region input.
Source code in src/kirin/dialects/func/stmts.py
22 23 24 25 26 27 28 29 30 | |
get_callable_region classmethod
get_callable_region(stmt: 'Function') -> ir.Region
Returns the body of the callable region
Source code in src/kirin/dialects/func/stmts.py
16 17 18 | |
Function dataclass
Function(
*,
args: Sequence[SSAValue] = (),
regions: Sequence[Region] = (),
successors: Sequence[Block] = (),
attributes: Mapping[str, Attribute] = {},
results: Sequence[ResultValue] = (),
result_types: Sequence[TypeAttribute] = (),
args_slice: Mapping[str, int | slice] = {},
source: SourceInfo | None = None
)
Bases: Statement
flowchart TD
kirin.dialects.func.stmts.Function[Function]
kirin.ir.nodes.stmt.Statement[Statement]
kirin.ir.nodes.base.IRNode[IRNode]
kirin.print.printable.Printable[Printable]
kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.Function
kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
click kirin.dialects.func.stmts.Function href "" "kirin.dialects.func.stmts.Function"
click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
Source code in src/kirin/ir/nodes/stmt.py
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 | |
body class-attribute instance-attribute
body: Region = region(multi=True)
The body of the function.
name class-attribute instance-attribute
name = 'func'
result class-attribute instance-attribute
result: ResultValue = result(MethodType)
The result of the function.
signature class-attribute instance-attribute
signature: Signature = attribute()
The signature of the function at declaration.
slots class-attribute instance-attribute
slots: tuple[str, ...] = attribute(default=())
The argument names of the function.
sym_name class-attribute instance-attribute
sym_name: str = attribute()
The symbol name of the function.
traits class-attribute instance-attribute
traits = frozenset(
{
IsolatedFromAbove(),
SymbolOpInterface(),
HasSignature(),
FuncOpCallableInterface(),
HasCFG(),
SSACFG(),
}
)
print_impl
print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | |
GetField dataclass
GetField(
*,
args: Sequence[SSAValue] = (),
regions: Sequence[Region] = (),
successors: Sequence[Block] = (),
attributes: Mapping[str, Attribute] = {},
results: Sequence[ResultValue] = (),
result_types: Sequence[TypeAttribute] = (),
args_slice: Mapping[str, int | slice] = {},
source: SourceInfo | None = None
)
Bases: Statement
flowchart TD
kirin.dialects.func.stmts.GetField[GetField]
kirin.ir.nodes.stmt.Statement[Statement]
kirin.ir.nodes.base.IRNode[IRNode]
kirin.print.printable.Printable[Printable]
kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.GetField
kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
click kirin.dialects.func.stmts.GetField href "" "kirin.dialects.func.stmts.GetField"
click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
Source code in src/kirin/ir/nodes/stmt.py
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 | |
field class-attribute instance-attribute
field: int = attribute()
name class-attribute instance-attribute
name = 'getfield'
obj class-attribute instance-attribute
obj: SSAValue = argument(MethodType)
result class-attribute instance-attribute
result: ResultValue = result(init=False)
traits class-attribute instance-attribute
traits = frozenset({Pure()})
print_impl
print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
153 154 155 156 157 158 159 160 | |
Invoke dataclass
Invoke(
*,
args: Sequence[SSAValue] = (),
regions: Sequence[Region] = (),
successors: Sequence[Block] = (),
attributes: Mapping[str, Attribute] = {},
results: Sequence[ResultValue] = (),
result_types: Sequence[TypeAttribute] = (),
args_slice: Mapping[str, int | slice] = {},
source: SourceInfo | None = None
)
Bases: Statement
flowchart TD
kirin.dialects.func.stmts.Invoke[Invoke]
kirin.ir.nodes.stmt.Statement[Statement]
kirin.ir.nodes.base.IRNode[IRNode]
kirin.print.printable.Printable[Printable]
kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.Invoke
kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
click kirin.dialects.func.stmts.Invoke href "" "kirin.dialects.func.stmts.Invoke"
click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
Source code in src/kirin/ir/nodes/stmt.py
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 | |
callee class-attribute instance-attribute
callee: Method = attribute()
inputs class-attribute instance-attribute
inputs: tuple[SSAValue, ...] = argument()
name class-attribute instance-attribute
name = 'invoke'
purity class-attribute instance-attribute
purity: bool = attribute(default=False)
result class-attribute instance-attribute
result: ResultValue = result()
traits class-attribute instance-attribute
traits = frozenset({MaybePure(), InvokeCall()})
check
check() -> None
Check the statement. Raises Exception if the statement is not correct. This method is called by the verify method, which will detect the source of the error in the IR. One should always call the verify method to verify the IR.
The difference between check and check_type is that check is called at any time to check the structure of the IR by verify, while check_type is called after the type inference to check the types of the IR.
Source code in src/kirin/dialects/func/stmts.py
304 305 306 307 308 | |
print_impl
print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 | |
InvokeCall dataclass
InvokeCall()
Bases: StaticCall['Invoke']
flowchart TD
kirin.dialects.func.stmts.InvokeCall[InvokeCall]
kirin.ir.traits.callable.StaticCall[StaticCall]
kirin.ir.traits.abc.StmtTrait[StmtTrait]
kirin.ir.traits.abc.Trait[Trait]
kirin.ir.traits.callable.StaticCall --> kirin.dialects.func.stmts.InvokeCall
kirin.ir.traits.abc.StmtTrait --> kirin.ir.traits.callable.StaticCall
kirin.ir.traits.abc.Trait --> kirin.ir.traits.abc.StmtTrait
click kirin.dialects.func.stmts.InvokeCall href "" "kirin.dialects.func.stmts.InvokeCall"
click kirin.ir.traits.callable.StaticCall href "" "kirin.ir.traits.callable.StaticCall"
click kirin.ir.traits.abc.StmtTrait href "" "kirin.ir.traits.abc.StmtTrait"
click kirin.ir.traits.abc.Trait href "" "kirin.ir.traits.abc.Trait"
get_callee classmethod
get_callee(stmt: Invoke) -> ir.Method
Returns the callee of the static call statement.
Source code in src/kirin/dialects/func/stmts.py
35 36 37 | |
Lambda dataclass
Lambda(
*,
args: Sequence[SSAValue] = (),
regions: Sequence[Region] = (),
successors: Sequence[Block] = (),
attributes: Mapping[str, Attribute] = {},
results: Sequence[ResultValue] = (),
result_types: Sequence[TypeAttribute] = (),
args_slice: Mapping[str, int | slice] = {},
source: SourceInfo | None = None
)
Bases: Statement
flowchart TD
kirin.dialects.func.stmts.Lambda[Lambda]
kirin.ir.nodes.stmt.Statement[Statement]
kirin.ir.nodes.base.IRNode[IRNode]
kirin.print.printable.Printable[Printable]
kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.Lambda
kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
click kirin.dialects.func.stmts.Lambda href "" "kirin.dialects.func.stmts.Lambda"
click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
Source code in src/kirin/ir/nodes/stmt.py
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 | |
body class-attribute instance-attribute
body: Region = region(multi=True)
captured class-attribute instance-attribute
captured: tuple[SSAValue, ...] = argument()
name class-attribute instance-attribute
name = 'lambda'
result class-attribute instance-attribute
result: ResultValue = result(MethodType)
signature class-attribute instance-attribute
signature: Signature = attribute()
The signature of the function at declaration.
slots class-attribute instance-attribute
slots: tuple[str, ...] = attribute(default=())
The argument names of the function.
sym_name class-attribute instance-attribute
sym_name: str = attribute()
traits class-attribute instance-attribute
traits = frozenset(
{
Pure(),
HasSignature(),
SymbolOpInterface(),
FuncOpCallableInterface(),
HasCFG(),
SSACFG(),
}
)
check
check() -> None
Check the statement. Raises Exception if the statement is not correct. This method is called by the verify method, which will detect the source of the error in the IR. One should always call the verify method to verify the IR.
The difference between check and check_type is that check is called at any time to check the structure of the IR by verify, while check_type is called after the type inference to check the types of the IR.
Source code in src/kirin/dialects/func/stmts.py
120 121 | |
print_impl
print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | |
Return
Return(value_or_stmt: SSAValue | Statement | None = None)
Bases: Statement
flowchart TD
kirin.dialects.func.stmts.Return[Return]
kirin.ir.nodes.stmt.Statement[Statement]
kirin.ir.nodes.base.IRNode[IRNode]
kirin.print.printable.Printable[Printable]
kirin.ir.nodes.stmt.Statement --> kirin.dialects.func.stmts.Return
kirin.ir.nodes.base.IRNode --> kirin.ir.nodes.stmt.Statement
kirin.print.printable.Printable --> kirin.ir.nodes.base.IRNode
click kirin.dialects.func.stmts.Return href "" "kirin.dialects.func.stmts.Return"
click kirin.ir.nodes.stmt.Statement href "" "kirin.ir.nodes.stmt.Statement"
click kirin.ir.nodes.base.IRNode href "" "kirin.ir.nodes.base.IRNode"
click kirin.print.printable.Printable href "" "kirin.print.printable.Printable"
Source code in src/kirin/dialects/func/stmts.py
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | |
name class-attribute instance-attribute
name = 'return'
traits class-attribute instance-attribute
traits = frozenset({IsTerminator(), HasParent((Function,))})
value class-attribute instance-attribute
value: SSAValue = argument()
check
check() -> None
Check the statement. Raises Exception if the statement is not correct. This method is called by the verify method, which will detect the source of the error in the IR. One should always call the verify method to verify the IR.
The difference between check and check_type is that check is called at any time to check the structure of the IR by verify, while check_type is called after the type inference to check the types of the IR.
Source code in src/kirin/dialects/func/stmts.py
207 208 209 210 211 212 | |
print_impl
print_impl(printer: Printer) -> None
Source code in src/kirin/dialects/func/stmts.py
199 200 201 202 203 204 205 | |