README
jarp is a typed utility library for the glue code between JAX PyTrees,
attrs, and NVIDIA Warp. It packages the small adapters that keep mixed
array-and-metadata structures ergonomic in compiled code, tree transforms, and
Warp interop.
โจ Features¶
- ๐ง Filtered JIT for mixed PyTrees:
jarp.jit(filter=True)partitions arrays from static Python metadata so functions can crossjax.jitwithout hand-written static argument plumbing. - ๐ณ PyTree-friendly class decorators:
jarp.define,jarp.frozen, andjarp.frozen_staticwrapattrsand register classes with JAX using field specifiers such asarray(),static(), andauto(). - ๐ Tree flattening with round-tripping structure:
jarp.ravelproduces a flat vector plus a reusableStructureobject that can rebuild the original PyTree, including static leaves. - โก Warp integration that matches JAX workflows:
jarp.to_warp,jarp.warp.jax_callable, andjarp.warp.jax_kernelbridge NumPy or JAX arrays into Warp and support dtype-driven generic wrappers. - ๐ Modern Python support and typed APIs: the package targets Python 3.12 through 3.14 and ships inline type information.
๐ฆ Installation¶
[!NOTE]
jarprequires Python 3.12 or newer.
Install the published package with uv:
If you want a CUDA-enabled JAX extra, pick the matching wheel set:
๐ Quick Start¶
This example shows the two pieces jarp is built around: filtered JIT for
mixed PyTrees and attrs-style classes that flatten cleanly under JAX.
import jax.numpy as jnp
import jarp
@jarp.define
class Batch:
values: object = jarp.array()
label: str = jarp.static()
@jarp.jit(filter=True)
def normalize(batch: Batch) -> Batch:
centered = batch.values - jnp.mean(batch.values)
return Batch(values=centered, label=batch.label)
batch = Batch(values=jnp.array([1.0, 2.0, 3.0]), label="train")
result = normalize(batch)
The array payload stays traceable, while the string label is preserved as static metadata.
jarp.ravel handles the other common workflow: turn a PyTree into one vector
and keep enough structure around to rebuild it later.
import jax.numpy as jnp
import jarp
payload = {"a": jnp.zeros((3,)), "b": jnp.ones((4,)), "static": "foo"}
flat, structure = jarp.ravel(payload)
round_trip = structure.unravel(flat)
๐ ๏ธ Local Development¶
Clone the repository, sync the workspace, and use nox for the maintained
automation surface:
git clone https://github.com/liblaf/jarp.git
cd jarp
uv sync --all-groups
nox --list-sessions
nox --tags test
To build the documentation site locally:
Benchmarks and API docs live under docs/, and the published
site is available at liblaf.github.io/jarp.
๐ค Contributing¶
Issues and pull requests are welcome, especially around PyTree ergonomics, Warp integration, and edge cases that show up in real JAX code.
