Skip to content

Dialect Group

DialectGroup dataclass

DialectGroup(
    dialects: Iterable[Union[Dialect, ModuleType]],
    run_pass: RunPassGen[PassParams] | None = None,
)

Bases: Generic[PassParams]

Source code in src/kirin/ir/group.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(
    self,
    dialects: Iterable[Union["Dialect", ModuleType]],
    run_pass: RunPassGen[PassParams] | None = None,
):
    def identity(code: Method):
        pass

    self.data = frozenset(self.map_module(dialect) for dialect in dialects)
    if run_pass is None:
        self.run_pass_gen = None
        self.run_pass = None
    else:
        self.run_pass_gen = run_pass
        self.run_pass = run_pass(self)

data instance-attribute

data: frozenset[Dialect] = frozenset(
    map_module(dialect) for dialect in dialects
)

The set of dialects in the group.

registry property

registry: Registry

return the registry for the dialect group. This returns a proxy object that can be used to select the lowering interpreters, interpreters, and codegen for the dialects in the group.

Returns:

Name Type Description
Registry Registry

the registry object.

run_pass class-attribute instance-attribute

run_pass: RunPass[PassParams] | None = None

the function that runs the passes on the method.

run_pass_gen class-attribute instance-attribute

run_pass_gen: RunPassGen[PassParams] | None = None

the function that generates the run_pass function.

This is used to create new dialect groups from existing ones, while keeping the same run_pass function.

__call__

__call__(
    py_func: Callable[Param, RetType],
    *args: PassParams.args,
    **options: PassParams.kwargs
) -> Method[Param, RetType]
__call__(
    py_func: None = None,
    *args: PassParams.args,
    **options: PassParams.kwargs
) -> MethodTransform[Param, RetType]
__call__(
    py_func: Callable[Param, RetType] | None = None,
    *args: PassParams.args,
    **options: PassParams.kwargs
) -> (
    Method[Param, RetType] | MethodTransform[Param, RetType]
)

create a method from the python function.

Parameters:

Name Type Description Default
py_func Callable

the python function to create the method from.

None
args args

the arguments to pass to the run_pass function.

()
options kwargs

the keyword arguments to pass to the run_pass function.

{}

Returns:

Name Type Description
Method Method[Param, RetType] | MethodTransform[Param, RetType]

the method created from the python function.

Source code in src/kirin/ir/group.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def __call__(
    self,
    py_func: Callable[Param, RetType] | None = None,
    *args: PassParams.args,
    **options: PassParams.kwargs,
) -> Method[Param, RetType] | MethodTransform[Param, RetType]:
    """create a method from the python function.

    Args:
        py_func (Callable): the python function to create the method from.
        args (PassParams.args): the arguments to pass to the run_pass function.
        options (PassParams.kwargs): the keyword arguments to pass to the run_pass function.

    Returns:
        Method: the method created from the python function.
    """
    from kirin.lowering import Lowering

    emit_ir = Lowering(self)

    def wrapper(py_func: Callable) -> Method:
        if py_func.__name__ == "<lambda>":
            raise ValueError("Cannot compile lambda functions")

        lineno_offset, file = 0, ""
        frame = inspect.currentframe()
        if frame and frame.f_back is not None and frame.f_back.f_back is not None:
            call_site_frame = frame.f_back.f_back
            if py_func.__name__ in call_site_frame.f_locals:
                raise CompilerError(
                    f"overwriting function definition of `{py_func.__name__}`"
                )

            lineno_offset = call_site_frame.f_lineno - 1
            file = call_site_frame.f_code.co_filename

        code = emit_ir.run(py_func, lineno_offset=lineno_offset)
        mt = Method(
            mod=inspect.getmodule(py_func),
            py_func=py_func,
            sym_name=py_func.__name__,
            arg_names=["#self#"] + inspect.getfullargspec(py_func).args,
            dialects=self,
            code=code,
            file=file,
        )
        if doc := inspect.getdoc(py_func):
            mt.__doc__ = doc

        if self.run_pass is not None:
            self.run_pass(mt, *args, **options)
        return mt

    if py_func is not None:
        return wrapper(py_func)
    return wrapper

__contains__

__contains__(dialect) -> bool

check if the dialect is in the group.

Parameters:

Name Type Description Default
dialect Union[Dialect, ModuleType]

the dialect to check.

required

Returns:

Name Type Description
bool bool

True if the dialect is in the group, False otherwise.

Source code in src/kirin/ir/group.py
126
127
128
129
130
131
132
133
134
135
def __contains__(self, dialect) -> bool:
    """check if the dialect is in the group.

    Args:
        dialect (Union[Dialect, ModuleType]): the dialect to check.

    Returns:
        bool: True if the dialect is in the group, False otherwise.
    """
    return self.map_module(dialect) in self.data

add

add(dialect: Union[Dialect, ModuleType]) -> DialectGroup

add a dialect to the group.

Parameters:

Name Type Description Default
dialect Union[Dialect, ModuleType]

the dialect to add

required

Returns:

Name Type Description
DialectGroup DialectGroup

the new dialect group with the added

Source code in src/kirin/ir/group.py
81
82
83
84
85
86
87
88
89
90
def add(self, dialect: Union["Dialect", ModuleType]) -> "DialectGroup":
    """add a dialect to the group.

    Args:
        dialect (Union[Dialect, ModuleType]): the dialect to add

    Returns:
        DialectGroup: the new dialect group with the added
    """
    return self.union([dialect])

discard

discard(
    dialect: Union[Dialect, ModuleType]
) -> DialectGroup

discard a dialect from the group.

Note

This does not raise an error if the dialect is not in the group.

Parameters:

Name Type Description Default
dialect Union[Dialect, ModuleType]

the dialect to discard

required

Returns:

Name Type Description
DialectGroup DialectGroup

the new dialect group with the discarded dialect.

Source code in src/kirin/ir/group.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def discard(self, dialect: Union["Dialect", ModuleType]) -> "DialectGroup":
    """discard a dialect from the group.

    !!! note
        This does not raise an error if the dialect is not in the group.

    Args:
        dialect (Union[Dialect, ModuleType]): the dialect to discard

    Returns:
        DialectGroup: the new dialect group with the discarded dialect.
    """
    dialect_ = self.map_module(dialect)
    return DialectGroup(
        dialects=frozenset(
            each for each in self.data if each.name != dialect_.name
        ),
        run_pass=self.run_pass_gen,  # pass the run_pass_gen function
    )

map_module staticmethod

map_module(dialect: Union[Dialect, ModuleType]) -> Dialect

map the module to the dialect if it is a module. It assumes that the module has a dialect attribute that is an instance of Dialect.

Source code in src/kirin/ir/group.py
71
72
73
74
75
76
77
78
79
@staticmethod
def map_module(dialect: Union["Dialect", ModuleType]) -> "Dialect":
    """map the module to the dialect if it is a module.
    It assumes that the module has a `dialect` attribute
    that is an instance of [`Dialect`][kirin.ir.Dialect].
    """
    if isinstance(dialect, ModuleType):
        return getattr(dialect, "dialect")
    return dialect

union

union(
    dialect: Iterable[Union[Dialect, ModuleType]]
) -> DialectGroup

union a set of dialects to the group.

Parameters:

Name Type Description Default
dialect Iterable[Union[Dialect, ModuleType]]

the dialects to union

required

Returns:

Name Type Description
DialectGroup DialectGroup

the new dialect group with the union.

Source code in src/kirin/ir/group.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def union(self, dialect: Iterable[Union["Dialect", ModuleType]]) -> "DialectGroup":
    """union a set of dialects to the group.

    Args:
        dialect (Iterable[Union[Dialect, ModuleType]]): the dialects to union

    Returns:
        DialectGroup: the new dialect group with the union.
    """
    return DialectGroup(
        dialects=self.data.union(frozenset(self.map_module(d) for d in dialect)),
        run_pass=self.run_pass_gen,  # pass the run_pass_gen function
    )

dialect_group

dialect_group(
    dialects: Iterable[Union[Dialect, ModuleType]]
) -> Callable[
    [RunPassGen[PassParams]], DialectGroup[PassParams]
]

Create a dialect group from the given dialects based on the definition of run_pass function.

Parameters:

Name Type Description Default
dialects Iterable[Union[Dialect, ModuleType]]

the dialects to include in the group.

required

Returns:

Type Description
Callable[[RunPassGen[PassParams]], DialectGroup[PassParams]]

Callable[[RunPassGen[PassParams]], DialectGroup[PassParams]]: the dialect group.

Example
from kirin.dialects import cf, fcf, func, math

@dialect_group([cf, fcf, func, math])
def basic_no_opt(self):
    # initializations
    def run_pass(mt: Method) -> None:
        # how passes are applied to the method
        pass

    return run_pass
Source code in src/kirin/ir/group.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def dialect_group(
    dialects: Iterable[Union["Dialect", ModuleType]]
) -> Callable[[RunPassGen[PassParams]], DialectGroup[PassParams]]:
    """Create a dialect group from the given dialects based on the
    definition of `run_pass` function.

    Args:
        dialects (Iterable[Union[Dialect, ModuleType]]): the dialects to include in the group.

    Returns:
        Callable[[RunPassGen[PassParams]], DialectGroup[PassParams]]: the dialect group.

    Example:
        ```python
        from kirin.dialects import cf, fcf, func, math

        @dialect_group([cf, fcf, func, math])
        def basic_no_opt(self):
            # initializations
            def run_pass(mt: Method) -> None:
                # how passes are applied to the method
                pass

            return run_pass
        ```
    """

    # NOTE: do not alias the annotation below
    def wrapper(
        transform: RunPassGen[PassParams],
    ) -> DialectGroup[PassParams]:
        ret = DialectGroup(dialects, run_pass=transform)
        update_wrapper(ret, transform)
        return ret

    return wrapper