jarp
¶
PyTree-oriented helpers for JAX, attrs, and Warp.
Modules:
-
lax–Control-flow wrappers that mirror JAX APIs with small ergonomic additions.
-
tree–Helpers for defining, flattening, and transforming JAX PyTrees.
-
warp–Interop helpers between JAX arrays and Warp kernels or callables.
Classes:
-
Partial–Store a partially applied callable as a PyTree-aware proxy.
-
PyTreeProxy–Wrap an arbitrary object and flatten the wrapped value as a PyTree.
-
Structure–Record how to flatten and rebuild a PyTree's dynamic leaves.
Functions:
-
array– -
auto– -
define–Define an
attrsclass and optionally register it as a PyTree. -
field– -
frozen–Define a frozen
attrsclass and register it as a data PyTree. -
frozen_static–Define a frozen
attrsclass and register it as a static PyTree. -
jax_callable–Wrap
warp.jax_experimental.jax_callablewith optional dtype dispatch. -
jax_kernel–Wrap
warp.jax_experimental.jax_kernelwith optional overload lookup. -
jit–Compile a callable with JAX, optionally preserving static PyTree leaves.
-
partial–Partially apply a callable and keep the result compatible with JAX trees.
-
ravel–Flatten a PyTree's dynamic leaves into one vector.
-
register_pytree_prelude– -
static– -
to_warp–Convert a supported array object into a :class:
warp.array. -
while_loop–Run a loop with either
jax.lax.while_loopor Python control flow.
Attributes:
-
__version__(str) – -
__version_tuple__(tuple[int | str, ...]) –
__version_tuple__
module-attribute
¶
Partial
¶
Bases: PartialCallableObjectProxy
flowchart TD
jarp.Partial[Partial]
click jarp.Partial href "" "jarp.Partial"
Store a partially applied callable as a PyTree-aware proxy.
Methods:
-
__call__–
Attributes:
-
__wrapped__(Callable[..., T]) –
Source code in src/jarp/tree/prelude/_partial.py
PyTreeProxy
¶
Bases: BaseObjectProxy
flowchart TD
jarp.PyTreeProxy[PyTreeProxy]
click jarp.PyTreeProxy href "" "jarp.PyTreeProxy"
Wrap an arbitrary object and flatten the wrapped value as a PyTree.
Attributes:
-
__wrapped__(T) –
Structure
¶
Record how to flatten and rebuild a PyTree's dynamic leaves.
Instances are returned by :func:ravel and capture the original tree
definition, the static leaves that were removed from the flat vector, and
the offsets needed to reconstruct each dynamic leaf.
Parameters:
-
(dtype¶str | type[Any] | dtype | SupportsDType) – -
(meta_leaves¶tuple[Any, ...]) – -
(offsets¶tuple[int, ...]) – -
(shapes¶tuple[Shape | None, ...]) – -
(treedef¶PyTreeDef) –
Methods:
-
ravel–Flatten a compatible tree or flatten an array in-place.
-
unravel–Rebuild the original tree shape from a flat vector.
Attributes:
-
dtype(DTypeLike) – -
is_leaf(bool) –Return whether the original tree was a single leaf.
-
meta_leaves(tuple[Any, ...]) – -
offsets(tuple[int, ...]) – -
shapes(tuple[Shape | None, ...]) – -
treedef(PyTreeDef) –
ravel
¶
ravel(tree: T | Array) -> Vector
Flatten a compatible tree or flatten an array in-place.
Parameters:
-
(tree¶T | Array) –A tree with the same structure used to build this :class:
Structure, or an already-flat array.
Returns:
-
Vector–A one-dimensional array containing the dynamic leaves.
Source code in src/jarp/tree/_ravel.py
unravel
¶
Rebuild the original tree shape from a flat vector.
Parameters:
-
(flat¶T | Array) –One-dimensional data produced by :meth:
ravel, or a tree that already matches the recorded structure. -
(dtype¶DTypeLike | None, default:None) –Optional dtype override applied to the flat array before it is split and reshaped.
Returns:
-
T–A tree with the same structure and static metadata as the original
-
T–input to :func:
ravel.
Source code in src/jarp/tree/_ravel.py
array
¶
array(
*,
default: T = ...,
validator: _ValidatorArgType[T] | None = ...,
repr: _ReprArgType = ...,
hash: bool | None = ...,
init: bool = ...,
metadata: Mapping[Any, Any] | None = ...,
converter: _ConverterType
| list[_ConverterType]
| tuple[_ConverterType, ...]
| None = ...,
factory: Callable[[], T] | None = ...,
kw_only: bool | None = ...,
eq: _EqOrderType | None = ...,
order: _EqOrderType | None = ...,
on_setattr: _OnSetAttrArgType | None = ...,
alias: str | None = ...,
type: type | None = ...,
static: FieldType | bool | None = ...,
) -> Array
Parameters:
-
(default¶T, default:...) – -
(validator¶_ValidatorArgType[T] | None, default:...) – -
(repr¶_ReprArgType, default:...) – -
(hash¶bool | None, default:...) – -
(init¶bool, default:...) – -
(metadata¶Mapping[Any, Any] | None, default:...) – -
(converter¶_ConverterType | list[_ConverterType] | tuple[_ConverterType, ...] | None, default:...) – -
(factory¶Callable[[], T] | None, default:...) – -
(kw_only¶bool | None, default:...) – -
(eq¶_EqOrderType | None, default:...) – -
(order¶_EqOrderType | None, default:...) – -
(on_setattr¶_OnSetAttrArgType | None, default:...) – -
(alias¶str | None, default:...) – -
(type¶type | None, default:...) – -
(static¶FieldType | bool | None, default:...) –
Source code in src/jarp/tree/attrs/_field_specifiers.py
define
¶
Define an attrs class and optionally register it as a PyTree.
Parameters:
-
(maybe_cls¶T | None, default:None) –Class being decorated. When omitted, return a configured decorator.
-
(**kwargs¶Any, default:{}) –Options forwarded to :func:
attrs.define, pluspytreeto control JAX registration.pytree="data"registers fields withfieldzsemantics,"static"registers the whole instance as a static value, and"none"leaves the class unregistered.
Returns:
-
Any–The decorated class or a class decorator.
Source code in src/jarp/tree/attrs/_define.py
frozen
¶
Define a frozen attrs class and register it as a data PyTree.
Source code in src/jarp/tree/attrs/_define.py
frozen_static
¶
Define a frozen attrs class and register it as a static PyTree.
Source code in src/jarp/tree/attrs/_define.py
jax_callable
¶
jax_callable(
func: _FfiCallableFunction,
*,
generic: Literal[False] = False,
**kwargs: Unpack[JaxCallableOptions],
) -> FfiCallableProtocol
jax_callable(
*,
generic: Literal[False] = False,
**kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFunction], FfiCallableProtocol]
jax_callable(
func: _FfiCallableFactory,
*,
generic: Literal[True],
**kwargs: Unpack[JaxCallableOptions],
) -> _FfiCallable
jax_callable(
*,
generic: Literal[True],
**kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFactory], _FfiCallable]
Wrap warp.jax_experimental.jax_callable with optional dtype dispatch.
Parameters:
-
(func¶Callable | None, default:None) –Warp callable function or factory. When omitted, return a decorator.
-
(generic¶bool, default:False) –When true,
funcis treated as a factory that receives Warp scalar dtypes inferred from the runtime JAX arguments and returns a concrete Warp callable implementation. -
(num_outputs¶int, default:...) – -
(graph_mode¶GraphMode, default:...) – -
(vmap_method¶str | None, default:...) – -
(output_dims¶dict[str, ShapeLike] | None, default:...) – -
(in_out_argnames¶Iterable[str], default:...) – -
(stage_in_argnames¶Iterable[str], default:...) – -
(stage_out_argnames¶Iterable[str], default:...) – -
(graph_cache_max¶int | None, default:...) – -
(module_preload_mode¶ModulePreloadMode, default:...) –
Returns:
-
Any–A callable compatible with JAX tracing, or a decorator producing one.
Source code in src/jarp/warp/_jax_callable.py
jax_kernel
¶
jax_kernel(
*,
arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
| None = None,
**kwargs: Unpack[JaxKernelOptions],
) -> Callable[[Callable], FfiKernelProtocol]
jax_kernel(
kernel: Callable,
*,
arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
| None = None,
**kwargs: Unpack[JaxKernelOptions],
) -> FfiKernelProtocol
Wrap warp.jax_experimental.jax_kernel with optional overload lookup.
Parameters:
-
(kernel¶Callable | None, default:None) –Warp kernel to expose to JAX. When omitted, return a decorator.
-
(arg_types_factory¶Callable[[WarpScalarDType], ArgTypes] | None, default:None) –Optional callback that maps runtime Warp scalar dtypes to the overloaded kernel argument types expected by :func:
warp.overload. -
(num_outputs¶int, default:...) – -
(vmap_method¶VmapMethod, default:...) – -
(launch_dims¶ShapeLike | None, default:...) – -
(output_dims¶ShapeLike | dict[str, ShapeLike] | None, default:...) – -
(in_out_argnames¶Iterable[str], default:...) – -
(module_preload_mode¶ModulePreloadMode, default:...) – -
(enable_backward¶bool, default:...) –
Returns:
-
Any–A callable compatible with JAX tracing, or a decorator producing one.
Source code in src/jarp/warp/_jax_kernel.py
jit
¶
jit[**P, T](
fun: Callable[P, T],
/,
*,
filter: Literal[False] = False,
**kwargs: Unpack[JitOptions],
) -> Callable[P, T]
jit[**P, T](
*,
filter: Literal[False] = False,
**kwargs: Unpack[JitOptions],
) -> Callable[[Callable[P, T]], Callable[P, T]]
Compile a callable with JAX, optionally preserving static PyTree leaves.
When filter=False this is a thin wrapper around :func:jax.jit.
When filter=True the function and its inputs are partitioned into
dynamic array leaves and static metadata so mixed PyTrees can cross the JIT
boundary without requiring manual static_argnums wiring.
Parameters:
-
(fun¶Callable[P, T] | None, default:None) –Callable to compile. When omitted, return a decorator.
-
(**kwargs¶Any, default:{}) –Options forwarded to :func:
jax.jit. Withfilter=True, only the subset in :class:FilterJitOptionsis supported because static argument handling is managed internally.
Returns:
Source code in src/jarp/_jit.py
partial
¶
Partially apply a callable and keep the result compatible with JAX trees.
ravel
¶
Flatten a PyTree's dynamic leaves into one vector.
Non-array leaves are treated as static metadata and preserved in the
returned :class:Structure instead of being concatenated into the flat
array.
Parameters:
-
(tree¶T) –PyTree to flatten.
Returns:
-
Array–A tuple of
(flat, structure)whereflatis a one-dimensional -
Structure[T]–JAX array and
structurecan rebuild compatible trees later.
Source code in src/jarp/tree/_ravel.py
register_pytree_prelude
cached
¶
to_warp
¶
Convert a supported array object into a :class:warp.array.
The generic dispatcher currently supports NumPy arrays and JAX arrays. A
dtype hint may be a concrete Warp dtype or a tuple that describes a
vector or matrix dtype inferred from the trailing dimensions of arr.
Source code in src/jarp/warp/_to_warp.py
while_loop
¶
while_loop[T](
cond_fun: Callable[[T], BooleanNumeric],
body_fun: Callable[[T], T],
init_val: T,
*,
jit: bool = True,
) -> T
Run a loop with either jax.lax.while_loop or Python control flow.
Parameters:
-
(cond_fun¶Callable[[T], BooleanNumeric]) –Predicate evaluated on the loop state.
-
(body_fun¶Callable[[T], T]) –Function that produces the next loop state.
-
(init_val¶T) –Initial loop state.
-
(jit¶bool, default:True) –When true, dispatch to :func:
jax.lax.while_loop. When false, run an eager Pythonwhileloop with the same callbacks.
Returns:
-
T–The final loop state.
Source code in src/jarp/lax/_while_loop.py
lax
¶
Control-flow wrappers that mirror JAX APIs with small ergonomic additions.
Functions:
-
while_loop–Run a loop with either
jax.lax.while_loopor Python control flow.
while_loop
¶
while_loop[T](
cond_fun: Callable[[T], BooleanNumeric],
body_fun: Callable[[T], T],
init_val: T,
*,
jit: bool = True,
) -> T
Run a loop with either jax.lax.while_loop or Python control flow.
Parameters:
-
(cond_fun¶Callable[[T], BooleanNumeric]) –Predicate evaluated on the loop state.
-
(body_fun¶Callable[[T], T]) –Function that produces the next loop state.
-
(init_val¶T) –Initial loop state.
-
(jit¶bool, default:True) –When true, dispatch to :func:
jax.lax.while_loop. When false, run an eager Pythonwhileloop with the same callbacks.
Returns:
-
T–The final loop state.
Source code in src/jarp/lax/_while_loop.py
tree
¶
Helpers for defining, flattening, and transforming JAX PyTrees.
Modules:
-
attrs–attrs-based decorators and field helpers for JAX PyTree classes.
-
codegen–Code-generation helpers for custom PyTree registrations.
-
prelude–PyTree-aware wrappers for callables and object proxies.
Classes:
-
AuxData–Carry the static part of a partitioned PyTree.
-
FieldType– -
Partial–Store a partially applied callable as a PyTree-aware proxy.
-
PyTreeProxy–Wrap an arbitrary object and flatten the wrapped value as a PyTree.
-
PyTreeType–Choose how a class should participate in JAX PyTree flattening.
-
Structure–Record how to flatten and rebuild a PyTree's dynamic leaves.
Functions:
-
array– -
auto– -
codegen_pytree_functions– -
combine–Rebuild a PyTree from dynamic leaves and recorded auxiliary data.
-
combine_leaves–Zip dynamic leaves back together with their static counterparts.
-
define–Define an
attrsclass and optionally register it as a PyTree. -
field– -
frozen–Define a frozen
attrsclass and register it as a data PyTree. -
frozen_static–Define a frozen
attrsclass and register it as a static PyTree. -
is_data–Return whether an object should stay on the dynamic side of a partition.
-
is_leaf–Return whether a leaf contributes data to a flattened vector.
-
partial–Partially apply a callable and keep the result compatible with JAX trees.
-
partition–Split a PyTree into dynamic leaves and static metadata.
-
partition_leaves–Separate raw tree leaves into data leaves and metadata leaves.
-
ravel–Flatten a PyTree's dynamic leaves into one vector.
-
register_fieldz– -
register_generic– -
register_pytree_prelude– -
static–
AuxData
¶
FieldType
¶
Partial
¶
Bases: PartialCallableObjectProxy
flowchart TD
jarp.tree.Partial[Partial]
click jarp.tree.Partial href "" "jarp.tree.Partial"
Store a partially applied callable as a PyTree-aware proxy.
Methods:
-
__call__–
Attributes:
-
__wrapped__(Callable[..., T]) –
Source code in src/jarp/tree/prelude/_partial.py
PyTreeProxy
¶
Bases: BaseObjectProxy
flowchart TD
jarp.tree.PyTreeProxy[PyTreeProxy]
click jarp.tree.PyTreeProxy href "" "jarp.tree.PyTreeProxy"
Wrap an arbitrary object and flatten the wrapped value as a PyTree.
Attributes:
-
__wrapped__(T) –
PyTreeType
¶
Bases: StrEnum
flowchart TD
jarp.tree.PyTreeType[PyTreeType]
click jarp.tree.PyTreeType href "" "jarp.tree.PyTreeType"
Choose how a class should participate in JAX PyTree flattening.
Attributes:
Structure
¶
Record how to flatten and rebuild a PyTree's dynamic leaves.
Instances are returned by :func:ravel and capture the original tree
definition, the static leaves that were removed from the flat vector, and
the offsets needed to reconstruct each dynamic leaf.
Parameters:
-
(dtype¶str | type[Any] | dtype | SupportsDType) – -
(meta_leaves¶tuple[Any, ...]) – -
(offsets¶tuple[int, ...]) – -
(shapes¶tuple[Shape | None, ...]) – -
(treedef¶PyTreeDef) –
Methods:
-
ravel–Flatten a compatible tree or flatten an array in-place.
-
unravel–Rebuild the original tree shape from a flat vector.
Attributes:
-
dtype(DTypeLike) – -
is_leaf(bool) –Return whether the original tree was a single leaf.
-
meta_leaves(tuple[Any, ...]) – -
offsets(tuple[int, ...]) – -
shapes(tuple[Shape | None, ...]) – -
treedef(PyTreeDef) –
ravel
¶
ravel(tree: T | Array) -> Vector
Flatten a compatible tree or flatten an array in-place.
Parameters:
-
(tree¶T | Array) –A tree with the same structure used to build this :class:
Structure, or an already-flat array.
Returns:
-
Vector–A one-dimensional array containing the dynamic leaves.
Source code in src/jarp/tree/_ravel.py
unravel
¶
Rebuild the original tree shape from a flat vector.
Parameters:
-
(flat¶T | Array) –One-dimensional data produced by :meth:
ravel, or a tree that already matches the recorded structure. -
(dtype¶DTypeLike | None, default:None) –Optional dtype override applied to the flat array before it is split and reshaped.
Returns:
-
T–A tree with the same structure and static metadata as the original
-
T–input to :func:
ravel.
Source code in src/jarp/tree/_ravel.py
array
¶
array(
*,
default: T = ...,
validator: _ValidatorArgType[T] | None = ...,
repr: _ReprArgType = ...,
hash: bool | None = ...,
init: bool = ...,
metadata: Mapping[Any, Any] | None = ...,
converter: _ConverterType
| list[_ConverterType]
| tuple[_ConverterType, ...]
| None = ...,
factory: Callable[[], T] | None = ...,
kw_only: bool | None = ...,
eq: _EqOrderType | None = ...,
order: _EqOrderType | None = ...,
on_setattr: _OnSetAttrArgType | None = ...,
alias: str | None = ...,
type: type | None = ...,
static: FieldType | bool | None = ...,
) -> Array
Parameters:
-
(default¶T, default:...) – -
(validator¶_ValidatorArgType[T] | None, default:...) – -
(repr¶_ReprArgType, default:...) – -
(hash¶bool | None, default:...) – -
(init¶bool, default:...) – -
(metadata¶Mapping[Any, Any] | None, default:...) – -
(converter¶_ConverterType | list[_ConverterType] | tuple[_ConverterType, ...] | None, default:...) – -
(factory¶Callable[[], T] | None, default:...) – -
(kw_only¶bool | None, default:...) – -
(eq¶_EqOrderType | None, default:...) – -
(order¶_EqOrderType | None, default:...) – -
(on_setattr¶_OnSetAttrArgType | None, default:...) – -
(alias¶str | None, default:...) – -
(type¶type | None, default:...) – -
(static¶FieldType | bool | None, default:...) –
Source code in src/jarp/tree/attrs/_field_specifiers.py
codegen_pytree_functions
¶
codegen_pytree_functions(
cls: type,
data_fields: Sequence[str] = (),
meta_fields: Sequence[str] = (),
auto_fields: Sequence[str] = (),
*,
filter_spec: Callable[[Any], bool] = is_data,
bypass_setattr: bool | None = None,
) -> PyTreeFunctions
Source code in src/jarp/tree/codegen/_compile.py
combine
¶
Rebuild a PyTree from dynamic leaves and recorded auxiliary data.
Source code in src/jarp/tree/_filters.py
combine_leaves
¶
Zip dynamic leaves back together with their static counterparts.
Source code in src/jarp/tree/_filters.py
define
¶
Define an attrs class and optionally register it as a PyTree.
Parameters:
-
(maybe_cls¶T | None, default:None) –Class being decorated. When omitted, return a configured decorator.
-
(**kwargs¶Any, default:{}) –Options forwarded to :func:
attrs.define, pluspytreeto control JAX registration.pytree="data"registers fields withfieldzsemantics,"static"registers the whole instance as a static value, and"none"leaves the class unregistered.
Returns:
-
Any–The decorated class or a class decorator.
Source code in src/jarp/tree/attrs/_define.py
frozen
¶
Define a frozen attrs class and register it as a data PyTree.
Source code in src/jarp/tree/attrs/_define.py
frozen_static
¶
Define a frozen attrs class and register it as a static PyTree.
Source code in src/jarp/tree/attrs/_define.py
is_data
¶
Return whether an object should stay on the dynamic side of a partition.
partial
¶
Partially apply a callable and keep the result compatible with JAX trees.
partition
¶
Split a PyTree into dynamic leaves and static metadata.
Source code in src/jarp/tree/_filters.py
partition_leaves
¶
Separate raw tree leaves into data leaves and metadata leaves.
Source code in src/jarp/tree/_filters.py
ravel
¶
Flatten a PyTree's dynamic leaves into one vector.
Non-array leaves are treated as static metadata and preserved in the
returned :class:Structure instead of being concatenated into the flat
array.
Parameters:
-
(tree¶T) –PyTree to flatten.
Returns:
-
Array–A tuple of
(flat, structure)whereflatis a one-dimensional -
Structure[T]–JAX array and
structurecan rebuild compatible trees later.
Source code in src/jarp/tree/_ravel.py
register_fieldz
¶
register_fieldz[T: type](
cls: T,
data_fields: Sequence[str] | None = None,
meta_fields: Sequence[str] | None = None,
auto_fields: Sequence[str] | None = None,
*,
filter_spec: Callable[[Any], bool] = is_data,
bypass_setattr: bool | None = None,
) -> T
Source code in src/jarp/tree/attrs/_register.py
register_generic
¶
register_generic(
cls: type,
data_fields: Sequence[str] = (),
meta_fields: Sequence[str] = (),
auto_fields: Sequence[str] = (),
*,
filter_spec: Callable[[Any], bool] = is_data,
bypass_setattr: bool | None = None,
) -> None
Source code in src/jarp/tree/codegen/_compile.py
register_pytree_prelude
cached
¶
attrs
¶
attrs-based decorators and field helpers for JAX PyTree classes.
Classes:
-
FieldType– -
PyTreeType–Choose how a class should participate in JAX PyTree flattening.
Functions:
-
array– -
auto– -
define–Define an
attrsclass and optionally register it as a PyTree. -
field– -
frozen–Define a frozen
attrsclass and register it as a data PyTree. -
frozen_static–Define a frozen
attrsclass and register it as a static PyTree. -
register_fieldz– -
static–
FieldType
¶
PyTreeType
¶
Bases: StrEnum
flowchart TD
jarp.tree.attrs.PyTreeType[PyTreeType]
click jarp.tree.attrs.PyTreeType href "" "jarp.tree.attrs.PyTreeType"
Choose how a class should participate in JAX PyTree flattening.
Attributes:
array
¶
array(
*,
default: T = ...,
validator: _ValidatorArgType[T] | None = ...,
repr: _ReprArgType = ...,
hash: bool | None = ...,
init: bool = ...,
metadata: Mapping[Any, Any] | None = ...,
converter: _ConverterType
| list[_ConverterType]
| tuple[_ConverterType, ...]
| None = ...,
factory: Callable[[], T] | None = ...,
kw_only: bool | None = ...,
eq: _EqOrderType | None = ...,
order: _EqOrderType | None = ...,
on_setattr: _OnSetAttrArgType | None = ...,
alias: str | None = ...,
type: type | None = ...,
static: FieldType | bool | None = ...,
) -> Array
Parameters:
-
(default¶T, default:...) – -
(validator¶_ValidatorArgType[T] | None, default:...) – -
(repr¶_ReprArgType, default:...) – -
(hash¶bool | None, default:...) – -
(init¶bool, default:...) – -
(metadata¶Mapping[Any, Any] | None, default:...) – -
(converter¶_ConverterType | list[_ConverterType] | tuple[_ConverterType, ...] | None, default:...) – -
(factory¶Callable[[], T] | None, default:...) – -
(kw_only¶bool | None, default:...) – -
(eq¶_EqOrderType | None, default:...) – -
(order¶_EqOrderType | None, default:...) – -
(on_setattr¶_OnSetAttrArgType | None, default:...) – -
(alias¶str | None, default:...) – -
(type¶type | None, default:...) – -
(static¶FieldType | bool | None, default:...) –
Source code in src/jarp/tree/attrs/_field_specifiers.py
define
¶
Define an attrs class and optionally register it as a PyTree.
Parameters:
-
(maybe_cls¶T | None, default:None) –Class being decorated. When omitted, return a configured decorator.
-
(**kwargs¶Any, default:{}) –Options forwarded to :func:
attrs.define, pluspytreeto control JAX registration.pytree="data"registers fields withfieldzsemantics,"static"registers the whole instance as a static value, and"none"leaves the class unregistered.
Returns:
-
Any–The decorated class or a class decorator.
Source code in src/jarp/tree/attrs/_define.py
frozen
¶
Define a frozen attrs class and register it as a data PyTree.
Source code in src/jarp/tree/attrs/_define.py
frozen_static
¶
Define a frozen attrs class and register it as a static PyTree.
Source code in src/jarp/tree/attrs/_define.py
register_fieldz
¶
register_fieldz[T: type](
cls: T,
data_fields: Sequence[str] | None = None,
meta_fields: Sequence[str] | None = None,
auto_fields: Sequence[str] | None = None,
*,
filter_spec: Callable[[Any], bool] = is_data,
bypass_setattr: bool | None = None,
) -> T
Source code in src/jarp/tree/attrs/_register.py
codegen
¶
Code-generation helpers for custom PyTree registrations.
Classes:
-
PyTreeFunctions–PyTreeFunctions(flatten, unflatten, flatten_with_keys)
Functions:
-
codegen_flatten– -
codegen_flatten_with_keys– -
codegen_pytree_functions– -
codegen_unflatten– -
register_generic–
PyTreeFunctions
¶
Bases: NamedTuple
flowchart TD
jarp.tree.codegen.PyTreeFunctions[PyTreeFunctions]
click jarp.tree.codegen.PyTreeFunctions href "" "jarp.tree.codegen.PyTreeFunctions"
PyTreeFunctions(flatten, unflatten, flatten_with_keys)
Parameters:
-
(flatten¶Callable[list, tuple[_Children, _AuxData]], default:None) – -
(unflatten¶Callable[list, T], default:None) – -
(flatten_with_keys¶Callable[list, tuple[_ChildrenWithKeys, _AuxData]], default:None) –
Attributes:
codegen_flatten
¶
codegen_flatten(
data_fields: Sequence[str],
meta_fields: Sequence[str],
auto_fields: Sequence[str],
) -> FunctionDef
Source code in src/jarp/tree/codegen/_codegen.py
codegen_flatten_with_keys
¶
codegen_flatten_with_keys(
data_fields: Sequence[str],
meta_fields: Sequence[str],
auto_fields: Sequence[str],
) -> FunctionDef
Source code in src/jarp/tree/codegen/_codegen.py
codegen_pytree_functions
¶
codegen_pytree_functions(
cls: type,
data_fields: Sequence[str] = (),
meta_fields: Sequence[str] = (),
auto_fields: Sequence[str] = (),
*,
filter_spec: Callable[[Any], bool] = is_data,
bypass_setattr: bool | None = None,
) -> PyTreeFunctions
Source code in src/jarp/tree/codegen/_compile.py
codegen_unflatten
¶
codegen_unflatten(
data_fields: Sequence[str],
meta_fields: Sequence[str],
auto_fields: Sequence[str],
*,
bypass_setattr: bool = False,
) -> FunctionDef
Source code in src/jarp/tree/codegen/_codegen.py
register_generic
¶
register_generic(
cls: type,
data_fields: Sequence[str] = (),
meta_fields: Sequence[str] = (),
auto_fields: Sequence[str] = (),
*,
filter_spec: Callable[[Any], bool] = is_data,
bypass_setattr: bool | None = None,
) -> None
Source code in src/jarp/tree/codegen/_compile.py
prelude
¶
PyTree-aware wrappers for callables and object proxies.
Classes:
-
Partial–Store a partially applied callable as a PyTree-aware proxy.
-
PyTreeProxy–Wrap an arbitrary object and flatten the wrapped value as a PyTree.
Functions:
-
partial–Partially apply a callable and keep the result compatible with JAX trees.
-
register_pytree_prelude–
Partial
¶
Bases: PartialCallableObjectProxy
flowchart TD
jarp.tree.prelude.Partial[Partial]
click jarp.tree.prelude.Partial href "" "jarp.tree.prelude.Partial"
Store a partially applied callable as a PyTree-aware proxy.
Methods:
-
__call__–
Attributes:
-
__wrapped__(Callable[..., T]) –
Source code in src/jarp/tree/prelude/_partial.py
PyTreeProxy
¶
Bases: BaseObjectProxy
flowchart TD
jarp.tree.prelude.PyTreeProxy[PyTreeProxy]
click jarp.tree.prelude.PyTreeProxy href "" "jarp.tree.prelude.PyTreeProxy"
Wrap an arbitrary object and flatten the wrapped value as a PyTree.
Attributes:
-
__wrapped__(T) –
partial
¶
Partially apply a callable and keep the result compatible with JAX trees.
warp
¶
Interop helpers between JAX arrays and Warp kernels or callables.
Modules:
-
types–Convenience accessors for Warp scalar, vector, and matrix dtypes.
Classes:
-
FfiCallableProtocol–Callable interface returned by :func:
jax_callable. -
FfiKernelProtocol–Callable interface returned by :func:
jax_kernel. -
JaxCallableCallOptions– -
JaxCallableOptions– -
JaxKernelCallOptions– -
JaxKernelOptions–
Functions:
-
jax_callable–Wrap
warp.jax_experimental.jax_callablewith optional dtype dispatch. -
jax_kernel–Wrap
warp.jax_experimental.jax_kernelwith optional overload lookup. -
to_warp–Convert a supported array object into a :class:
warp.array.
FfiCallableProtocol
¶
Bases: Protocol
flowchart TD
jarp.warp.FfiCallableProtocol[FfiCallableProtocol]
click jarp.warp.FfiCallableProtocol href "" "jarp.warp.FfiCallableProtocol"
Callable interface returned by :func:jax_callable.
Methods:
-
__call__–
__call__
¶
__call__(
*args: Array,
output_dims: ShapeLike
| dict[str, ShapeLike]
| None = ...,
vmap_method: VmapMethod | None = ...,
) -> Sequence[Array]
FfiKernelProtocol
¶
Bases: Protocol
flowchart TD
jarp.warp.FfiKernelProtocol[FfiKernelProtocol]
click jarp.warp.FfiKernelProtocol href "" "jarp.warp.FfiKernelProtocol"
Callable interface returned by :func:jax_kernel.
Methods:
-
__call__–
__call__
¶
__call__(
*args: Array,
output_dims: ShapeLike
| dict[str, ShapeLike]
| None = ...,
launch_dims: ShapeLike | None = ...,
vmap_method: VmapMethod | None = ...,
) -> Sequence[Array]
JaxCallableCallOptions
typed-dict
¶
JaxCallableCallOptions(
*,
output_dims: ShapeLike
| dict[str, ShapeLike]
| None = ...,
vmap_method: VmapMethod | None = ...,
)
Bases: TypedDict
flowchart TD
jarp.warp.JaxCallableCallOptions[JaxCallableCallOptions]
click jarp.warp.JaxCallableCallOptions href "" "jarp.warp.JaxCallableCallOptions"
Parameters:
Parameters:
-
(output_dims¶ShapeLike | dict[str, ShapeLike] | None, default:...) – -
(vmap_method¶VmapMethod | None, default:...) –
JaxCallableOptions
typed-dict
¶
JaxCallableOptions(
*,
num_outputs: int = ...,
graph_mode: GraphMode = ...,
vmap_method: str | None = ...,
output_dims: dict[str, ShapeLike] | None = ...,
in_out_argnames: Iterable[str] = ...,
stage_in_argnames: Iterable[str] = ...,
stage_out_argnames: Iterable[str] = ...,
graph_cache_max: int | None = ...,
module_preload_mode: ModulePreloadMode = ...,
)
Bases: TypedDict
flowchart TD
jarp.warp.JaxCallableOptions[JaxCallableOptions]
click jarp.warp.JaxCallableOptions href "" "jarp.warp.JaxCallableOptions"
Parameters:
-
(num_outputs¶int) – -
(graph_mode¶GraphMode) – -
(vmap_method¶str | None) – -
(output_dims¶dict[str, ShapeLike] | None) – -
(in_out_argnames¶Iterable[str]) – -
(stage_in_argnames¶Iterable[str]) – -
(stage_out_argnames¶Iterable[str]) – -
(graph_cache_max¶int | None) – -
(module_preload_mode¶ModulePreloadMode) –
Parameters:
-
(num_outputs¶int, default:...) – -
(graph_mode¶GraphMode, default:...) – -
(vmap_method¶str | None, default:...) – -
(output_dims¶dict[str, ShapeLike] | None, default:...) – -
(in_out_argnames¶Iterable[str], default:...) – -
(stage_in_argnames¶Iterable[str], default:...) – -
(stage_out_argnames¶Iterable[str], default:...) – -
(graph_cache_max¶int | None, default:...) – -
(module_preload_mode¶ModulePreloadMode, default:...) –
JaxKernelCallOptions
typed-dict
¶
JaxKernelCallOptions(
*,
output_dims: ShapeLike
| dict[str, ShapeLike]
| None = ...,
launch_dims: ShapeLike | None = ...,
vmap_method: VmapMethod | None = ...,
)
Bases: TypedDict
flowchart TD
jarp.warp.JaxKernelCallOptions[JaxKernelCallOptions]
click jarp.warp.JaxKernelCallOptions href "" "jarp.warp.JaxKernelCallOptions"
Parameters:
-
(output_dims¶ShapeLike | dict[str, ShapeLike] | None) – -
(launch_dims¶ShapeLike | None) – -
(vmap_method¶VmapMethod | None) –
Parameters:
-
(output_dims¶ShapeLike | dict[str, ShapeLike] | None, default:...) – -
(launch_dims¶ShapeLike | None, default:...) – -
(vmap_method¶VmapMethod | None, default:...) –
JaxKernelOptions
typed-dict
¶
JaxKernelOptions(
*,
num_outputs: int = ...,
vmap_method: VmapMethod = ...,
launch_dims: ShapeLike | None = ...,
output_dims: ShapeLike
| dict[str, ShapeLike]
| None = ...,
in_out_argnames: Iterable[str] = ...,
module_preload_mode: ModulePreloadMode = ...,
enable_backward: bool = ...,
)
Bases: TypedDict
flowchart TD
jarp.warp.JaxKernelOptions[JaxKernelOptions]
click jarp.warp.JaxKernelOptions href "" "jarp.warp.JaxKernelOptions"
Parameters:
-
(num_outputs¶int) – -
(vmap_method¶VmapMethod) – -
(launch_dims¶ShapeLike | None) – -
(output_dims¶ShapeLike | dict[str, ShapeLike] | None) – -
(in_out_argnames¶Iterable[str]) – -
(module_preload_mode¶ModulePreloadMode) – -
(enable_backward¶bool) –
Parameters:
-
(num_outputs¶int, default:...) – -
(vmap_method¶VmapMethod, default:...) – -
(launch_dims¶ShapeLike | None, default:...) – -
(output_dims¶ShapeLike | dict[str, ShapeLike] | None, default:...) – -
(in_out_argnames¶Iterable[str], default:...) – -
(module_preload_mode¶ModulePreloadMode, default:...) – -
(enable_backward¶bool, default:...) –
jax_callable
¶
jax_callable(
func: _FfiCallableFunction,
*,
generic: Literal[False] = False,
**kwargs: Unpack[JaxCallableOptions],
) -> FfiCallableProtocol
jax_callable(
*,
generic: Literal[False] = False,
**kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFunction], FfiCallableProtocol]
jax_callable(
func: _FfiCallableFactory,
*,
generic: Literal[True],
**kwargs: Unpack[JaxCallableOptions],
) -> _FfiCallable
jax_callable(
*,
generic: Literal[True],
**kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFactory], _FfiCallable]
Wrap warp.jax_experimental.jax_callable with optional dtype dispatch.
Parameters:
-
(func¶Callable | None, default:None) –Warp callable function or factory. When omitted, return a decorator.
-
(generic¶bool, default:False) –When true,
funcis treated as a factory that receives Warp scalar dtypes inferred from the runtime JAX arguments and returns a concrete Warp callable implementation. -
(num_outputs¶int, default:...) – -
(graph_mode¶GraphMode, default:...) – -
(vmap_method¶str | None, default:...) – -
(output_dims¶dict[str, ShapeLike] | None, default:...) – -
(in_out_argnames¶Iterable[str], default:...) – -
(stage_in_argnames¶Iterable[str], default:...) – -
(stage_out_argnames¶Iterable[str], default:...) – -
(graph_cache_max¶int | None, default:...) – -
(module_preload_mode¶ModulePreloadMode, default:...) –
Returns:
-
Any–A callable compatible with JAX tracing, or a decorator producing one.
Source code in src/jarp/warp/_jax_callable.py
jax_kernel
¶
jax_kernel(
*,
arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
| None = None,
**kwargs: Unpack[JaxKernelOptions],
) -> Callable[[Callable], FfiKernelProtocol]
jax_kernel(
kernel: Callable,
*,
arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
| None = None,
**kwargs: Unpack[JaxKernelOptions],
) -> FfiKernelProtocol
Wrap warp.jax_experimental.jax_kernel with optional overload lookup.
Parameters:
-
(kernel¶Callable | None, default:None) –Warp kernel to expose to JAX. When omitted, return a decorator.
-
(arg_types_factory¶Callable[[WarpScalarDType], ArgTypes] | None, default:None) –Optional callback that maps runtime Warp scalar dtypes to the overloaded kernel argument types expected by :func:
warp.overload. -
(num_outputs¶int, default:...) – -
(vmap_method¶VmapMethod, default:...) – -
(launch_dims¶ShapeLike | None, default:...) – -
(output_dims¶ShapeLike | dict[str, ShapeLike] | None, default:...) – -
(in_out_argnames¶Iterable[str], default:...) – -
(module_preload_mode¶ModulePreloadMode, default:...) – -
(enable_backward¶bool, default:...) –
Returns:
-
Any–A callable compatible with JAX tracing, or a decorator producing one.
Source code in src/jarp/warp/_jax_kernel.py
to_warp
¶
Convert a supported array object into a :class:warp.array.
The generic dispatcher currently supports NumPy arrays and JAX arrays. A
dtype hint may be a concrete Warp dtype or a tuple that describes a
vector or matrix dtype inferred from the trailing dimensions of arr.
Source code in src/jarp/warp/_to_warp.py
types
¶
Convenience accessors for Warp scalar, vector, and matrix dtypes.
Functions:
-
__getattr__–Resolve dynamic shorthand names such as
floating,vec3, ormat33. -
matrix–Build a Warp matrix dtype using the default floating scalar type.
-
vector–Build a Warp vector dtype using the default floating scalar type.
Attributes:
-
floating(type) – -
mat22(type) – -
mat33(type) – -
mat44(type) – -
vec2(type) – -
vec3(type) – -
vec4(type) –
__getattr__
¶
Resolve dynamic shorthand names such as floating, vec3, or mat33.