adam.jax.jax_like
=================

.. py:module:: adam.jax.jax_like


Classes
-------

.. autoapisummary::

   adam.jax.jax_like.JaxLike
   adam.jax.jax_like.JaxLikeFactory
   adam.jax.jax_like.SpatialMath


Module Contents
---------------

.. py:class:: JaxLike

   Bases: :py:obj:`adam.core.array_api_math.ArrayAPILike`


   Wrapper class for Jax types


   .. py:attribute:: array
      :type:  jax.numpy.array


.. py:class:: JaxLikeFactory(spec: adam.core.array_api_math.ArraySpec | None = None)

   Bases: :py:obj:`adam.core.array_api_math.ArrayAPIFactory`


   Generic factory. Give it (a) a Like class and (b) an xp namespace
   (array_api_compat.* if available; otherwise the library module).


.. py:class:: SpatialMath(spec: adam.core.array_api_math.ArraySpec | None = None)

   Bases: :py:obj:`adam.core.array_api_math.ArrayAPISpatialMath`


   A drop-in SpatialMath that implements sin/cos/outer/concat/skew with the Array API.

   Works for NumPy, PyTorch, and JAX; CasADi should keep its own subclass.


   .. py:method:: solve(A: adam.core.array_api_math.ArrayAPILike, B: adam.core.array_api_math.ArrayAPILike) -> adam.core.array_api_math.ArrayAPILike

      Override solve to handle JAX's batched solve API correctly

      JAX requires b to have shape (..., N, M) for batched solves, not just (..., N).
      This follows JAX's recommendation: use solve(a, b[..., None]).squeeze(-1) for 1D solves.



