Running Jax with cudnn provided by lambda-stack

I am trying to install jax on a lambda machine which has the lambda-stack installed but I am unable to locate cudnn and make it visible to my jax installation.
Where should libcudnn be located if the stack was correctly installed? Is there a recommended way of install jax with the lambda stack?

Thanks,
Simon

you should find it under:

/usr/lib/python3/dist-packages/tensorflow/

make sure to add the path to LD_LIBRARY_PATH variable which is for runtime and LIBRARY_PATH variable for compile time linking in your bash profile or bashrc file.