Switch tf.keras.Model.save_weights to TensorFlow format for graph networks.
Some tweaks to support loading checkpoints into modified Python programs. Relaxes the Checkpointable consistency check for object matching: if the same object in the checkpoint matches two different Python objects, it will just choose the first one that matches (based on whichever traversal of the Python dependency graph it's doing). assert_consumed() on the status object will fail, but this gives users the option of continuing anyway. Adds a "weight-bearing layer index" dependency to graph networks which skips Layers without weights, in addition to the regular layer index. This allows users to add Layers without weights while not breaking checkpoints, as they could when matching with flattened weights from HDF5 format. Eventually I'd like to add a dependency structure which matches the topology of the graph itself (so a Layer would have checkpoint dependencies on other Layers it outputs to), but there are some subtleties before that's useful (it'd need something like a secondary check that the Python classes match). I think the scheme in this CL is robust enough for general use, and adding more dependencies later can make it more robust now that users won't run into consistency check errors (previously more dependencies would only make matching more picky). PiperOrigin-RevId: 194277075
Loading
Please sign in to comment