Skip to content

README

jarp is a typed utility library for the glue code between JAX PyTrees, attrs, and NVIDIA Warp. It packages the small adapters that keep mixed array-and-metadata structures ergonomic in compiled code, tree transforms, and Warp interop.

โœจ Features

  • ๐Ÿง  Filtered JIT for mixed PyTrees: jarp.jit(filter=True) partitions arrays from static Python metadata so functions can cross jax.jit without hand-written static argument plumbing.
  • ๐ŸŒณ PyTree-friendly class decorators: jarp.define, jarp.frozen, and jarp.frozen_static wrap attrs and register classes with JAX using field specifiers such as array(), static(), and auto().
  • ๐Ÿ“ Tree flattening with round-tripping structure: jarp.ravel produces a flat vector plus a reusable Structure object that can rebuild the original PyTree, including static leaves.
  • โšก Warp integration that matches JAX workflows: jarp.to_warp, jarp.warp.jax_callable, and jarp.warp.jax_kernel bridge NumPy or JAX arrays into Warp and support dtype-driven generic wrappers.
  • ๐Ÿ Modern Python support and typed APIs: the package targets Python 3.12 through 3.14 and ships inline type information.

๐Ÿ“ฆ Installation

[!NOTE] jarp requires Python 3.12 or newer.

Install the published package with uv:

uv add jarp

If you want a CUDA-enabled JAX extra, pick the matching wheel set:

uv add 'jarp[cuda12]'
uv add 'jarp[cuda13]'

๐Ÿš€ Quick Start

This example shows the two pieces jarp is built around: filtered JIT for mixed PyTrees and attrs-style classes that flatten cleanly under JAX.

import jax.numpy as jnp
import jarp


@jarp.define
class Batch:
    values: object = jarp.array()
    label: str = jarp.static()


@jarp.jit(filter=True)
def normalize(batch: Batch) -> Batch:
    centered = batch.values - jnp.mean(batch.values)
    return Batch(values=centered, label=batch.label)


batch = Batch(values=jnp.array([1.0, 2.0, 3.0]), label="train")
result = normalize(batch)

The array payload stays traceable, while the string label is preserved as static metadata.

jarp.ravel handles the other common workflow: turn a PyTree into one vector and keep enough structure around to rebuild it later.

import jax.numpy as jnp
import jarp


payload = {"a": jnp.zeros((3,)), "b": jnp.ones((4,)), "static": "foo"}
flat, structure = jarp.ravel(payload)
round_trip = structure.unravel(flat)

๐Ÿ› ๏ธ Local Development

Clone the repository, sync the workspace, and use nox for the maintained automation surface:

git clone https://github.com/liblaf/jarp.git
cd jarp
uv sync --all-groups
nox --list-sessions
nox --tags test

To build the documentation site locally:

uv run zensical build

Benchmarks and API docs live under docs/, and the published site is available at liblaf.github.io/jarp.

๐Ÿค Contributing

Issues and pull requests are welcome, especially around PyTree ergonomics, Warp integration, and edge cases that show up in real JAX code.

PR WELCOME

Contributors


๐Ÿ“ License

Copyright ยฉ 2026 liblaf.
This project is MIT licensed.