gromo.linear_growing_module.LinearAdditionGrowingModule#

class gromo.linear_growing_module.LinearAdditionGrowingModule(post_addition_function: Module = Identity(), previous_modules=None, next_modules=None, allow_growing: bool = False, in_features: int = None, device: device | None = None, name: str = None)[source]#
compute_optimal_delta(update: bool = True, return_deltas: bool = False) list[tuple[Tensor, Tensor]] | None[source]#

Compute the optimal delta for each previous layer using current S and M tensors.

dW* = M S[-1]^-1 (if needed we use the pseudo-inverse)

Compute dW* (and dBias* if needed) and update the optimal_delta_layer attribute.

Parameters:
  • update (bool) – if True update the optimal delta layer attribute

  • return_deltas (bool) – if True return the deltas

Returns:

optimal delta for the weights and the biases if needed

Return type:

list[tuple[torch.Tensor, torch.Tensor]] | None

compute_previous_m_update() tuple[Tensor, int][source]#

Compute the update of the tensor M for the input of all previous modules. B: full activity tensor M = dLoss/dA^T B

Returns:

  • torch.Tensor – update of the tensor M

  • int – number of samples used to compute the update

compute_previous_s_update() tuple[Tensor, int][source]#

Compute the update of the tensor S for the input of all previous modules. B: full activity tensor S = B^T B

Returns:

  • torch.Tensor – update of the tensor S

  • int – number of samples used to compute the update

compute_s_update()[source]#

Compute the update of the tensor S. With the input tensor X, the update is U^{j k} = X^{i j} X^{i k}.

Returns:

update of the tensor S

Return type:

torch.Tensor

construct_full_activity()[source]#

Construct the full activity tensor B from the input of all previous modules. B = (B_1, B_2, …, B_k) in (n, C1 + C2 + … + Ck) with Ck the number of features of the k-th module. With B_i = (X_i, 1) in (n, C_i’ + 1) if the bias is used.

Returns:

full activity tensor

Return type:

torch.Tensor

set_next_modules(next_modules: list[AdditionGrowingModule | GrowingModule]) None[source]#

Set the next modules of the current module.

Parameters:

next_modules – list of next modules

set_previous_modules(previous_modules: list[AdditionGrowingModule | GrowingModule]) None[source]#

Set the previous modules of the current module.

Parameters:

previous_modules – list of previous modules