Add with_path to allow modules to be flattened with key paths.
```
>>> class MyModule(tf.Module):
... @property
... def state_dict(self):
... return dict(self._flatten(
... predicate=lambda v: isinstance(v, tf.Variable), with_path=True))
>>> mod = MyModule()
>>> mod.encoder = Encoder()
>>> mod.decoder = mod.encoder
>>> mod.state_dict
{('encoder', 'w'), <tf.Variable: ...>,
('decoder', 'w'), <tf.Variable: ...>}
```
h/t tensorflow/community#56
PiperOrigin-RevId: 232908045
Loading
Please sign in to comment