gromo.linear_growing_module.LinearAdditionGrowingModule#
- class gromo.linear_growing_module.LinearAdditionGrowingModule(post_addition_function: Module = Identity(), previous_modules=None, next_modules=None, allow_growing: bool = False, in_features: int = None, device: device | None = None, name: str = None)[source]#
- compute_optimal_delta(update: bool = True, return_deltas: bool = False) list[tuple[Tensor, Tensor]] | None [source]#
Compute the optimal delta for each previous layer using current S and M tensors.
dW* = M S[-1]^-1 (if needed we use the pseudo-inverse)
Compute dW* (and dBias* if needed) and update the optimal_delta_layer attribute.
- compute_previous_m_update() tuple[Tensor, int] [source]#
Compute the update of the tensor M for the input of all previous modules. B: full activity tensor M = dLoss/dA^T B
- Returns:
torch.Tensor – update of the tensor M
int – number of samples used to compute the update
- compute_previous_s_update() tuple[Tensor, int] [source]#
Compute the update of the tensor S for the input of all previous modules. B: full activity tensor S = B^T B
- Returns:
torch.Tensor – update of the tensor S
int – number of samples used to compute the update
- compute_s_update()[source]#
Compute the update of the tensor S. With the input tensor X, the update is U^{j k} = X^{i j} X^{i k}.
- Returns:
update of the tensor S
- Return type:
torch.Tensor
- construct_full_activity()[source]#
Construct the full activity tensor B from the input of all previous modules. B = (B_1, B_2, …, B_k) in (n, C1 + C2 + … + Ck) with Ck the number of features of the k-th module. With B_i = (X_i, 1) in (n, C_i’ + 1) if the bias is used.
- Returns:
full activity tensor
- Return type:
torch.Tensor
- set_next_modules(next_modules: list[AdditionGrowingModule | GrowingModule]) None [source]#
Set the next modules of the current module.
- Parameters:
next_modules – list of next modules
- set_previous_modules(previous_modules: list[AdditionGrowingModule | GrowingModule]) None [source]#
Set the previous modules of the current module.
- Parameters:
previous_modules – list of previous modules