jax.tree_util.tree_structure¶

jax.tree_util.tree_structure(tree)[source]¶

Gets the treedef for a pytree.