jax.tree_util.treedef_tuple¶

jax.tree_util.treedef_tuple(treedefs)[source]¶

Makes a tuple treedef from a list of child treedefs.