Note
Go to the end to download the full example code.
Minimal Linear Growing Layers#
This example shows minimal linear growing layers.
# Authors: Theo Rudkiewicz <theo.rudkiewicz@inria.fr>
# Sylvain Chevallier <sylvain.chevallier@universite-paris-saclay.fr>
Setup#
Importing the modules
import torch
from gromo.modules.linear_growing_module import LinearGrowingModule
from gromo.utils.utils import global_device
Define three linear growing layers of size 1, 1, 1 with ReLU activation
l1 = LinearGrowingModule(
1, 1, use_bias=True, post_layer_function=torch.nn.ReLU(), name="l1"
)
l2 = LinearGrowingModule(
1,
1,
use_bias=True,
previous_module=l1,
post_layer_function=torch.nn.ReLU(),
name="l2",
)
l3 = LinearGrowingModule(1, 1, use_bias=True, previous_module=l2, name="l3")
Generate random data, initialize the computation and compute optimal updates
x = torch.randn(200, 1, device=global_device())
net = torch.nn.Sequential(l1, l2, l3)
print(net)
for layer in net:
layer.init_computation()
for layer in net:
print(layer.__str__(verbose=1))
y = net(x)
loss = torch.norm(y)
print(f"loss: {loss}")
loss.backward()
for layer in net:
layer.update_computation()
layer.compute_optimal_updates()
for layer in net:
layer.reset_computation()
l1.delete_update()
l3.delete_update()
l2.scaling_factor = 1
Sequential(
(0): LinearGrowingModule(LinearGrowingModule(l1))(in_features=1, out_features=1, use_bias=True)
(1): LinearGrowingModule(LinearGrowingModule(l2))(in_features=1, out_features=1, use_bias=True)
(2): LinearGrowingModule(LinearGrowingModule(l3))(in_features=1, out_features=1, use_bias=True)
)
LinearGrowingModule(l1) module with 2 parameters (self._allow_growing=False, self.store_input=True, self.store_pre_activity=True).
LinearGrowingModule(l2) module with 2 parameters (self._allow_growing=False, self.store_input=True, self.store_pre_activity=True).
LinearGrowingModule(l3) module with 2 parameters (self._allow_growing=False, self.store_input=True, self.store_pre_activity=True).
loss: 2.412005662918091
Print parameters before and after applying the optimal update
print(f"{l2.first_order_improvement=}")
print(f"{l2.weight=}")
print(f"{l2.bias=}")
print(f"{l2.optimal_delta_layer=}")
print(f"{l2.parameter_update_decrease=}")
print(f"{l2.extended_input_layer=}")
print(f"{l2.extended_input_layer.weight=}")
print(f"{l2.extended_input_layer.bias=}")
print(f"{l1.extended_output_layer=}")
print(f"{l2.eigenvalues_extension=}")
x_ext = None
for layer in net:
x, x_ext = layer.extended_forward(x, x_ext)
new_loss = torch.norm(x)
print(f"loss: {new_loss}, {loss - new_loss} improvement")
l2.apply_change()
print("------- New weights -------")
print(f"{l1.weight=}")
print(f"{l2.weight=}")
print(f"{l3.weight=}")
print("------- New biases -------")
print(f"{l1.bias=}")
print(f"{l2.bias=}")
print(f"{l3.bias=}")
for layer in net:
layer.init_computation()
for layer in net:
print(layer.__str__(verbose=2))
y = net(x)
loss = torch.norm(y)
print(f"loss: {loss}")
loss.backward()
for layer in net:
layer.update_computation()
layer.compute_optimal_updates()
l2.first_order_improvement=tensor(0.0016)
l2.weight=Parameter containing:
tensor([[-0.9119]], requires_grad=True)
l2.bias=Parameter containing:
tensor([0.9209], requires_grad=True)
l2.optimal_delta_layer=Linear(in_features=1, out_features=1, bias=True)
l2.parameter_update_decrease=tensor(0.0003)
l2.extended_input_layer=Linear(in_features=1, out_features=1, bias=True)
l2.extended_input_layer.weight=Parameter containing:
tensor([[0.1893]], requires_grad=True)
l2.extended_input_layer.bias=Parameter containing:
tensor([0.], requires_grad=True)
l1.extended_output_layer=Linear(in_features=1, out_features=1, bias=True)
l2.eigenvalues_extension=tensor([0.0359])
loss: 2.2087595462799072, 0.2032461166381836 improvement
------- New weights -------
l1.weight=Parameter containing:
tensor([[-0.4660],
[-0.0719]], requires_grad=True)
l2.weight=Parameter containing:
tensor([[-0.8903, 0.1893]], requires_grad=True)
l3.weight=Parameter containing:
tensor([[0.2901]], requires_grad=True)
------- New biases -------
l1.bias=Parameter containing:
tensor([0.2415, 0.1837], requires_grad=True)
l2.bias=Parameter containing:
tensor([0.9312], requires_grad=True)
l3.bias=Parameter containing:
tensor([-0.3382], requires_grad=True)
LinearGrowingModule(l1) module with 4 parameters.
Layer : Linear(in_features=1, out_features=2, bias=True)
Post layer function : ReLU()
Allow growing : False
Store input : True
self._internal_store_input=True
Store pre-activity : True
self._internal_store_pre_activity=True
Tensor S (internal) : S(LinearGrowingModule(l1)) tensor of shape (2, 2) with 0 samples
Tensor S : S(LinearGrowingModule(l1)) tensor of shape (2, 2) with 0 samples
Tensor M : M(LinearGrowingModule(l1)) tensor of shape (2, 2) with 0 samples
Optimal delta layer : None
Extended input layer : None
Extended output layer : Linear(in_features=1, out_features=1, bias=True)
LinearGrowingModule(l2) module with 3 parameters.
Layer : Linear(in_features=2, out_features=1, bias=True)
Post layer function : ReLU()
Allow growing : False
Store input : True
self._internal_store_input=True
Store pre-activity : True
self._internal_store_pre_activity=True
Tensor S (internal) : S(LinearGrowingModule(l2)) tensor of shape (3, 3) with 0 samples
Tensor S : S(LinearGrowingModule(l2)) tensor of shape (3, 3) with 0 samples
Tensor M : M(LinearGrowingModule(l2)) tensor of shape (3, 1) with 0 samples
Optimal delta layer : Linear(in_features=1, out_features=1, bias=True)
Extended input layer : Linear(in_features=1, out_features=1, bias=True)
Extended output layer : None
LinearGrowingModule(l3) module with 2 parameters.
Layer : Linear(in_features=1, out_features=1, bias=True)
Post layer function : Identity()
Allow growing : False
Store input : True
self._internal_store_input=True
Store pre-activity : True
self._internal_store_pre_activity=True
Tensor S (internal) : S(LinearGrowingModule(l3)) tensor of shape (2, 2) with 0 samples
Tensor S : S(LinearGrowingModule(l3)) tensor of shape (2, 2) with 0 samples
Tensor M : M(LinearGrowingModule(l3)) tensor of shape (2, 1) with 0 samples
Optimal delta layer : None
Extended input layer : None
Extended output layer : None
loss: 1.9270763397216797
Total running time of the script: (0 minutes 1.234 seconds)