# `JaxSim` as a multibody dynamics library

JaxSim was initially developed as a **hardware-accelerated physics engine**. Over time, it has evolved, adding new features to become a comprehensive **JAX-based multibody dynamics library**.

In this notebook, you'll explore the main APIs for loading robot models and computing key quantities for applications such as control, planning, and more.

A key advantage of JaxSim is its ability to create fully differentiable closed-loop systems, enabling end-to-end optimization. Combined with the flexibility to parameterize model kinematics and dynamics, JaxSim can serve as an excellent playground for robot learning applications.

<a target="_blank" href="https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_multibody_dynamics_library.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>


## Prepare environment

First, we need to install the necessary packages and import their resources.

In [None]:
# @title Imports and setup
from IPython.display import clear_output
import sys

IS_COLAB = "google.colab" in sys.modules

# Install JAX, sdformat, and other notebook dependencies.
if IS_COLAB:
    !{sys.executable} -m pip install --pre -qU jaxsim
    !{sys.executable} -m pip install robot_descriptions>=1.16.0
    !apt install -qq lsb-release wget gnupg
    !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg
    !echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null
    !apt -qq update
    !apt install -qq --no-install-recommends libsdformat13 gz-tools2

    clear_output()

import os
import pathlib

import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.math
import robot_descriptions
from jaxsim import logging
from jaxsim import VelRepr

logging.set_logging_level(logging.LoggingLevel.WARNING)
print(f"Running on {jax.devices()}")

## Robot model

JaxSim allows loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files.

In this example, we will use the [ErgoCub][ergocub] humanoid robot model. If you have a URDF/SDF file for your robot that is compatible with [`gazebosim/sdformat`][sdformat_github][1], it should work out-of-the-box with JaxSim.

[sdformat]: http://sdformat.org/
[urdf]: http://wiki.ros.org/urdf/
[ergocub]: https://ergocub.eu/
[sdformat_github]: https://github.com/gazebosim/sdformat

---

[1]: JaxSim validates robot descriptions using the command `gz sdf -p /path/to/file.urdf`. Ensure this command runs successfully on your file.


In [None]:
# @title Fetch the URDF file

try:
    os.environ["ROBOT_DESCRIPTION_COMMIT"] = "v0.7.7"

    import robot_descriptions.ergocub_description

finally:
    _ = os.environ.pop("ROBOT_DESCRIPTION_COMMIT", None)

model_description_path = pathlib.Path(
    robot_descriptions.ergocub_description.URDF_PATH.replace(
        "ergoCubSN002", "ergoCubSN001"
    )
)

clear_output()

### Create the model and its data

The dynamics of a generic floating-base model are governed by the following equations of motion:

$$
M(\mathbf{q}) \dot{\boldsymbol{\nu}} + \mathbf{h}(\mathbf{q}, \boldsymbol{\nu}) = B \boldsymbol{\tau} + \sum_{L_i \in \mathcal{L}} J_{W,L_i}^\top(\mathbf{q}) \: \mathbf{f}_i
.
$$

Here, the system state is represented by:

- $\mathbf{q} = ({}^W \mathbf{p}_B, \, \mathbf{s}) \in \text{SE}(3) \times \mathbb{R}^n$ is the generalized position.
- $\boldsymbol{\nu} = (\boldsymbol{v}_{W,B}, \, \boldsymbol{\omega}_{W,B}, \, \dot{\mathbf{s}}) \in \mathbb{R}^{6+n}$ is the generalized velocity.

The inputs to the system are:

- $\boldsymbol{\tau} \in \mathbb{R}^n$ are the joint torques.
- $\mathbf{f}_i \in \mathbb{R}^6$ is the 6D force applied to the link $L_i$.

JaxSim exposes functional APIs to operate over the following two main data structures:

- **`JaxSimModel`** stores all the constant information parsed from the model description.
- **`JaxSimModelData`** holds the state of model.

Additionally, JaxSim includes a utility class, **`JaxSimModelReferences`**, for managing and manipulating system inputs.

---

This notebook uses the notation summarized in the following report. Please refer to this document if you have any questions or if something is unclear.

> Traversaro and Saccon, **Multibody dynamics notation**, 2019, [URL](https://pure.tue.nl/ws/portalfiles/portal/139293126/A_Multibody_Dynamics_Notation_Revision_2_.pdf).

In [None]:
# Create the model from the model description.
# JaxSim removes all fixed joints by lumping together their parent and child links.
full_model = js.model.JaxSimModel.build_from_model_description(
    model_description=model_description_path
)

It is often useful to work with only a subset of joints, referred to as the _considered joints_. JaxSim allows to reduce a model so that the computation of the rigid body dynamics quantities is simplified.

By default, the positions of the removed joints are considered to be zero. If this is not the case, the `reduce` function accepts a dictionary `dict[str, float]` to specify custom joint positions.

In [None]:
model = js.model.reduce(
    model=full_model,
    considered_joints=tuple(
        j
        for j in full_model.joint_names()
        # Remove sensor joints.
        if "camera" not in j
        # Remove head and hands.
        and "neck" not in j
        and "wrist" not in j
        and "thumb" not in j
        and "index" not in j
        and "middle" not in j
        and "ring" not in j
        and "pinkie" not in j
        # Remove upper body.
        and "torso" not in j and "elbow" not in j and "shoulder" not in j
    ),
)

In [None]:
# Print model quantities.
print(f"Model name: {model.name()}")
print(f"Number of links: {model.number_of_links()}")
print(f"Number of joints: {model.number_of_joints()}")

print()
print(f"Links:\n{model.link_names()}")

print()
print(f"Joints:\n{model.joint_names()}")

print()
print(f"Frames:\n{model.frame_names()}")

In [None]:
# Create a random data object from the reduced model.
data = js.data.random_model_data(model=model)

# Print the default state.
W_H_B, s = data.generalized_position
ν = data.generalized_velocity

print(f"W_H_B: shape={W_H_B.shape}\n{W_H_B}\n")
print(f"s: shape={s.shape}\n{s}\n")
print(f"ν: shape={ν.shape}\n{ν}\n")  # noqa: RUF001

In [None]:
# Create a random link forces matrix.
link_forces = jax.random.uniform(
    minval=-10.0,
    maxval=10.0,
    shape=(model.number_of_links(), 6),
    key=jax.random.PRNGKey(0),
)

# Create a random joint force references vector.
# Note that these are called 'references' because the actual joint forces that
# are actuated might differ due to effects like joint friction.
joint_force_references = jax.random.uniform(
    minval=-10.0, maxval=10.0, shape=(model.dofs(),), key=jax.random.PRNGKey(0)
)

# Create the references object.
references = js.references.JaxSimModelReferences.build(
    model=model,
    data=data,
    link_forces=link_forces,
    joint_force_references=joint_force_references,
)

print(f"link_forces: shape={references.link_forces(model=model, data=data).shape}")
print(f"joint_force_references: shape={references.joint_force_references(model=model).shape}")

## Robot Kinematics

JaxSim offers functional APIs for computing kinematic quantities:

- **`jaxsim.api.model`**: vectorized functions operating on the whole model.
- **`jaxsim.api.link`**: functions operating on individual links.
- **`jaxsim.api.frame`**: functions operating on individual frames. 

Due to JAX limitations on vectorizable data types, many APIs operate on indices instead of names. Since using indices can be error prone, JaxSim provides conversion functions for both links:

- **jaxsim.api.link.names_to_idxs()**
- **jaxsim.api.link.idxs_to_names()**

and frames: 

- **jaxsim.api.frame.names_to_idxs()**
- **jaxsim.api.frame.idxs_to_names()**

We recommend using names whenever possible to avoid hard-to-trace errors.


In [None]:
# Find the index of a link.
link_name = "l_ankle_2"
link_index = js.link.name_to_idx(model=model, link_name=link_name)

In [None]:
# @title Link Pose

# Compute its pose w.r.t. the world frame through forward kinematics.
W_H_L = js.link.transform(model=model, data=data, link_index=link_index)

print(f"Transform of '{link_name}': shape={W_H_L.shape}\n{W_H_L}")

In [None]:
# @title Link 6D Velocity

# JaxSim allows to select the so-called representation of the frame velocity.
L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body)
LW_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Mixed)
W_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial)

print(f"Body-fixed velocity      L_v_WL={L_v_WL}")
print(f"Mixed velocity:         LW_v_WL={LW_v_WL}")
print(f"Inertial-fixed velocity: W_v_WL={W_v_WL}")

# These can also be computed passing through the link free-floating Jacobian.
# This type of Jacobian has a input velocity representation that corresponds
# the velocity representation of ν, and an output velocity representation that
# corresponds to the velocity representation of the desired 6D velocity.

# You can use the following context manager to easily switch between representations.
with data.switch_velocity_representation(VelRepr.Body):

    # Body-fixed generalized velocity.
    B_ν = data.generalized_velocity

    # Free-floating Jacobian accepting a body-fixed generalized velocity and
    # returning an inertial-fixed link velocity.
    W_J_WL_B = js.link.jacobian(
        model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial
    )

# Now the following relation should hold.
assert jnp.allclose(W_v_WL, W_J_WL_B @ B_ν)

In [None]:
# Find the index of a frame.
frame_name = "l_foot_front"
frame_index = js.frame.name_to_idx(model=model, frame_name=frame_name)

In [None]:
# @title Frame Pose

# Compute its pose w.r.t. the world frame through forward kinematics.
W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index)

print(f"Transform of '{frame_name}': shape={W_H_F.shape}\n{W_H_F}")

In [None]:
# @title Frame 6D Velocity

# JaxSim allows to select the so-called representation of the frame velocity.
F_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Body)
FW_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Mixed)
W_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial)

print(f"Body-fixed velocity      F_v_WF={F_v_WF}")
print(f"Mixed velocity:         FW_v_WF={FW_v_WF}")
print(f"Inertial-fixed velocity: W_v_WF={W_v_WF}")

# These can also be computed passing through the frame free-floating Jacobian.
# This type of Jacobian has a input velocity representation that corresponds
# the velocity representation of ν, and an output velocity representation that
# corresponds to the velocity representation of the desired 6D velocity.

# You can use the following context manager to easily switch between representations.
with data.switch_velocity_representation(VelRepr.Body):

    # Body-fixed generalized velocity.
    B_ν = data.generalized_velocity

    # Free-floating Jacobian accepting a body-fixed generalized velocity and
    # returning an inertial-fixed link velocity.
    W_J_WF_B = js.frame.jacobian(
        model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial
    )

# Now the following relation should hold.
assert jnp.allclose(W_v_WF, W_J_WF_B @ B_ν)

## Robot Dynamics

JaxSim provides all the quantities involved in the equations of motion, restated here:

$$
M(\mathbf{q}) \dot{\boldsymbol{\nu}} + \mathbf{h}(\mathbf{q}, \boldsymbol{\nu}) = B \boldsymbol{\tau} + \sum_{L_i \in \mathcal{L}} J_{W,L_i}^\top(\mathbf{q}) \: \mathbf{f}_i
.
$$

Specifically, it can compute:

- $M(\mathbf{q}) \in \mathbb{R}^{(6+n)\times(6+n)}$: the mass matrix.
- $\mathbf{h}(\mathbf{q}, \boldsymbol{\nu}) \in \mathbb{R}^{6+n}$: the vector of bias forces.
- $B \in \mathbb{R}^{(6+n) \times n}$ the joint selector matrix.
- $J_{W,L} \in \mathbb{R}^{6 \times (6+n)}$ the Jacobian of link $L$.

Often, for convenience, link Jacobians are stacked together. Since JaxSim efficiently computes the Jacobians for all links, using the stacked version is recommended when needed:

$$
M(\mathbf{q}) \dot{\boldsymbol{\nu}} + \mathbf{h}(\mathbf{q}, \boldsymbol{\nu}) = B \boldsymbol{\tau} + J_{W,\mathcal{L}}^\top(\mathbf{q}) \: \mathbf{f}_\mathcal{L}
.
$$

Furthermore, there are applications that require unpacking the vector of bias forces as follow:

$$
\mathbf{h}(\mathbf{q}, \boldsymbol{\nu}) = C(\mathbf{q}, \boldsymbol{\nu}) \boldsymbol{\nu} + \mathbf{g}(\mathbf{q})
,
$$

where:

- $\mathbf{g}(\mathbf{q}) \in \mathbb{R}^{6+n}$: the vector of gravity forces.
- $C(\mathbf{q}, \boldsymbol{\nu}) \in \mathbb{R}^{(6+n)\times(6+n)}$: the Coriolis matrix.

Here below we report the functions to compute all these quantities. Note that all quantities depend on the active velocity representation of `data`. As it was done for the link velocity, it is possible to change the representation associated to all the computed quantities by operating within the corresponding context manager. Here below we consider the default representation of data.

In [None]:
print("Velocity representation of data:", data.velocity_representation, "\n")

# Compute the mass matrix.
M = js.model.free_floating_mass_matrix(model=model, data=data)
print(f"M:   shape={M.shape}")

# Compute the vector of bias forces.
h = js.model.free_floating_bias_forces(model=model, data=data)
print(f"h:   shape={h.shape}")

# Compute the vector of gravity forces.
g = js.model.free_floating_gravity_forces(model=model, data=data)
print(f"g:   shape={g.shape}")

# Compute the Coriolis matrix.
C = js.model.free_floating_coriolis_matrix(model=model, data=data)
print(f"C:   shape={C.shape}")

# Create a the joint selector matrix.
B = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T
print(f"B:   shape={B.shape}")

# Compute the stacked tensor of link Jacobians.
J = js.model.generalized_free_floating_jacobian(model=model, data=data)
print(f"J:   shape={J.shape}")

# Extract the joint forces from the references object.
τ = references.joint_force_references(model=model)
print(f"τ:   shape={τ.shape}")

# Extract the link forces from the references object.
f_L = references.link_forces(model=model, data=data)
print(f"f_L: shape={f_L.shape}")

# The following relation should hold.
assert jnp.allclose(h, C @ ν + g)

### Forward Dynamics

$$
\dot{\boldsymbol{\nu}} = \text{FD}(\mathbf{q}, \boldsymbol{\nu}, \boldsymbol{\tau}, \mathbf{f}_{\mathcal{L}})
$$

JaxSim provides two alternative methods to compute the forward dynamics:

1. Operate on the quantities of the equations of motion.
2. Call the recursive Articulated Body Algorithm (ABA).

The physics engine provided by JaxSim exploits the efficient calculation of the forward dynamics with ABA for simulating the trajectories of the system dynamics.

In [None]:
ν̇_eom = jnp.linalg.pinv(M) @ (B @ τ - h + jnp.einsum("l6g,l6->g", J, f_L))

v̇_WB, s̈ = js.model.forward_dynamics_aba(
    model=model, data=data, link_forces=f_L, joint_forces=joint_force_references
)

ν̇_aba = jnp.hstack([v̇_WB, s̈])
print(f"ν̇: shape={ν̇_aba.shape}")  # noqa: RUF001

# The following relation should hold.
assert jnp.allclose(ν̇_eom, ν̇_aba)

### Inverse Dynamics

$$
(\boldsymbol{\tau}, \, \mathbf{f}_B) = \text{ID}(\mathbf{q}, \boldsymbol{\nu}, \dot{\boldsymbol{\nu}}, \mathbf{f}_{\mathcal{L}})
$$

JaxSim offers two methods to compute inverse dynamics:

- Directly use the quantities from the equations of motion.
- Use the Recursive Newton-Euler Algorithm (RNEA).

Unlike many other implementations, JaxSim's RNEA for floating-base systems is the true inverse of $\text{FD}$. It also computes the 6D force applied to the base link that generates the base acceleration.

In [None]:
f_B, τ_rnea = js.model.inverse_dynamics(
    model=model,
    data=data,
    base_acceleration=v̇_WB,
    joint_accelerations=s̈,
    # To check that f_B works, let's remove the force applied
    # to the base link from the link forces.
    link_forces=f_L.at[0].set(jnp.zeros(6))
)

print(f"f_B:    shape={f_B.shape}")
print(f"τ_rnea: shape={τ_rnea.shape}")

# The following relations should hold.
assert jnp.allclose(τ_rnea, τ)
assert jnp.allclose(f_B, link_forces[0])

### Centroidal Dynamics

Centroidal dynamics is a useful simplification often employed in planning and control applications. It represents the dynamics projected onto a mixed frame associated with the center of mass (CoM):

$$
G = G[W] = ({}^W \mathbf{p}_{\text{CoM}}, [W])
.
$$

The governing equations for centroidal dynamics take into account the 6D centroidal momentum:

$$
{}_G \mathbf{h} =
\begin{bmatrix}
{}_G \mathbf{h}^l \\ {}_G \mathbf{h}^\omega
\end{bmatrix} =
\begin{bmatrix}
m \, {}^W \dot{\mathbf{p}}_\text{CoM} \\ {}_G \mathbf{h}^\omega
\end{bmatrix}
\in \mathbb{R}^6
.
$$

The equations of centroidal dynamics can be expressed as:

$$
{}_G \dot{\mathbf{h}} =
m \,
\begin{bmatrix}
{}^W \mathbf{g} \\ \mathbf{0}_3
\end{bmatrix} +
\sum_{C_i \in \mathcal{C}} {}_G \mathbf{X}^{C_i} \, {}_{C_i} \mathbf{f}_i
.
$$

While centroidal dynamics can function independently by considering the total mass $m \in \mathbb{R}$ of the robot and the transformations for 6D contact forces ${}_G \mathbf{X}^{C_i}$ corresponding to the pose ${}^G \mathbf{H}_{C_i} \in \text{SE}(3)$ of the contact frames, advanced kino-dynamic methods may require a relationship between full kinematics and centroidal dynamics. This is typically achieved through the _Centroidal Momentum Matrix_ (also known as the _centroidal momentum Jacobian_):

$$
{}_G \mathbf{h} = J_\text{CMM}(\mathbf{q}) \, \boldsymbol{\nu}
.
$$

JaxSim offers APIs to compute all these quantities (and many more) in the `jaxsim.api.com` package.

In [None]:
# Number of contact points.
n_cp = len(model.kin_dyn_parameters.contact_parameters.body)
print("Number of contact points:", n_cp, "\n")

# Compute the centroidal momentum.
J_CMM = js.com.centroidal_momentum_jacobian(model=model, data=data)
G_h = J_CMM @ ν
print(f"G_h:    shape={G_h.shape}")
print(f"J_CMM:  shape={J_CMM.shape}")

# The following relation should hold.
assert jnp.allclose(G_h, js.com.centroidal_momentum(model=model, data=data))

# If we consider all contact points of the model as active
# (discourages since they might be too many), the 6D transforms of
# collidable points can be computed as follows:
W_H_C = js.contact.transforms(model=model, data=data)

# Compute the pose of the G frame.
W_p_CoM = js.com.com_position(model=model, data=data)
G_H_W = jaxsim.math.Transform.inverse(jnp.eye(4).at[0:3, 3].set(W_p_CoM))

# Convert from SE(3) to the transforms for 6D forces.
G_Xf_C = jax.vmap(
    lambda W_H_Ci: jaxsim.math.Adjoint.from_transform(
        transform=G_H_W @ W_H_Ci, inverse=True
    )
)(W_H_C)
print(f"G_Xf_C: shape={G_Xf_C.shape}")

# Let's create random 3D linear forces applied to the contact points.
C_fl = jax.random.uniform(
    minval=-10.0,
    maxval=10.0,
    shape=(n_cp, 3),
    key=jax.random.PRNGKey(0),
)

# Compute the 3D gravity vector and the total mass of the robot.
m = js.model.total_mass(model=model)

# The centroidal dynamics can be computed as follows.
G_ḣ = 0
G_ḣ += m * jnp.hstack([0, 0, model.gravity, 0, 0, 0])
G_ḣ += jnp.einsum("c66,c6->6", G_Xf_C, jnp.hstack([C_fl, jnp.zeros_like(C_fl)]))
print(f"G_ḣ:    shape={G_ḣ.shape}")

## Contact Frames

Many control and planning applications require projecting the floating-base dynamics into the contact space or computing quantities related to active contact points, such as enforcing holonomic constraints.

The underlying theory for these applications becomes clearer in a mixed representation. Specifically, the position, linear velocity, and linear acceleration of contact points in their corresponding mixed frame align with the numerical derivatives of their coordinate vectors.

Key methodologies in this area may involve the Delassus matrix:

$$
\Psi(\mathbf{q}) = J_{W,C}(\mathbf{q}) \, M(\mathbf{q})^{-1} \, J_{W,C}^T(\mathbf{q})
$$

or the linear acceleration of a contact point:

$$
{}^W \ddot{\mathbf{p}}_C = \frac{\text{d} (J^l_{W,C} \boldsymbol{\nu})}{\text{d}t}
= \dot{J}^l_{W,C} \boldsymbol{\nu} + J^l_{W,C} \dot{\boldsymbol{\nu}}
.
$$

JaxSim offers APIs to compute all these quantities (and many more) in the `jaxsim.api.contact` package.

In [None]:
with (
    data.switch_velocity_representation(VelRepr.Mixed),
    references.switch_velocity_representation(VelRepr.Mixed),
):

    # Compute the mixed generalized velocity.
    BW_ν = data.generalized_velocity

    # Compute the mixed generalized acceleration.
    BW_ν̇ = jnp.hstack(
        js.model.forward_dynamics(
            model=model,
            data=data,
            link_forces=references.link_forces(model=model, data=data),
            joint_forces=references.joint_force_references(model=model),
        )
    )

    # Compute the mass matrix in mixed representation.
    BW_M = js.model.free_floating_mass_matrix(model=model, data=data)

    # Compute the contact Jacobian and its derivative.
    Jl_WC = js.contact.jacobian(model=model, data=data)[:, 0:3, :]
    J̇l_WC = js.contact.jacobian_derivative(model=model, data=data)[:, 0:3, :]

# Compute the Delassus matrix.
Ψ = jnp.vstack(Jl_WC) @ jnp.linalg.lstsq(BW_M, jnp.vstack(Jl_WC).T)[0]
print(f"Ψ:     shape={Ψ.shape}")

# Compute the transforms of the mixed frames implicitly associated
# to each collidable point.
W_H_C = js.contact.transforms(model=model, data=data)
print(f"W_H_C: shape={W_H_C.shape}")

# Compute the linear velocity of the collidable points.
with data.switch_velocity_representation(VelRepr.Mixed):
    W_ṗ_B = js.contact.collidable_point_velocities(model=model, data=data)[:, 0:3]
    print(f"W_ṗ_B: shape={W_ṗ_B.shape}")

# Compute the linear acceleration of the collidable points.
W_p̈_C = 0
W_p̈_C += jnp.einsum("c3g,g->c3", J̇l_WC, BW_ν)
W_p̈_C += jnp.einsum("c3g,g->c3", Jl_WC, BW_ν̇)
print(f"W_p̈_C: shape={W_p̈_C.shape}")

## Conclusions

This notebook provided an overview of the main APIs in JaxSim for its use as a multibody dynamics library. Here are a few key points to remember:

- Explore all the modules in the `jaxsim.api` package to discover the full range of APIs available. Many more functionalities exist beyond what was covered in this notebook.
- All APIs follow a functional approach, consistent with the JAX programming style.
- This functional design allows for easy application of `jax.vmap` to execute functions in parallel on hardware accelerators.
- Since the entire multibody dynamics library is built with JAX, it natively supports `jax.grad`, `jax.jacfwd`, and `jax.jacrev` transformations, enabling automatic differentiation through complex logic without additional effort.

Have fun!