jaxKAN: A Unified JAX Framework for Kolmogorov-Arnold Networks

In this open-access JOSS paper, available here, we introduce jaxKAN, a high-performance JAX-based library for building and training Kolmogorov-Arnold Networks (KANs). Designed for both general-purpose and scientific applications, jaxKAN supports adaptive training, hybrid KAN architectures and Physics-Informed Kolmogorov-Arnold Networks (PIKANs).

Highlights Link to heading

  • Multiple Layer Types: Includes B-spline, Chebyshev, Legendre, RBF, Fourier and other layer types, all available through a consistent API.
  • Adaptive Training Routines: Grid updates, basis order extension and optimizer-aware state transitions are built-in.
  • Specialized PDE Tools: High-level utilities for solving PDEs using PIKANs are included, e.g., the train_PIKAN function.
  • Performance: Achieves significant speedups over prior work and supports batched least-squares operations not natively available in JAX.
@article{Rigas2025,
  title = {jax{KAN}: A unified {JAX} framework for {K}olmogorov-{A}rnold Networks},
  journal = {Journal of Open Source Software},
  year = {2025},
  volume = {10},
  number = {108},
  pages = {7830},
  doi = {10.21105/joss.07830}
}