Skip to content

jarp

PyTree-oriented helpers for JAX, attrs, and Warp.

Modules:

  • lax

    Control-flow wrappers that mirror JAX APIs with small ergonomic additions.

  • tree

    Helpers for defining, flattening, and transforming JAX PyTrees.

  • warp

    Interop helpers between JAX arrays and Warp kernels or callables.

Classes:

  • Partial

    Store a partially applied callable as a PyTree-aware proxy.

  • PyTreeProxy

    Wrap an arbitrary object and flatten the wrapped value as a PyTree.

  • Structure

    Record how to flatten and rebuild a PyTree's dynamic leaves.

Functions:

  • array
  • auto
  • define

    Define an attrs class and optionally register it as a PyTree.

  • field
  • frozen

    Define a frozen attrs class and register it as a data PyTree.

  • frozen_static

    Define a frozen attrs class and register it as a static PyTree.

  • jax_callable

    Wrap warp.jax_experimental.jax_callable with optional dtype dispatch.

  • jax_kernel

    Wrap warp.jax_experimental.jax_kernel with optional overload lookup.

  • jit

    Compile a callable with JAX, optionally preserving static PyTree leaves.

  • partial

    Partially apply a callable and keep the result compatible with JAX trees.

  • ravel

    Flatten a PyTree's dynamic leaves into one vector.

  • register_pytree_prelude
  • static
  • to_warp

    Convert a supported array object into a :class:warp.array.

  • while_loop

    Run a loop with either jax.lax.while_loop or Python control flow.

Attributes:

__version__ module-attribute

__version__: str = '0.1.10.dev3+g4d8037a3c'

__version_tuple__ module-attribute

__version_tuple__: tuple[int | str, ...] = (
    0,
    1,
    10,
    "dev3",
    "g4d8037a3c",
)

Partial

Partial(
    func: Callable[..., T], /, *args: Any, **kwargs: Any
)

Bases: PartialCallableObjectProxy


              flowchart TD
              jarp.Partial[Partial]

              

              click jarp.Partial href "" "jarp.Partial"
            

Store a partially applied callable as a PyTree-aware proxy.

Methods:

Attributes:

Source code in src/jarp/tree/prelude/_partial.py
def __init__(self, func: Callable[..., T], /, *args: Any, **kwargs: Any) -> None:
    """Create a proxy that records bound arguments for PyTree flattening."""
    super().__init__(func, *args, **kwargs)
    self._self_args = args
    self._self_kwargs = kwargs

__wrapped__ instance-attribute

__wrapped__: Callable[..., T]

__call__

__call__(*args: P.args, **kwargs: P.kwargs) -> T
Source code in src/jarp/tree/prelude/_partial.py
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

PyTreeProxy

Bases: BaseObjectProxy


              flowchart TD
              jarp.PyTreeProxy[PyTreeProxy]

              

              click jarp.PyTreeProxy href "" "jarp.PyTreeProxy"
            

Wrap an arbitrary object and flatten the wrapped value as a PyTree.

Attributes:

__wrapped__ instance-attribute

__wrapped__: T

Structure

Record how to flatten and rebuild a PyTree's dynamic leaves.

Instances are returned by :func:ravel and capture the original tree definition, the static leaves that were removed from the flat vector, and the offsets needed to reconstruct each dynamic leaf.

Parameters:

Methods:

  • ravel

    Flatten a compatible tree or flatten an array in-place.

  • unravel

    Rebuild the original tree shape from a flat vector.

Attributes:

dtype instance-attribute

dtype: DTypeLike

is_leaf property

is_leaf: bool

Return whether the original tree was a single leaf.

meta_leaves instance-attribute

meta_leaves: tuple[Any, ...]

offsets instance-attribute

offsets: tuple[int, ...]

shapes instance-attribute

shapes: tuple[Shape | None, ...]

treedef instance-attribute

treedef: PyTreeDef

ravel

ravel(tree: T | Array) -> Vector

Flatten a compatible tree or flatten an array in-place.

Parameters:

  • tree

    (T | Array) –

    A tree with the same structure used to build this :class:Structure, or an already-flat array.

Returns:

  • Vector

    A one-dimensional array containing the dynamic leaves.

Source code in src/jarp/tree/_ravel.py
def ravel(self, tree: T | Array) -> Vector:
    """Flatten a compatible tree or flatten an array in-place.

    Args:
        tree: A tree with the same structure used to build this
            :class:`Structure`, or an already-flat array.

    Returns:
        A one-dimensional array containing the dynamic leaves.
    """
    if isinstance(tree, Array):
        # do not flatten if already flat
        return jnp.ravel(tree)
    leaves, treedef = jax.tree.flatten(tree)
    assert treedef == self.treedef
    data_leaves, meta_leaves = partition_leaves(leaves)
    assert tuple(meta_leaves) == self.meta_leaves
    return _ravel(data_leaves)

unravel

unravel(
    flat: T | Array, dtype: DTypeLike | None = None
) -> T

Rebuild the original tree shape from a flat vector.

Parameters:

  • flat

    (T | Array) –

    One-dimensional data produced by :meth:ravel, or a tree that already matches the recorded structure.

  • dtype

    (DTypeLike | None, default: None ) –

    Optional dtype override applied to the flat array before it is split and reshaped.

Returns:

  • T

    A tree with the same structure and static metadata as the original

  • T

    input to :func:ravel.

Source code in src/jarp/tree/_ravel.py
def unravel(self, flat: T | Array, dtype: DTypeLike | None = None) -> T:
    """Rebuild the original tree shape from a flat vector.

    Args:
        flat: One-dimensional data produced by :meth:`ravel`, or a tree that
            already matches the recorded structure.
        dtype: Optional dtype override applied to the flat array before it
            is split and reshaped.

    Returns:
        A tree with the same structure and static metadata as the original
        input to :func:`ravel`.
    """
    if not isinstance(flat, Array):
        # do not unravel if already a pytree
        assert jax.tree.structure(flat) == self.treedef
        return cast("T", flat)
    flat: Array = jnp.asarray(flat, self.dtype if dtype is None else dtype)
    if self.is_leaf:
        return cast("T", jnp.reshape(flat, self.shapes[0]))
    data_leaves: list[Array | None] = _unravel(flat, self.offsets, self.shapes)
    leaves: list[Any] = combine_leaves(data_leaves, self.meta_leaves)
    return jax.tree.unflatten(self.treedef, leaves)

array

array(
    *,
    default: T = ...,
    validator: _ValidatorArgType[T] | None = ...,
    repr: _ReprArgType = ...,
    hash: bool | None = ...,
    init: bool = ...,
    metadata: Mapping[Any, Any] | None = ...,
    converter: _ConverterType
    | list[_ConverterType]
    | tuple[_ConverterType, ...]
    | None = ...,
    factory: Callable[[], T] | None = ...,
    kw_only: bool | None = ...,
    eq: _EqOrderType | None = ...,
    order: _EqOrderType | None = ...,
    on_setattr: _OnSetAttrArgType | None = ...,
    alias: str | None = ...,
    type: type | None = ...,
    static: FieldType | bool | None = ...,
) -> Array

Parameters:

  • default

    (T, default: ... ) –
  • validator

    (_ValidatorArgType[T] | None, default: ... ) –
  • repr

    (_ReprArgType, default: ... ) –
  • hash

    (bool | None, default: ... ) –
  • init

    (bool, default: ... ) –
  • metadata

    (Mapping[Any, Any] | None, default: ... ) –
  • converter

    (_ConverterType | list[_ConverterType] | tuple[_ConverterType, ...] | None, default: ... ) –
  • factory

    (Callable[[], T] | None, default: ... ) –
  • kw_only

    (bool | None, default: ... ) –
  • eq

    (_EqOrderType | None, default: ... ) –
  • order

    (_EqOrderType | None, default: ... ) –
  • on_setattr

    (_OnSetAttrArgType | None, default: ... ) –
  • alias

    (str | None, default: ... ) –
  • type

    (type | None, default: ... ) –
  • static

    (FieldType | bool | None, default: ... ) –
Source code in src/jarp/tree/attrs/_field_specifiers.py
def array(**kwargs: Unpack[FieldOptions[ArrayLike | None]]) -> Array:
    if "default" in kwargs and "factory" not in kwargs:
        default: ArrayLike | None = kwargs["default"]
        if not (default is None or isinstance(default, attrs.Factory)):  # ty:ignore[invalid-argument-type]
            default: Array = jnp.asarray(default)
            kwargs.pop("default")
            kwargs["factory"] = lambda: default
    return field(**kwargs)

auto

auto(**kwargs) -> Any
Source code in src/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def auto(**kwargs) -> Any:
    kwargs.setdefault("static", FieldType.AUTO)
    return field(**kwargs)

define

define[T: type](
    cls: T,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> T
define[T: type](
    cls: None = None,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> Callable[[T], T]

Define an attrs class and optionally register it as a PyTree.

Parameters:

  • maybe_cls

    (T | None, default: None ) –

    Class being decorated. When omitted, return a configured decorator.

  • **kwargs

    (Any, default: {} ) –

    Options forwarded to :func:attrs.define, plus pytree to control JAX registration. pytree="data" registers fields with fieldz semantics, "static" registers the whole instance as a static value, and "none" leaves the class unregistered.

Returns:

  • Any

    The decorated class or a class decorator.

Source code in src/jarp/tree/attrs/_define.py
def define[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define an ``attrs`` class and optionally register it as a PyTree.

    Args:
        maybe_cls: Class being decorated. When omitted, return a configured
            decorator.
        **kwargs: Options forwarded to :func:`attrs.define`, plus ``pytree`` to
            control JAX registration. ``pytree="data"`` registers fields with
            ``fieldz`` semantics, ``"static"`` registers the whole instance as a
            static value, and ``"none"`` leaves the class unregistered.

    Returns:
        The decorated class or a class decorator.
    """
    if maybe_cls is None:
        return functools.partial(define, **kwargs)
    pytree: PyTreeType = PyTreeType(kwargs.pop("pytree", None))
    frozen: bool = kwargs.get("frozen", False)
    if pytree is PyTreeType.STATIC and not frozen:
        warnings.warn(
            "Defining a static class that is not frozen may lead to unexpected behavior.",
            stacklevel=2,
        )
    cls: T = attrs.define(maybe_cls, **kwargs)  # ty:ignore[invalid-assignment]
    match pytree:
        case PyTreeType.DATA:
            register_fieldz(cls)
        case PyTreeType.STATIC:
            jtu.register_static(cls)
    return cls

field

field(**kwargs) -> Any
Source code in src/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def field(**kwargs) -> Any:
    if "static" in kwargs:
        kwargs["metadata"] = {
            "static": kwargs.pop("static"),
            **(kwargs.get("metadata") or {}),
        }
    return attrs.field(**kwargs)

frozen

frozen[T: type](
    cls: T,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> T
frozen[T: type](
    cls: None = None,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> Callable[[T], T]

Define a frozen attrs class and register it as a data PyTree.

Source code in src/jarp/tree/attrs/_define.py
def frozen[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define a frozen ``attrs`` class and register it as a data PyTree."""
    _warnings_hide = True
    if maybe_cls is None:
        return functools.partial(frozen, **kwargs)
    kwargs.setdefault("frozen", True)
    return define(maybe_cls, **kwargs)

frozen_static

frozen_static[T: type](
    cls: T,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> T
frozen_static[T: type](
    cls: None = None,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> Callable[[T], T]

Define a frozen attrs class and register it as a static PyTree.

Source code in src/jarp/tree/attrs/_define.py
def frozen_static[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define a frozen ``attrs`` class and register it as a static PyTree."""
    _warnings_hide = True
    if maybe_cls is None:
        return functools.partial(frozen_static, **kwargs)
    kwargs.setdefault("frozen", True)
    kwargs.setdefault("pytree", PyTreeType.STATIC)
    return define(maybe_cls, **kwargs)

jax_callable

jax_callable(
    func: _FfiCallableFunction,
    *,
    generic: Literal[False] = False,
    **kwargs: Unpack[JaxCallableOptions],
) -> FfiCallableProtocol
jax_callable(
    *,
    generic: Literal[False] = False,
    **kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFunction], FfiCallableProtocol]
jax_callable(
    func: _FfiCallableFactory,
    *,
    generic: Literal[True],
    **kwargs: Unpack[JaxCallableOptions],
) -> _FfiCallable
jax_callable(
    *,
    generic: Literal[True],
    **kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFactory], _FfiCallable]

Wrap warp.jax_experimental.jax_callable with optional dtype dispatch.

Parameters:

  • func

    (Callable | None, default: None ) –

    Warp callable function or factory. When omitted, return a decorator.

  • generic

    (bool, default: False ) –

    When true, func is treated as a factory that receives Warp scalar dtypes inferred from the runtime JAX arguments and returns a concrete Warp callable implementation.

  • num_outputs

    (int, default: ... ) –
  • graph_mode

    (GraphMode, default: ... ) –
  • vmap_method

    (str | None, default: ... ) –
  • output_dims

    (dict[str, ShapeLike] | None, default: ... ) –
  • in_out_argnames

    (Iterable[str], default: ... ) –
  • stage_in_argnames

    (Iterable[str], default: ... ) –
  • stage_out_argnames

    (Iterable[str], default: ... ) –
  • graph_cache_max

    (int | None, default: ... ) –
  • module_preload_mode

    (ModulePreloadMode, default: ... ) –

Returns:

  • Any

    A callable compatible with JAX tracing, or a decorator producing one.

Source code in src/jarp/warp/_jax_callable.py
def jax_callable(
    func: Callable | None = None,
    *,
    generic: bool = False,
    **kwargs: Unpack[JaxCallableOptions],
) -> Any:
    """Wrap ``warp.jax_experimental.jax_callable`` with optional dtype dispatch.

    Args:
        func: Warp callable function or factory. When omitted, return a
            decorator.
        generic: When true, ``func`` is treated as a factory that receives Warp
            scalar dtypes inferred from the runtime JAX arguments and returns a
            concrete Warp callable implementation.
        **kwargs: Options forwarded to Warp's JAX callable adapter.

    Returns:
        A callable compatible with JAX tracing, or a decorator producing one.
    """
    if func is None:
        return functools.partial(jax_callable, generic=generic, **kwargs)
    if not generic:
        return warp.jax_experimental.jax_callable(func, **kwargs)
    factory: _FfiCallableFactory = functools.lru_cache(func)
    return _FfiCallable(factory=factory, options=kwargs)  # ty:ignore[invalid-argument-type]

jax_kernel

jax_kernel(
    *,
    arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
    | None = None,
    **kwargs: Unpack[JaxKernelOptions],
) -> Callable[[Callable], FfiKernelProtocol]
jax_kernel(
    kernel: Callable,
    *,
    arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
    | None = None,
    **kwargs: Unpack[JaxKernelOptions],
) -> FfiKernelProtocol

Wrap warp.jax_experimental.jax_kernel with optional overload lookup.

Parameters:

  • kernel

    (Callable | None, default: None ) –

    Warp kernel to expose to JAX. When omitted, return a decorator.

  • arg_types_factory

    (Callable[[WarpScalarDType], ArgTypes] | None, default: None ) –

    Optional callback that maps runtime Warp scalar dtypes to the overloaded kernel argument types expected by :func:warp.overload.

  • num_outputs

    (int, default: ... ) –
  • vmap_method

    (VmapMethod, default: ... ) –
  • launch_dims

    (ShapeLike | None, default: ... ) –
  • output_dims

    (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • in_out_argnames

    (Iterable[str], default: ... ) –
  • module_preload_mode

    (ModulePreloadMode, default: ... ) –
  • enable_backward

    (bool, default: ... ) –

Returns:

  • Any

    A callable compatible with JAX tracing, or a decorator producing one.

Source code in src/jarp/warp/_jax_kernel.py
def jax_kernel(
    kernel: Callable | None = None,
    *,
    arg_types_factory: Callable[[WarpScalarDType], ArgTypes] | None = None,
    **kwargs: Unpack[JaxKernelOptions],
) -> Any:
    """Wrap ``warp.jax_experimental.jax_kernel`` with optional overload lookup.

    Args:
        kernel: Warp kernel to expose to JAX. When omitted, return a decorator.
        arg_types_factory: Optional callback that maps runtime Warp scalar dtypes
            to the overloaded kernel argument types expected by
            :func:`warp.overload`.
        **kwargs: Options forwarded to Warp's JAX kernel adapter.

    Returns:
        A callable compatible with JAX tracing, or a decorator producing one.
    """
    if kernel is None:
        return functools.partial(
            jax_kernel, arg_types_factory=arg_types_factory, **kwargs
        )
    if arg_types_factory is None:
        return warp.jax_experimental.jax_kernel(kernel, **kwargs)
    return _FfiKernel(
        kernel=cast("wp.Kernel", kernel),
        options=kwargs,  # ty:ignore[invalid-argument-type]
        arg_types_factory=arg_types_factory,
    )

jit

jit[**P, T](
    fun: Callable[P, T],
    /,
    *,
    filter: Literal[False] = False,
    **kwargs: Unpack[JitOptions],
) -> Callable[P, T]
jit[**P, T](
    *,
    filter: Literal[False] = False,
    **kwargs: Unpack[JitOptions],
) -> Callable[[Callable[P, T]], Callable[P, T]]
jit[**P, T](
    fun: Callable[P, T],
    /,
    *,
    filter: Literal[True],
    **kwargs: Unpack[FilterJitOptions],
) -> Callable[P, T]
jit[**P, T](
    *,
    filter: Literal[True],
    **kwargs: Unpack[FilterJitOptions],
) -> Callable[[Callable[P, T]], Callable[P, T]]

Compile a callable with JAX, optionally preserving static PyTree leaves.

When filter=False this is a thin wrapper around :func:jax.jit. When filter=True the function and its inputs are partitioned into dynamic array leaves and static metadata so mixed PyTrees can cross the JIT boundary without requiring manual static_argnums wiring.

Parameters:

  • fun

    (Callable[P, T] | None, default: None ) –

    Callable to compile. When omitted, return a decorator.

  • **kwargs

    (Any, default: {} ) –

    Options forwarded to :func:jax.jit. With filter=True, only the subset in :class:FilterJitOptions is supported because static argument handling is managed internally.

Returns:

  • Callable

    A compiled callable or decorator with the same public call signature as

  • Callable

    fun.

Source code in src/jarp/_jit.py
def jit[**P, T](fun: Callable[P, T] | None = None, **kwargs: Any) -> Callable:
    """Compile a callable with JAX, optionally preserving static PyTree leaves.

    When ``filter=False`` this is a thin wrapper around :func:`jax.jit`.
    When ``filter=True`` the function and its inputs are partitioned into
    dynamic array leaves and static metadata so mixed PyTrees can cross the JIT
    boundary without requiring manual ``static_argnums`` wiring.

    Args:
        fun: Callable to compile. When omitted, return a decorator.
        **kwargs: Options forwarded to :func:`jax.jit`. With ``filter=True``,
            only the subset in :class:`FilterJitOptions` is supported because
            static argument handling is managed internally.

    Returns:
        A compiled callable or decorator with the same public call signature as
        ``fun``.
    """
    if fun is None:
        return functools.partial(jit, **kwargs)
    filter_: bool = kwargs.pop("filter", False)
    if not filter_:
        return jax.jit(fun, **kwargs)

    fun_data: _Data
    fun_meta: _Meta[Callable[P, T]]
    fun_data, fun_meta = tree.partition(fun)
    inner: _Inner[T] = _Inner(fun_meta)
    inner_jit: _InnerProtocol[T] = jax.jit(inner, static_argnums=(2,), **kwargs)
    outer: _Outer[P, T] = _Outer(fun_data=fun_data, inner=inner_jit)
    functools.update_wrapper(outer, fun)
    return outer

partial

partial[T](
    func: Callable[..., T], /, *args: Any, **kwargs: Any
) -> Partial[..., T]

Partially apply a callable and keep the result compatible with JAX trees.

Source code in src/jarp/tree/prelude/_partial.py
def partial[T](func: Callable[..., T], /, *args: Any, **kwargs: Any) -> Partial[..., T]:
    """Partially apply a callable and keep the result compatible with JAX trees."""
    return Partial(func, *args, **kwargs)

ravel

ravel[T](tree: T) -> tuple[Array, Structure[T]]

Flatten a PyTree's dynamic leaves into one vector.

Non-array leaves are treated as static metadata and preserved in the returned :class:Structure instead of being concatenated into the flat array.

Parameters:

  • tree

    (T) –

    PyTree to flatten.

Returns:

  • Array

    A tuple of (flat, structure) where flat is a one-dimensional

  • Structure[T]

    JAX array and structure can rebuild compatible trees later.

Source code in src/jarp/tree/_ravel.py
def ravel[T](tree: T) -> tuple[Array, Structure[T]]:
    """Flatten a PyTree's dynamic leaves into one vector.

    Non-array leaves are treated as static metadata and preserved in the
    returned :class:`Structure` instead of being concatenated into the flat
    array.

    Args:
        tree: PyTree to flatten.

    Returns:
        A tuple of ``(flat, structure)`` where ``flat`` is a one-dimensional
        JAX array and ``structure`` can rebuild compatible trees later.
    """
    leaves, treedef = jax.tree.flatten(tree)
    dynamic_leaves, static_leaves = partition_leaves(leaves)
    flat: Array = _ravel(dynamic_leaves)
    structure: Structure[T] = Structure(
        offsets=_offsets_from_leaves(dynamic_leaves),
        shapes=_shapes_from_leaves(dynamic_leaves),
        meta_leaves=tuple(static_leaves),
        treedef=treedef,
        dtype=flat.dtype,
    )
    return flat, structure

register_pytree_prelude cached

register_pytree_prelude() -> None
Source code in src/jarp/tree/prelude/_prelude.py
@functools.cache  # run only once
def register_pytree_prelude() -> None:
    register_pytree_method()
    register_warp_array()

static

static(**kwargs) -> Any
Source code in src/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def static(**kwargs) -> Any:
    # for consistency with `jax.tree_util.register_dataclass`
    kwargs.setdefault("static", True)
    return field(**kwargs)

to_warp

to_warp(
    arr: array | ndarray | Array, *_args, **_kwargs
) -> array

Convert a supported array object into a :class:warp.array.

The generic dispatcher currently supports NumPy arrays and JAX arrays. A dtype hint may be a concrete Warp dtype or a tuple that describes a vector or matrix dtype inferred from the trailing dimensions of arr.

Source code in src/jarp/warp/_to_warp.py
@functools.singledispatch
def to_warp(arr: Any, *_args, **_kwargs) -> wp.array:
    """Convert a supported array object into a :class:`warp.array`.

    The generic dispatcher currently supports NumPy arrays and JAX arrays. A
    ``dtype`` hint may be a concrete Warp dtype or a tuple that describes a
    vector or matrix dtype inferred from the trailing dimensions of ``arr``.
    """
    raise TypeError(arr)

while_loop

while_loop[T](
    cond_fun: Callable[[T], BooleanNumeric],
    body_fun: Callable[[T], T],
    init_val: T,
    *,
    jit: bool = True,
) -> T

Run a loop with either jax.lax.while_loop or Python control flow.

Parameters:

  • cond_fun

    (Callable[[T], BooleanNumeric]) –

    Predicate evaluated on the loop state.

  • body_fun

    (Callable[[T], T]) –

    Function that produces the next loop state.

  • init_val

    (T) –

    Initial loop state.

  • jit

    (bool, default: True ) –

    When true, dispatch to :func:jax.lax.while_loop. When false, run an eager Python while loop with the same callbacks.

Returns:

  • T

    The final loop state.

Source code in src/jarp/lax/_while_loop.py
def while_loop[T](
    cond_fun: Callable[[T], BooleanNumeric],
    body_fun: Callable[[T], T],
    init_val: T,
    *,
    jit: bool = True,
) -> T:
    """Run a loop with either ``jax.lax.while_loop`` or Python control flow.

    Args:
        cond_fun: Predicate evaluated on the loop state.
        body_fun: Function that produces the next loop state.
        init_val: Initial loop state.
        jit: When true, dispatch to :func:`jax.lax.while_loop`. When false, run
            an eager Python ``while`` loop with the same callbacks.

    Returns:
        The final loop state.
    """
    if jit:
        return jax.lax.while_loop(cond_fun, body_fun, init_val)
    val: T = init_val
    while cond_fun(val):
        val: T = body_fun(val)
    return val

lax

Control-flow wrappers that mirror JAX APIs with small ergonomic additions.

Functions:

  • while_loop

    Run a loop with either jax.lax.while_loop or Python control flow.

while_loop

while_loop[T](
    cond_fun: Callable[[T], BooleanNumeric],
    body_fun: Callable[[T], T],
    init_val: T,
    *,
    jit: bool = True,
) -> T

Run a loop with either jax.lax.while_loop or Python control flow.

Parameters:

  • cond_fun

    (Callable[[T], BooleanNumeric]) –

    Predicate evaluated on the loop state.

  • body_fun

    (Callable[[T], T]) –

    Function that produces the next loop state.

  • init_val

    (T) –

    Initial loop state.

  • jit

    (bool, default: True ) –

    When true, dispatch to :func:jax.lax.while_loop. When false, run an eager Python while loop with the same callbacks.

Returns:

  • T

    The final loop state.

Source code in src/jarp/lax/_while_loop.py
def while_loop[T](
    cond_fun: Callable[[T], BooleanNumeric],
    body_fun: Callable[[T], T],
    init_val: T,
    *,
    jit: bool = True,
) -> T:
    """Run a loop with either ``jax.lax.while_loop`` or Python control flow.

    Args:
        cond_fun: Predicate evaluated on the loop state.
        body_fun: Function that produces the next loop state.
        init_val: Initial loop state.
        jit: When true, dispatch to :func:`jax.lax.while_loop`. When false, run
            an eager Python ``while`` loop with the same callbacks.

    Returns:
        The final loop state.
    """
    if jit:
        return jax.lax.while_loop(cond_fun, body_fun, init_val)
    val: T = init_val
    while cond_fun(val):
        val: T = body_fun(val)
    return val

tree

Helpers for defining, flattening, and transforming JAX PyTrees.

Modules:

  • attrs

    attrs-based decorators and field helpers for JAX PyTree classes.

  • codegen

    Code-generation helpers for custom PyTree registrations.

  • prelude

    PyTree-aware wrappers for callables and object proxies.

Classes:

  • AuxData

    Carry the static part of a partitioned PyTree.

  • FieldType
  • Partial

    Store a partially applied callable as a PyTree-aware proxy.

  • PyTreeProxy

    Wrap an arbitrary object and flatten the wrapped value as a PyTree.

  • PyTreeType

    Choose how a class should participate in JAX PyTree flattening.

  • Structure

    Record how to flatten and rebuild a PyTree's dynamic leaves.

Functions:

  • array
  • auto
  • codegen_pytree_functions
  • combine

    Rebuild a PyTree from dynamic leaves and recorded auxiliary data.

  • combine_leaves

    Zip dynamic leaves back together with their static counterparts.

  • define

    Define an attrs class and optionally register it as a PyTree.

  • field
  • frozen

    Define a frozen attrs class and register it as a data PyTree.

  • frozen_static

    Define a frozen attrs class and register it as a static PyTree.

  • is_data

    Return whether an object should stay on the dynamic side of a partition.

  • is_leaf

    Return whether a leaf contributes data to a flattened vector.

  • partial

    Partially apply a callable and keep the result compatible with JAX trees.

  • partition

    Split a PyTree into dynamic leaves and static metadata.

  • partition_leaves

    Separate raw tree leaves into data leaves and metadata leaves.

  • ravel

    Flatten a PyTree's dynamic leaves into one vector.

  • register_fieldz
  • register_generic
  • register_pytree_prelude
  • static

AuxData

Carry the static part of a partitioned PyTree.

Parameters:

Attributes:

meta_leaves instance-attribute

meta_leaves: tuple[Any, ...]

treedef instance-attribute

treedef: Any

FieldType

Bases: StrEnum


              flowchart TD
              jarp.tree.FieldType[FieldType]

              

              click jarp.tree.FieldType href "" "jarp.tree.FieldType"
            

Methods:

Attributes:

AUTO class-attribute instance-attribute

AUTO = auto()

DATA class-attribute instance-attribute

DATA = auto()

META class-attribute instance-attribute

META = auto()

__bool__

__bool__() -> bool
Source code in src/jarp/tree/attrs/_field_specifiers.py
def __bool__(self) -> bool:
    match self:
        case FieldType.META:
            return True
        case FieldType.AUTO | FieldType.DATA:
            # for consistency with `jax.tree_util.register_dataclass`
            return False

Partial

Partial(
    func: Callable[..., T], /, *args: Any, **kwargs: Any
)

Bases: PartialCallableObjectProxy


              flowchart TD
              jarp.tree.Partial[Partial]

              

              click jarp.tree.Partial href "" "jarp.tree.Partial"
            

Store a partially applied callable as a PyTree-aware proxy.

Methods:

Attributes:

Source code in src/jarp/tree/prelude/_partial.py
def __init__(self, func: Callable[..., T], /, *args: Any, **kwargs: Any) -> None:
    """Create a proxy that records bound arguments for PyTree flattening."""
    super().__init__(func, *args, **kwargs)
    self._self_args = args
    self._self_kwargs = kwargs

__wrapped__ instance-attribute

__wrapped__: Callable[..., T]

__call__

__call__(*args: P.args, **kwargs: P.kwargs) -> T
Source code in src/jarp/tree/prelude/_partial.py
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

PyTreeProxy

Bases: BaseObjectProxy


              flowchart TD
              jarp.tree.PyTreeProxy[PyTreeProxy]

              

              click jarp.tree.PyTreeProxy href "" "jarp.tree.PyTreeProxy"
            

Wrap an arbitrary object and flatten the wrapped value as a PyTree.

Attributes:

__wrapped__ instance-attribute

__wrapped__: T

PyTreeType

Bases: StrEnum


              flowchart TD
              jarp.tree.PyTreeType[PyTreeType]

              

              click jarp.tree.PyTreeType href "" "jarp.tree.PyTreeType"
            

Choose how a class should participate in JAX PyTree flattening.

Attributes:

DATA class-attribute instance-attribute

DATA = auto()

NONE class-attribute instance-attribute

NONE = auto()

STATIC class-attribute instance-attribute

STATIC = auto()

Structure

Record how to flatten and rebuild a PyTree's dynamic leaves.

Instances are returned by :func:ravel and capture the original tree definition, the static leaves that were removed from the flat vector, and the offsets needed to reconstruct each dynamic leaf.

Parameters:

Methods:

  • ravel

    Flatten a compatible tree or flatten an array in-place.

  • unravel

    Rebuild the original tree shape from a flat vector.

Attributes:

dtype instance-attribute

dtype: DTypeLike

is_leaf property

is_leaf: bool

Return whether the original tree was a single leaf.

meta_leaves instance-attribute

meta_leaves: tuple[Any, ...]

offsets instance-attribute

offsets: tuple[int, ...]

shapes instance-attribute

shapes: tuple[Shape | None, ...]

treedef instance-attribute

treedef: PyTreeDef

ravel

ravel(tree: T | Array) -> Vector

Flatten a compatible tree or flatten an array in-place.

Parameters:

  • tree
    (T | Array) –

    A tree with the same structure used to build this :class:Structure, or an already-flat array.

Returns:

  • Vector

    A one-dimensional array containing the dynamic leaves.

Source code in src/jarp/tree/_ravel.py
def ravel(self, tree: T | Array) -> Vector:
    """Flatten a compatible tree or flatten an array in-place.

    Args:
        tree: A tree with the same structure used to build this
            :class:`Structure`, or an already-flat array.

    Returns:
        A one-dimensional array containing the dynamic leaves.
    """
    if isinstance(tree, Array):
        # do not flatten if already flat
        return jnp.ravel(tree)
    leaves, treedef = jax.tree.flatten(tree)
    assert treedef == self.treedef
    data_leaves, meta_leaves = partition_leaves(leaves)
    assert tuple(meta_leaves) == self.meta_leaves
    return _ravel(data_leaves)

unravel

unravel(
    flat: T | Array, dtype: DTypeLike | None = None
) -> T

Rebuild the original tree shape from a flat vector.

Parameters:

  • flat
    (T | Array) –

    One-dimensional data produced by :meth:ravel, or a tree that already matches the recorded structure.

  • dtype
    (DTypeLike | None, default: None ) –

    Optional dtype override applied to the flat array before it is split and reshaped.

Returns:

  • T

    A tree with the same structure and static metadata as the original

  • T

    input to :func:ravel.

Source code in src/jarp/tree/_ravel.py
def unravel(self, flat: T | Array, dtype: DTypeLike | None = None) -> T:
    """Rebuild the original tree shape from a flat vector.

    Args:
        flat: One-dimensional data produced by :meth:`ravel`, or a tree that
            already matches the recorded structure.
        dtype: Optional dtype override applied to the flat array before it
            is split and reshaped.

    Returns:
        A tree with the same structure and static metadata as the original
        input to :func:`ravel`.
    """
    if not isinstance(flat, Array):
        # do not unravel if already a pytree
        assert jax.tree.structure(flat) == self.treedef
        return cast("T", flat)
    flat: Array = jnp.asarray(flat, self.dtype if dtype is None else dtype)
    if self.is_leaf:
        return cast("T", jnp.reshape(flat, self.shapes[0]))
    data_leaves: list[Array | None] = _unravel(flat, self.offsets, self.shapes)
    leaves: list[Any] = combine_leaves(data_leaves, self.meta_leaves)
    return jax.tree.unflatten(self.treedef, leaves)

array

array(
    *,
    default: T = ...,
    validator: _ValidatorArgType[T] | None = ...,
    repr: _ReprArgType = ...,
    hash: bool | None = ...,
    init: bool = ...,
    metadata: Mapping[Any, Any] | None = ...,
    converter: _ConverterType
    | list[_ConverterType]
    | tuple[_ConverterType, ...]
    | None = ...,
    factory: Callable[[], T] | None = ...,
    kw_only: bool | None = ...,
    eq: _EqOrderType | None = ...,
    order: _EqOrderType | None = ...,
    on_setattr: _OnSetAttrArgType | None = ...,
    alias: str | None = ...,
    type: type | None = ...,
    static: FieldType | bool | None = ...,
) -> Array

Parameters:

  • default

    (T, default: ... ) –
  • validator

    (_ValidatorArgType[T] | None, default: ... ) –
  • repr

    (_ReprArgType, default: ... ) –
  • hash

    (bool | None, default: ... ) –
  • init

    (bool, default: ... ) –
  • metadata

    (Mapping[Any, Any] | None, default: ... ) –
  • converter

    (_ConverterType | list[_ConverterType] | tuple[_ConverterType, ...] | None, default: ... ) –
  • factory

    (Callable[[], T] | None, default: ... ) –
  • kw_only

    (bool | None, default: ... ) –
  • eq

    (_EqOrderType | None, default: ... ) –
  • order

    (_EqOrderType | None, default: ... ) –
  • on_setattr

    (_OnSetAttrArgType | None, default: ... ) –
  • alias

    (str | None, default: ... ) –
  • type

    (type | None, default: ... ) –
  • static

    (FieldType | bool | None, default: ... ) –
Source code in src/jarp/tree/attrs/_field_specifiers.py
def array(**kwargs: Unpack[FieldOptions[ArrayLike | None]]) -> Array:
    if "default" in kwargs and "factory" not in kwargs:
        default: ArrayLike | None = kwargs["default"]
        if not (default is None or isinstance(default, attrs.Factory)):  # ty:ignore[invalid-argument-type]
            default: Array = jnp.asarray(default)
            kwargs.pop("default")
            kwargs["factory"] = lambda: default
    return field(**kwargs)

auto

auto(**kwargs) -> Any
Source code in src/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def auto(**kwargs) -> Any:
    kwargs.setdefault("static", FieldType.AUTO)
    return field(**kwargs)

codegen_pytree_functions

codegen_pytree_functions(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> PyTreeFunctions
Source code in src/jarp/tree/codegen/_compile.py
def codegen_pytree_functions(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> PyTreeFunctions:
    if bypass_setattr is None:
        bypass_setattr = cls.__setattr__ is not object.__setattr__
    flatten_def: ast.FunctionDef = codegen_flatten(
        data_fields, meta_fields, auto_fields
    )
    flatten_with_keys_def: ast.FunctionDef = codegen_flatten_with_keys(
        data_fields, meta_fields, auto_fields
    )
    unflatten_def: ast.FunctionDef = codegen_unflatten(
        data_fields, meta_fields, auto_fields, bypass_setattr=bypass_setattr
    )
    module: ast.Module = ast.Module(
        body=[flatten_def, flatten_with_keys_def, unflatten_def], type_ignores=[]
    )
    module = ast.fix_missing_locations(module)
    source: str = ast.unparse(module)
    namespace: dict = {
        "_cls": cls,
        "_filter_spec": filter_spec,
        "_object_new": object.__new__,
        "_object_setattr": object.__setattr__,
        **_make_keys((*data_fields, *meta_fields, *auto_fields)),
    }
    filename: str = _make_filename(cls)
    # use unparse source so we have correct source code locations
    code: types.CodeType = compile(source, filename, "exec")
    exec(code, namespace)  # noqa: S102
    _update_linecache(source, filename)
    return PyTreeFunctions(
        _add_dunder(cls, namespace["flatten"]),
        _add_dunder(cls, namespace["unflatten"]),
        _add_dunder(cls, namespace["flatten_with_keys"]),
    )

combine

combine[T](
    data_leaves: Iterable[Array | None], aux: AuxData[T]
) -> T

Rebuild a PyTree from dynamic leaves and recorded auxiliary data.

Source code in src/jarp/tree/_filters.py
def combine[T](data_leaves: Iterable[Array | None], aux: AuxData[T]) -> T:
    """Rebuild a PyTree from dynamic leaves and recorded auxiliary data."""
    leaves: list[Any] = combine_leaves(data_leaves, aux.meta_leaves)
    return jax.tree.unflatten(aux.treedef, leaves)

combine_leaves

combine_leaves(
    data_leaves: Iterable[Array | None],
    meta_leaves: Iterable[Any],
) -> list[Any]

Zip dynamic leaves back together with their static counterparts.

Source code in src/jarp/tree/_filters.py
def combine_leaves(
    data_leaves: Iterable[Array | None], meta_leaves: Iterable[Any]
) -> list[Any]:
    """Zip dynamic leaves back together with their static counterparts."""
    return [
        data_leaf if meta_leaf is None else meta_leaf
        for data_leaf, meta_leaf in zip(data_leaves, meta_leaves, strict=True)
    ]

define

define[T: type](
    cls: T,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> T
define[T: type](
    cls: None = None,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> Callable[[T], T]

Define an attrs class and optionally register it as a PyTree.

Parameters:

  • maybe_cls

    (T | None, default: None ) –

    Class being decorated. When omitted, return a configured decorator.

  • **kwargs

    (Any, default: {} ) –

    Options forwarded to :func:attrs.define, plus pytree to control JAX registration. pytree="data" registers fields with fieldz semantics, "static" registers the whole instance as a static value, and "none" leaves the class unregistered.

Returns:

  • Any

    The decorated class or a class decorator.

Source code in src/jarp/tree/attrs/_define.py
def define[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define an ``attrs`` class and optionally register it as a PyTree.

    Args:
        maybe_cls: Class being decorated. When omitted, return a configured
            decorator.
        **kwargs: Options forwarded to :func:`attrs.define`, plus ``pytree`` to
            control JAX registration. ``pytree="data"`` registers fields with
            ``fieldz`` semantics, ``"static"`` registers the whole instance as a
            static value, and ``"none"`` leaves the class unregistered.

    Returns:
        The decorated class or a class decorator.
    """
    if maybe_cls is None:
        return functools.partial(define, **kwargs)
    pytree: PyTreeType = PyTreeType(kwargs.pop("pytree", None))
    frozen: bool = kwargs.get("frozen", False)
    if pytree is PyTreeType.STATIC and not frozen:
        warnings.warn(
            "Defining a static class that is not frozen may lead to unexpected behavior.",
            stacklevel=2,
        )
    cls: T = attrs.define(maybe_cls, **kwargs)  # ty:ignore[invalid-assignment]
    match pytree:
        case PyTreeType.DATA:
            register_fieldz(cls)
        case PyTreeType.STATIC:
            jtu.register_static(cls)
    return cls

field

field(**kwargs) -> Any
Source code in src/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def field(**kwargs) -> Any:
    if "static" in kwargs:
        kwargs["metadata"] = {
            "static": kwargs.pop("static"),
            **(kwargs.get("metadata") or {}),
        }
    return attrs.field(**kwargs)

frozen

frozen[T: type](
    cls: T,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> T
frozen[T: type](
    cls: None = None,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> Callable[[T], T]

Define a frozen attrs class and register it as a data PyTree.

Source code in src/jarp/tree/attrs/_define.py
def frozen[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define a frozen ``attrs`` class and register it as a data PyTree."""
    _warnings_hide = True
    if maybe_cls is None:
        return functools.partial(frozen, **kwargs)
    kwargs.setdefault("frozen", True)
    return define(maybe_cls, **kwargs)

frozen_static

frozen_static[T: type](
    cls: T,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> T
frozen_static[T: type](
    cls: None = None,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> Callable[[T], T]

Define a frozen attrs class and register it as a static PyTree.

Source code in src/jarp/tree/attrs/_define.py
def frozen_static[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define a frozen ``attrs`` class and register it as a static PyTree."""
    _warnings_hide = True
    if maybe_cls is None:
        return functools.partial(frozen_static, **kwargs)
    kwargs.setdefault("frozen", True)
    kwargs.setdefault("pytree", PyTreeType.STATIC)
    return define(maybe_cls, **kwargs)

is_data

is_data(obj: Any) -> bool

Return whether an object should stay on the dynamic side of a partition.

Source code in src/jarp/tree/_filters.py
def is_data(obj: Any) -> bool:
    """Return whether an object should stay on the dynamic side of a partition."""
    return obj is None or isinstance(obj, Array) or type(obj) in _registry

is_leaf

is_leaf(obj: Any) -> TypeIs[Array | None]

Return whether a leaf contributes data to a flattened vector.

Source code in src/jarp/tree/_filters.py
def is_leaf(obj: Any) -> TypeIs[Array | None]:
    """Return whether a leaf contributes data to a flattened vector."""
    return obj is None or isinstance(obj, Array)

partial

partial[T](
    func: Callable[..., T], /, *args: Any, **kwargs: Any
) -> Partial[..., T]

Partially apply a callable and keep the result compatible with JAX trees.

Source code in src/jarp/tree/prelude/_partial.py
def partial[T](func: Callable[..., T], /, *args: Any, **kwargs: Any) -> Partial[..., T]:
    """Partially apply a callable and keep the result compatible with JAX trees."""
    return Partial(func, *args, **kwargs)

partition

partition[T](
    obj: T,
) -> tuple[list[Array | None], AuxData[T]]

Split a PyTree into dynamic leaves and static metadata.

Source code in src/jarp/tree/_filters.py
def partition[T](obj: T) -> tuple[list[Array | None], AuxData[T]]:
    """Split a PyTree into dynamic leaves and static metadata."""
    leaves: list[Any]
    treedef: Any
    leaves, treedef = jax.tree.flatten(obj)
    data_leaves: list[Array | None]
    meta_leaves: list[Any]
    data_leaves, meta_leaves = partition_leaves(leaves)
    return data_leaves, AuxData(tuple(meta_leaves), treedef)

partition_leaves

partition_leaves(
    leaves: list[Any],
) -> tuple[list[Array | None], list[Any]]

Separate raw tree leaves into data leaves and metadata leaves.

Source code in src/jarp/tree/_filters.py
def partition_leaves(leaves: list[Any]) -> tuple[list[Array | None], list[Any]]:
    """Separate raw tree leaves into data leaves and metadata leaves."""
    data_leaves: list[Array | None] = []
    meta_leaves: list[Any] = []
    for leaf in leaves:
        if is_leaf(leaf):
            data_leaves.append(leaf)
            meta_leaves.append(None)
        else:
            data_leaves.append(None)
            meta_leaves.append(leaf)
    return data_leaves, meta_leaves

ravel

ravel[T](tree: T) -> tuple[Array, Structure[T]]

Flatten a PyTree's dynamic leaves into one vector.

Non-array leaves are treated as static metadata and preserved in the returned :class:Structure instead of being concatenated into the flat array.

Parameters:

  • tree

    (T) –

    PyTree to flatten.

Returns:

  • Array

    A tuple of (flat, structure) where flat is a one-dimensional

  • Structure[T]

    JAX array and structure can rebuild compatible trees later.

Source code in src/jarp/tree/_ravel.py
def ravel[T](tree: T) -> tuple[Array, Structure[T]]:
    """Flatten a PyTree's dynamic leaves into one vector.

    Non-array leaves are treated as static metadata and preserved in the
    returned :class:`Structure` instead of being concatenated into the flat
    array.

    Args:
        tree: PyTree to flatten.

    Returns:
        A tuple of ``(flat, structure)`` where ``flat`` is a one-dimensional
        JAX array and ``structure`` can rebuild compatible trees later.
    """
    leaves, treedef = jax.tree.flatten(tree)
    dynamic_leaves, static_leaves = partition_leaves(leaves)
    flat: Array = _ravel(dynamic_leaves)
    structure: Structure[T] = Structure(
        offsets=_offsets_from_leaves(dynamic_leaves),
        shapes=_shapes_from_leaves(dynamic_leaves),
        meta_leaves=tuple(static_leaves),
        treedef=treedef,
        dtype=flat.dtype,
    )
    return flat, structure

register_fieldz

register_fieldz[T: type](
    cls: T,
    data_fields: Sequence[str] | None = None,
    meta_fields: Sequence[str] | None = None,
    auto_fields: Sequence[str] | None = None,
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> T
Source code in src/jarp/tree/attrs/_register.py
def register_fieldz[T: type](
    cls: T,
    data_fields: Sequence[str] | None = None,
    meta_fields: Sequence[str] | None = None,
    auto_fields: Sequence[str] | None = None,
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> T:
    if data_fields is None:
        data_fields: list[str] = _filter_field_names(cls, FieldType.DATA)
    if meta_fields is None:
        meta_fields: list[str] = _filter_field_names(cls, FieldType.META)
    if auto_fields is None:
        auto_fields: list[str] = _filter_field_names(cls, FieldType.AUTO)
    register_generic(
        cls,
        data_fields,
        meta_fields,
        auto_fields,
        filter_spec=filter_spec,
        bypass_setattr=bypass_setattr,
    )
    return cls

register_generic

register_generic(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> None
Source code in src/jarp/tree/codegen/_compile.py
def register_generic(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> None:
    flatten: Callable
    unflatten: Callable
    flatten_with_keys: Callable
    flatten, unflatten, flatten_with_keys = codegen_pytree_functions(
        cls,
        data_fields,
        meta_fields,
        auto_fields,
        filter_spec=filter_spec,
        bypass_setattr=bypass_setattr,
    )
    jtu.register_pytree_node(cls, flatten, unflatten, flatten_with_keys)

register_pytree_prelude cached

register_pytree_prelude() -> None
Source code in src/jarp/tree/prelude/_prelude.py
@functools.cache  # run only once
def register_pytree_prelude() -> None:
    register_pytree_method()
    register_warp_array()

static

static(**kwargs) -> Any
Source code in src/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def static(**kwargs) -> Any:
    # for consistency with `jax.tree_util.register_dataclass`
    kwargs.setdefault("static", True)
    return field(**kwargs)

attrs

attrs-based decorators and field helpers for JAX PyTree classes.

Classes:

Functions:

FieldType

Bases: StrEnum


              flowchart TD
              jarp.tree.attrs.FieldType[FieldType]

              

              click jarp.tree.attrs.FieldType href "" "jarp.tree.attrs.FieldType"
            

Methods:

Attributes:

AUTO class-attribute instance-attribute
AUTO = auto()
DATA class-attribute instance-attribute
DATA = auto()
META class-attribute instance-attribute
META = auto()
__bool__
__bool__() -> bool
Source code in src/jarp/tree/attrs/_field_specifiers.py
def __bool__(self) -> bool:
    match self:
        case FieldType.META:
            return True
        case FieldType.AUTO | FieldType.DATA:
            # for consistency with `jax.tree_util.register_dataclass`
            return False

PyTreeType

Bases: StrEnum


              flowchart TD
              jarp.tree.attrs.PyTreeType[PyTreeType]

              

              click jarp.tree.attrs.PyTreeType href "" "jarp.tree.attrs.PyTreeType"
            

Choose how a class should participate in JAX PyTree flattening.

Attributes:

DATA class-attribute instance-attribute
DATA = auto()
NONE class-attribute instance-attribute
NONE = auto()
STATIC class-attribute instance-attribute
STATIC = auto()

array

array(
    *,
    default: T = ...,
    validator: _ValidatorArgType[T] | None = ...,
    repr: _ReprArgType = ...,
    hash: bool | None = ...,
    init: bool = ...,
    metadata: Mapping[Any, Any] | None = ...,
    converter: _ConverterType
    | list[_ConverterType]
    | tuple[_ConverterType, ...]
    | None = ...,
    factory: Callable[[], T] | None = ...,
    kw_only: bool | None = ...,
    eq: _EqOrderType | None = ...,
    order: _EqOrderType | None = ...,
    on_setattr: _OnSetAttrArgType | None = ...,
    alias: str | None = ...,
    type: type | None = ...,
    static: FieldType | bool | None = ...,
) -> Array

Parameters:

  • default
    (T, default: ... ) –
  • validator
    (_ValidatorArgType[T] | None, default: ... ) –
  • repr
    (_ReprArgType, default: ... ) –
  • hash
    (bool | None, default: ... ) –
  • init
    (bool, default: ... ) –
  • metadata
    (Mapping[Any, Any] | None, default: ... ) –
  • converter
    (_ConverterType | list[_ConverterType] | tuple[_ConverterType, ...] | None, default: ... ) –
  • factory
    (Callable[[], T] | None, default: ... ) –
  • kw_only
    (bool | None, default: ... ) –
  • eq
    (_EqOrderType | None, default: ... ) –
  • order
    (_EqOrderType | None, default: ... ) –
  • on_setattr
    (_OnSetAttrArgType | None, default: ... ) –
  • alias
    (str | None, default: ... ) –
  • type
    (type | None, default: ... ) –
  • static
    (FieldType | bool | None, default: ... ) –
Source code in src/jarp/tree/attrs/_field_specifiers.py
def array(**kwargs: Unpack[FieldOptions[ArrayLike | None]]) -> Array:
    if "default" in kwargs and "factory" not in kwargs:
        default: ArrayLike | None = kwargs["default"]
        if not (default is None or isinstance(default, attrs.Factory)):  # ty:ignore[invalid-argument-type]
            default: Array = jnp.asarray(default)
            kwargs.pop("default")
            kwargs["factory"] = lambda: default
    return field(**kwargs)

auto

auto(**kwargs) -> Any
Source code in src/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def auto(**kwargs) -> Any:
    kwargs.setdefault("static", FieldType.AUTO)
    return field(**kwargs)

define

define[T: type](
    cls: T,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> T
define[T: type](
    cls: None = None,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> Callable[[T], T]

Define an attrs class and optionally register it as a PyTree.

Parameters:

  • maybe_cls
    (T | None, default: None ) –

    Class being decorated. When omitted, return a configured decorator.

  • **kwargs
    (Any, default: {} ) –

    Options forwarded to :func:attrs.define, plus pytree to control JAX registration. pytree="data" registers fields with fieldz semantics, "static" registers the whole instance as a static value, and "none" leaves the class unregistered.

Returns:

  • Any

    The decorated class or a class decorator.

Source code in src/jarp/tree/attrs/_define.py
def define[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define an ``attrs`` class and optionally register it as a PyTree.

    Args:
        maybe_cls: Class being decorated. When omitted, return a configured
            decorator.
        **kwargs: Options forwarded to :func:`attrs.define`, plus ``pytree`` to
            control JAX registration. ``pytree="data"`` registers fields with
            ``fieldz`` semantics, ``"static"`` registers the whole instance as a
            static value, and ``"none"`` leaves the class unregistered.

    Returns:
        The decorated class or a class decorator.
    """
    if maybe_cls is None:
        return functools.partial(define, **kwargs)
    pytree: PyTreeType = PyTreeType(kwargs.pop("pytree", None))
    frozen: bool = kwargs.get("frozen", False)
    if pytree is PyTreeType.STATIC and not frozen:
        warnings.warn(
            "Defining a static class that is not frozen may lead to unexpected behavior.",
            stacklevel=2,
        )
    cls: T = attrs.define(maybe_cls, **kwargs)  # ty:ignore[invalid-assignment]
    match pytree:
        case PyTreeType.DATA:
            register_fieldz(cls)
        case PyTreeType.STATIC:
            jtu.register_static(cls)
    return cls

field

field(**kwargs) -> Any
Source code in src/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def field(**kwargs) -> Any:
    if "static" in kwargs:
        kwargs["metadata"] = {
            "static": kwargs.pop("static"),
            **(kwargs.get("metadata") or {}),
        }
    return attrs.field(**kwargs)

frozen

frozen[T: type](
    cls: T,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> T
frozen[T: type](
    cls: None = None,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> Callable[[T], T]

Define a frozen attrs class and register it as a data PyTree.

Source code in src/jarp/tree/attrs/_define.py
def frozen[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define a frozen ``attrs`` class and register it as a data PyTree."""
    _warnings_hide = True
    if maybe_cls is None:
        return functools.partial(frozen, **kwargs)
    kwargs.setdefault("frozen", True)
    return define(maybe_cls, **kwargs)

frozen_static

frozen_static[T: type](
    cls: T,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> T
frozen_static[T: type](
    cls: None = None,
    /,
    *,
    kw_only: bool = False,
    **kwargs: Unpack[DefineOptions],
) -> Callable[[T], T]

Define a frozen attrs class and register it as a static PyTree.

Source code in src/jarp/tree/attrs/_define.py
def frozen_static[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define a frozen ``attrs`` class and register it as a static PyTree."""
    _warnings_hide = True
    if maybe_cls is None:
        return functools.partial(frozen_static, **kwargs)
    kwargs.setdefault("frozen", True)
    kwargs.setdefault("pytree", PyTreeType.STATIC)
    return define(maybe_cls, **kwargs)

register_fieldz

register_fieldz[T: type](
    cls: T,
    data_fields: Sequence[str] | None = None,
    meta_fields: Sequence[str] | None = None,
    auto_fields: Sequence[str] | None = None,
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> T
Source code in src/jarp/tree/attrs/_register.py
def register_fieldz[T: type](
    cls: T,
    data_fields: Sequence[str] | None = None,
    meta_fields: Sequence[str] | None = None,
    auto_fields: Sequence[str] | None = None,
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> T:
    if data_fields is None:
        data_fields: list[str] = _filter_field_names(cls, FieldType.DATA)
    if meta_fields is None:
        meta_fields: list[str] = _filter_field_names(cls, FieldType.META)
    if auto_fields is None:
        auto_fields: list[str] = _filter_field_names(cls, FieldType.AUTO)
    register_generic(
        cls,
        data_fields,
        meta_fields,
        auto_fields,
        filter_spec=filter_spec,
        bypass_setattr=bypass_setattr,
    )
    return cls

static

static(**kwargs) -> Any
Source code in src/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def static(**kwargs) -> Any:
    # for consistency with `jax.tree_util.register_dataclass`
    kwargs.setdefault("static", True)
    return field(**kwargs)

codegen

Code-generation helpers for custom PyTree registrations.

Classes:

  • PyTreeFunctions

    PyTreeFunctions(flatten, unflatten, flatten_with_keys)

Functions:

PyTreeFunctions

Bases: NamedTuple


              flowchart TD
              jarp.tree.codegen.PyTreeFunctions[PyTreeFunctions]

              

              click jarp.tree.codegen.PyTreeFunctions href "" "jarp.tree.codegen.PyTreeFunctions"
            

PyTreeFunctions(flatten, unflatten, flatten_with_keys)

Parameters:

Attributes:

flatten instance-attribute
flatten: Callable[[T], tuple[_Children, _AuxData]]
flatten_with_keys instance-attribute
flatten_with_keys: Callable[
    [T], tuple[_ChildrenWithKeys, _AuxData]
]
unflatten instance-attribute
unflatten: Callable[[_AuxData, _Children], T]

codegen_flatten

codegen_flatten(
    data_fields: Sequence[str],
    meta_fields: Sequence[str],
    auto_fields: Sequence[str],
) -> FunctionDef
Source code in src/jarp/tree/codegen/_codegen.py
def codegen_flatten(
    data_fields: Sequence[str], meta_fields: Sequence[str], auto_fields: Sequence[str]
) -> FunctionDef:
    body: list[stmt] = codegen_partition(auto_fields)
    children: list[expr] = codegen_children(data_fields, auto_fields)
    aux: list[expr] = codegen_aux(meta_fields, auto_fields)
    body.append(Return(Tuple([Tuple(children, Load()), Tuple(aux, Load())], Load())))
    return codegen_function_def("flatten", [arg("obj")], body)

codegen_flatten_with_keys

codegen_flatten_with_keys(
    data_fields: Sequence[str],
    meta_fields: Sequence[str],
    auto_fields: Sequence[str],
) -> FunctionDef
Source code in src/jarp/tree/codegen/_codegen.py
def codegen_flatten_with_keys(
    data_fields: Sequence[str], meta_fields: Sequence[str], auto_fields: Sequence[str]
) -> FunctionDef:
    body: list[stmt] = codegen_partition(auto_fields)
    children: list[expr] = codegen_children(data_fields, auto_fields)
    aux: list[expr] = codegen_aux(meta_fields, auto_fields)
    keys: list[expr] = [
        Name(f"_{name}_key", Load()) for name in (*data_fields, *auto_fields)
    ]
    children_with_keys: list[expr] = [
        Tuple([key, child], Load()) for key, child in zip(keys, children, strict=True)
    ]
    body.append(
        Return(Tuple([Tuple(children_with_keys, Load()), Tuple(aux, Load())], Load()))
    )
    return codegen_function_def("flatten_with_keys", [arg("obj")], body)

codegen_pytree_functions

codegen_pytree_functions(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> PyTreeFunctions
Source code in src/jarp/tree/codegen/_compile.py
def codegen_pytree_functions(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> PyTreeFunctions:
    if bypass_setattr is None:
        bypass_setattr = cls.__setattr__ is not object.__setattr__
    flatten_def: ast.FunctionDef = codegen_flatten(
        data_fields, meta_fields, auto_fields
    )
    flatten_with_keys_def: ast.FunctionDef = codegen_flatten_with_keys(
        data_fields, meta_fields, auto_fields
    )
    unflatten_def: ast.FunctionDef = codegen_unflatten(
        data_fields, meta_fields, auto_fields, bypass_setattr=bypass_setattr
    )
    module: ast.Module = ast.Module(
        body=[flatten_def, flatten_with_keys_def, unflatten_def], type_ignores=[]
    )
    module = ast.fix_missing_locations(module)
    source: str = ast.unparse(module)
    namespace: dict = {
        "_cls": cls,
        "_filter_spec": filter_spec,
        "_object_new": object.__new__,
        "_object_setattr": object.__setattr__,
        **_make_keys((*data_fields, *meta_fields, *auto_fields)),
    }
    filename: str = _make_filename(cls)
    # use unparse source so we have correct source code locations
    code: types.CodeType = compile(source, filename, "exec")
    exec(code, namespace)  # noqa: S102
    _update_linecache(source, filename)
    return PyTreeFunctions(
        _add_dunder(cls, namespace["flatten"]),
        _add_dunder(cls, namespace["unflatten"]),
        _add_dunder(cls, namespace["flatten_with_keys"]),
    )

codegen_unflatten

codegen_unflatten(
    data_fields: Sequence[str],
    meta_fields: Sequence[str],
    auto_fields: Sequence[str],
    *,
    bypass_setattr: bool = False,
) -> FunctionDef
Source code in src/jarp/tree/codegen/_codegen.py
def codegen_unflatten(
    data_fields: Sequence[str],
    meta_fields: Sequence[str],
    auto_fields: Sequence[str],
    *,
    bypass_setattr: bool = False,
) -> FunctionDef:
    body: list[stmt] = [
        Assign(
            [Name("obj", Store())],
            Call(Name("_object_new", Load()), [Name("_cls", Load())], []),
        )
    ]
    if bypass_setattr:
        body.extend(_codegen_unflatten_bypass(data_fields, meta_fields, auto_fields))
    else:
        body.extend(_codegen_unflatten_direct(data_fields, meta_fields, auto_fields))
    body.append(Return(Name("obj", Load())))
    return codegen_function_def("unflatten", [arg("aux"), arg("children")], body)

register_generic

register_generic(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> None
Source code in src/jarp/tree/codegen/_compile.py
def register_generic(
    cls: type,
    data_fields: Sequence[str] = (),
    meta_fields: Sequence[str] = (),
    auto_fields: Sequence[str] = (),
    *,
    filter_spec: Callable[[Any], bool] = is_data,
    bypass_setattr: bool | None = None,
) -> None:
    flatten: Callable
    unflatten: Callable
    flatten_with_keys: Callable
    flatten, unflatten, flatten_with_keys = codegen_pytree_functions(
        cls,
        data_fields,
        meta_fields,
        auto_fields,
        filter_spec=filter_spec,
        bypass_setattr=bypass_setattr,
    )
    jtu.register_pytree_node(cls, flatten, unflatten, flatten_with_keys)

prelude

PyTree-aware wrappers for callables and object proxies.

Classes:

  • Partial

    Store a partially applied callable as a PyTree-aware proxy.

  • PyTreeProxy

    Wrap an arbitrary object and flatten the wrapped value as a PyTree.

Functions:

Partial

Partial(
    func: Callable[..., T], /, *args: Any, **kwargs: Any
)

Bases: PartialCallableObjectProxy


              flowchart TD
              jarp.tree.prelude.Partial[Partial]

              

              click jarp.tree.prelude.Partial href "" "jarp.tree.prelude.Partial"
            

Store a partially applied callable as a PyTree-aware proxy.

Methods:

Attributes:

Source code in src/jarp/tree/prelude/_partial.py
def __init__(self, func: Callable[..., T], /, *args: Any, **kwargs: Any) -> None:
    """Create a proxy that records bound arguments for PyTree flattening."""
    super().__init__(func, *args, **kwargs)
    self._self_args = args
    self._self_kwargs = kwargs
__wrapped__ instance-attribute
__wrapped__: Callable[..., T]
__call__
__call__(*args: P.args, **kwargs: P.kwargs) -> T
Source code in src/jarp/tree/prelude/_partial.py
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

PyTreeProxy

Bases: BaseObjectProxy


              flowchart TD
              jarp.tree.prelude.PyTreeProxy[PyTreeProxy]

              

              click jarp.tree.prelude.PyTreeProxy href "" "jarp.tree.prelude.PyTreeProxy"
            

Wrap an arbitrary object and flatten the wrapped value as a PyTree.

Attributes:

__wrapped__ instance-attribute
__wrapped__: T

partial

partial[T](
    func: Callable[..., T], /, *args: Any, **kwargs: Any
) -> Partial[..., T]

Partially apply a callable and keep the result compatible with JAX trees.

Source code in src/jarp/tree/prelude/_partial.py
def partial[T](func: Callable[..., T], /, *args: Any, **kwargs: Any) -> Partial[..., T]:
    """Partially apply a callable and keep the result compatible with JAX trees."""
    return Partial(func, *args, **kwargs)

register_pytree_prelude cached

register_pytree_prelude() -> None
Source code in src/jarp/tree/prelude/_prelude.py
@functools.cache  # run only once
def register_pytree_prelude() -> None:
    register_pytree_method()
    register_warp_array()

warp

Interop helpers between JAX arrays and Warp kernels or callables.

Modules:

  • types

    Convenience accessors for Warp scalar, vector, and matrix dtypes.

Classes:

Functions:

  • jax_callable

    Wrap warp.jax_experimental.jax_callable with optional dtype dispatch.

  • jax_kernel

    Wrap warp.jax_experimental.jax_kernel with optional overload lookup.

  • to_warp

    Convert a supported array object into a :class:warp.array.

FfiCallableProtocol

Bases: Protocol


              flowchart TD
              jarp.warp.FfiCallableProtocol[FfiCallableProtocol]

              

              click jarp.warp.FfiCallableProtocol href "" "jarp.warp.FfiCallableProtocol"
            

Callable interface returned by :func:jax_callable.

Methods:

__call__

__call__(
    *args: Array,
    output_dims: ShapeLike
    | dict[str, ShapeLike]
    | None = ...,
    vmap_method: VmapMethod | None = ...,
) -> Sequence[Array]

Parameters:

  • output_dims
    (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • vmap_method
    (VmapMethod | None, default: ... ) –
Source code in src/jarp/warp/_jax_callable.py
def __call__(
    self, *args: Array, **kwargs: Unpack[JaxCallableCallOptions]
) -> Sequence[Array]: ...

FfiKernelProtocol

Bases: Protocol


              flowchart TD
              jarp.warp.FfiKernelProtocol[FfiKernelProtocol]

              

              click jarp.warp.FfiKernelProtocol href "" "jarp.warp.FfiKernelProtocol"
            

Callable interface returned by :func:jax_kernel.

Methods:

__call__

__call__(
    *args: Array,
    output_dims: ShapeLike
    | dict[str, ShapeLike]
    | None = ...,
    launch_dims: ShapeLike | None = ...,
    vmap_method: VmapMethod | None = ...,
) -> Sequence[Array]

Parameters:

  • output_dims
    (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • launch_dims
    (ShapeLike | None, default: ... ) –
  • vmap_method
    (VmapMethod | None, default: ... ) –
Source code in src/jarp/warp/_jax_kernel.py
def __call__(
    self, *args: Array, **kwargs: Unpack[JaxKernelCallOptions]
) -> Sequence[Array]: ...

JaxCallableCallOptions typed-dict

JaxCallableCallOptions(
    *,
    output_dims: ShapeLike
    | dict[str, ShapeLike]
    | None = ...,
    vmap_method: VmapMethod | None = ...,
)

Bases: TypedDict


              flowchart TD
              jarp.warp.JaxCallableCallOptions[JaxCallableCallOptions]

              

              click jarp.warp.JaxCallableCallOptions href "" "jarp.warp.JaxCallableCallOptions"
            

Parameters:

  • output_dims

    (ShapeLike | dict[str, ShapeLike] | None) –
  • vmap_method

    (VmapMethod | None) –

Parameters:

  • output_dims

    (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • vmap_method

    (VmapMethod | None, default: ... ) –

JaxCallableOptions typed-dict

JaxCallableOptions(
    *,
    num_outputs: int = ...,
    graph_mode: GraphMode = ...,
    vmap_method: str | None = ...,
    output_dims: dict[str, ShapeLike] | None = ...,
    in_out_argnames: Iterable[str] = ...,
    stage_in_argnames: Iterable[str] = ...,
    stage_out_argnames: Iterable[str] = ...,
    graph_cache_max: int | None = ...,
    module_preload_mode: ModulePreloadMode = ...,
)

Bases: TypedDict


              flowchart TD
              jarp.warp.JaxCallableOptions[JaxCallableOptions]

              

              click jarp.warp.JaxCallableOptions href "" "jarp.warp.JaxCallableOptions"
            

Parameters:

  • num_outputs

    (int) –
  • graph_mode

    (GraphMode) –
  • vmap_method

    (str | None) –
  • output_dims

    (dict[str, ShapeLike] | None) –
  • in_out_argnames

    (Iterable[str]) –
  • stage_in_argnames

    (Iterable[str]) –
  • stage_out_argnames

    (Iterable[str]) –
  • graph_cache_max

    (int | None) –
  • module_preload_mode

    (ModulePreloadMode) –

Parameters:

  • num_outputs

    (int, default: ... ) –
  • graph_mode

    (GraphMode, default: ... ) –
  • vmap_method

    (str | None, default: ... ) –
  • output_dims

    (dict[str, ShapeLike] | None, default: ... ) –
  • in_out_argnames

    (Iterable[str], default: ... ) –
  • stage_in_argnames

    (Iterable[str], default: ... ) –
  • stage_out_argnames

    (Iterable[str], default: ... ) –
  • graph_cache_max

    (int | None, default: ... ) –
  • module_preload_mode

    (ModulePreloadMode, default: ... ) –

JaxKernelCallOptions typed-dict

JaxKernelCallOptions(
    *,
    output_dims: ShapeLike
    | dict[str, ShapeLike]
    | None = ...,
    launch_dims: ShapeLike | None = ...,
    vmap_method: VmapMethod | None = ...,
)

Bases: TypedDict


              flowchart TD
              jarp.warp.JaxKernelCallOptions[JaxKernelCallOptions]

              

              click jarp.warp.JaxKernelCallOptions href "" "jarp.warp.JaxKernelCallOptions"
            

Parameters:

  • output_dims

    (ShapeLike | dict[str, ShapeLike] | None) –
  • launch_dims

    (ShapeLike | None) –
  • vmap_method

    (VmapMethod | None) –

Parameters:

  • output_dims

    (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • launch_dims

    (ShapeLike | None, default: ... ) –
  • vmap_method

    (VmapMethod | None, default: ... ) –

JaxKernelOptions typed-dict

JaxKernelOptions(
    *,
    num_outputs: int = ...,
    vmap_method: VmapMethod = ...,
    launch_dims: ShapeLike | None = ...,
    output_dims: ShapeLike
    | dict[str, ShapeLike]
    | None = ...,
    in_out_argnames: Iterable[str] = ...,
    module_preload_mode: ModulePreloadMode = ...,
    enable_backward: bool = ...,
)

Bases: TypedDict


              flowchart TD
              jarp.warp.JaxKernelOptions[JaxKernelOptions]

              

              click jarp.warp.JaxKernelOptions href "" "jarp.warp.JaxKernelOptions"
            

Parameters:

  • num_outputs

    (int) –
  • vmap_method

    (VmapMethod) –
  • launch_dims

    (ShapeLike | None) –
  • output_dims

    (ShapeLike | dict[str, ShapeLike] | None) –
  • in_out_argnames

    (Iterable[str]) –
  • module_preload_mode

    (ModulePreloadMode) –
  • enable_backward

    (bool) –

Parameters:

  • num_outputs

    (int, default: ... ) –
  • vmap_method

    (VmapMethod, default: ... ) –
  • launch_dims

    (ShapeLike | None, default: ... ) –
  • output_dims

    (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • in_out_argnames

    (Iterable[str], default: ... ) –
  • module_preload_mode

    (ModulePreloadMode, default: ... ) –
  • enable_backward

    (bool, default: ... ) –

jax_callable

jax_callable(
    func: _FfiCallableFunction,
    *,
    generic: Literal[False] = False,
    **kwargs: Unpack[JaxCallableOptions],
) -> FfiCallableProtocol
jax_callable(
    *,
    generic: Literal[False] = False,
    **kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFunction], FfiCallableProtocol]
jax_callable(
    func: _FfiCallableFactory,
    *,
    generic: Literal[True],
    **kwargs: Unpack[JaxCallableOptions],
) -> _FfiCallable
jax_callable(
    *,
    generic: Literal[True],
    **kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFactory], _FfiCallable]

Wrap warp.jax_experimental.jax_callable with optional dtype dispatch.

Parameters:

  • func

    (Callable | None, default: None ) –

    Warp callable function or factory. When omitted, return a decorator.

  • generic

    (bool, default: False ) –

    When true, func is treated as a factory that receives Warp scalar dtypes inferred from the runtime JAX arguments and returns a concrete Warp callable implementation.

  • num_outputs

    (int, default: ... ) –
  • graph_mode

    (GraphMode, default: ... ) –
  • vmap_method

    (str | None, default: ... ) –
  • output_dims

    (dict[str, ShapeLike] | None, default: ... ) –
  • in_out_argnames

    (Iterable[str], default: ... ) –
  • stage_in_argnames

    (Iterable[str], default: ... ) –
  • stage_out_argnames

    (Iterable[str], default: ... ) –
  • graph_cache_max

    (int | None, default: ... ) –
  • module_preload_mode

    (ModulePreloadMode, default: ... ) –

Returns:

  • Any

    A callable compatible with JAX tracing, or a decorator producing one.

Source code in src/jarp/warp/_jax_callable.py
def jax_callable(
    func: Callable | None = None,
    *,
    generic: bool = False,
    **kwargs: Unpack[JaxCallableOptions],
) -> Any:
    """Wrap ``warp.jax_experimental.jax_callable`` with optional dtype dispatch.

    Args:
        func: Warp callable function or factory. When omitted, return a
            decorator.
        generic: When true, ``func`` is treated as a factory that receives Warp
            scalar dtypes inferred from the runtime JAX arguments and returns a
            concrete Warp callable implementation.
        **kwargs: Options forwarded to Warp's JAX callable adapter.

    Returns:
        A callable compatible with JAX tracing, or a decorator producing one.
    """
    if func is None:
        return functools.partial(jax_callable, generic=generic, **kwargs)
    if not generic:
        return warp.jax_experimental.jax_callable(func, **kwargs)
    factory: _FfiCallableFactory = functools.lru_cache(func)
    return _FfiCallable(factory=factory, options=kwargs)  # ty:ignore[invalid-argument-type]

jax_kernel

jax_kernel(
    *,
    arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
    | None = None,
    **kwargs: Unpack[JaxKernelOptions],
) -> Callable[[Callable], FfiKernelProtocol]
jax_kernel(
    kernel: Callable,
    *,
    arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
    | None = None,
    **kwargs: Unpack[JaxKernelOptions],
) -> FfiKernelProtocol

Wrap warp.jax_experimental.jax_kernel with optional overload lookup.

Parameters:

  • kernel

    (Callable | None, default: None ) –

    Warp kernel to expose to JAX. When omitted, return a decorator.

  • arg_types_factory

    (Callable[[WarpScalarDType], ArgTypes] | None, default: None ) –

    Optional callback that maps runtime Warp scalar dtypes to the overloaded kernel argument types expected by :func:warp.overload.

  • num_outputs

    (int, default: ... ) –
  • vmap_method

    (VmapMethod, default: ... ) –
  • launch_dims

    (ShapeLike | None, default: ... ) –
  • output_dims

    (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • in_out_argnames

    (Iterable[str], default: ... ) –
  • module_preload_mode

    (ModulePreloadMode, default: ... ) –
  • enable_backward

    (bool, default: ... ) –

Returns:

  • Any

    A callable compatible with JAX tracing, or a decorator producing one.

Source code in src/jarp/warp/_jax_kernel.py
def jax_kernel(
    kernel: Callable | None = None,
    *,
    arg_types_factory: Callable[[WarpScalarDType], ArgTypes] | None = None,
    **kwargs: Unpack[JaxKernelOptions],
) -> Any:
    """Wrap ``warp.jax_experimental.jax_kernel`` with optional overload lookup.

    Args:
        kernel: Warp kernel to expose to JAX. When omitted, return a decorator.
        arg_types_factory: Optional callback that maps runtime Warp scalar dtypes
            to the overloaded kernel argument types expected by
            :func:`warp.overload`.
        **kwargs: Options forwarded to Warp's JAX kernel adapter.

    Returns:
        A callable compatible with JAX tracing, or a decorator producing one.
    """
    if kernel is None:
        return functools.partial(
            jax_kernel, arg_types_factory=arg_types_factory, **kwargs
        )
    if arg_types_factory is None:
        return warp.jax_experimental.jax_kernel(kernel, **kwargs)
    return _FfiKernel(
        kernel=cast("wp.Kernel", kernel),
        options=kwargs,  # ty:ignore[invalid-argument-type]
        arg_types_factory=arg_types_factory,
    )

to_warp

to_warp(
    arr: array | ndarray | Array, *_args, **_kwargs
) -> array

Convert a supported array object into a :class:warp.array.

The generic dispatcher currently supports NumPy arrays and JAX arrays. A dtype hint may be a concrete Warp dtype or a tuple that describes a vector or matrix dtype inferred from the trailing dimensions of arr.

Source code in src/jarp/warp/_to_warp.py
@functools.singledispatch
def to_warp(arr: Any, *_args, **_kwargs) -> wp.array:
    """Convert a supported array object into a :class:`warp.array`.

    The generic dispatcher currently supports NumPy arrays and JAX arrays. A
    ``dtype`` hint may be a concrete Warp dtype or a tuple that describes a
    vector or matrix dtype inferred from the trailing dimensions of ``arr``.
    """
    raise TypeError(arr)

types

Convenience accessors for Warp scalar, vector, and matrix dtypes.

Functions:

  • __getattr__

    Resolve dynamic shorthand names such as floating, vec3, or mat33.

  • matrix

    Build a Warp matrix dtype using the default floating scalar type.

  • vector

    Build a Warp vector dtype using the default floating scalar type.

Attributes:

floating module-attribute

floating: type

mat22 module-attribute

mat22: type

mat33 module-attribute

mat33: type

mat44 module-attribute

mat44: type

vec2 module-attribute

vec2: type

vec3 module-attribute

vec3: type

vec4 module-attribute

vec4: type

__getattr__

__getattr__(name: str) -> type

Resolve dynamic shorthand names such as floating, vec3, or mat33.

Source code in src/jarp/warp/types.py
def __getattr__(name: str) -> type:
    """Resolve dynamic shorthand names such as ``floating``, ``vec3``, or ``mat33``."""
    if name in {"float", "float_"}:
        warnings.warn(
            f"{__name__}.{name} is deprecated, use {__name__}.floating instead",
            DeprecationWarning,
            stacklevel=2,
        )
        return _floating()
    if name == "floating":
        return _floating()
    if (result := re.fullmatch(r"vec(?P<length>[1-9])", name)) is not None:
        length = int(result.group("length"))
        return wp.types.vector(length, _floating())
    if (result := re.fullmatch(r"mat(?P<rows>[1-9])(?P<cols>[1-9])", name)) is not None:
        rows = int(result.group("rows"))
        cols = int(result.group("cols"))
        return wp.types.matrix((rows, cols), _floating())
    msg: str = f"module '{__name__}' has no attribute '{name}'"
    raise AttributeError(msg, name=name, obj=sys.modules[__name__])

matrix

matrix(shape: tuple[int, int]) -> type

Build a Warp matrix dtype using the default floating scalar type.

Source code in src/jarp/warp/types.py
def matrix(shape: tuple[int, int]) -> type:
    """Build a Warp matrix dtype using the default floating scalar type."""
    return wp.types.matrix(shape, _floating())

vector

vector(length: int) -> type

Build a Warp vector dtype using the default floating scalar type.

Source code in src/jarp/warp/types.py
def vector(length: int) -> type:
    """Build a Warp vector dtype using the default floating scalar type."""
    return wp.types.vector(length, _floating())