jax.tree_util.keystr#

jax.tree_util.keystr(keys, *, simple=False, separator='')[source]#

Helper to pretty-print a tuple of keys.

Parameters:
  • keys (KeyPath) – A tuple of KeyEntry or any class that can be converted to string.

  • simple (bool) – If True, use a simplified string representation for keys. The simple representation of keys will be more compact than the default, but is ambiguous in some cases (for example “0” might refer to the first item in a list or a dictionary key for the integer 0 or string “0”).

  • separator (str) – The separator to use to join string representations of the keys.

Returns:

A string that joins all string representations of the keys.

Return type:

str

Examples

>>> import jax
>>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
...   print(jax.tree_util.keystr(path))
['foo']['bar']['bat'][0]
['foo']['bar']['bat'][1]
['foo']['bar']['baz']
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
...   print(jax.tree_util.keystr(path, simple=True, separator='/'))
foo/bar/bat/0
foo/bar/bat/1
foo/bar/baz