deephyper.keras.layers.SparseMPNN
deephyper.keras.layers.SparseMPNN#
-
class
deephyper.keras.layers.
SparseMPNN
(*args: Any, **kwargs: Any)[source]# Bases:
tensorflow.keras.layers.
Message passing cell.
- Parameters
state_dim (int) – number of output channels.
T (int) – number of message passing repetition.
attn_heads (int) – number of attention heads.
attn_method (str) – type of attention methods.
aggr_method (str) – type of aggregation methods.
activation (str) – type of activation functions.
update_method (str) – type of update functions.
Methods
build
Apply the layer on input tensors.
-
__call__
(*args: Any, **kwargs: Any) → Any# Call self as a function.
-
call
(inputs, **kwargs)[source]# Apply the layer on input tensors.
- Parameters
inputs (list) – X (tensor): node feature tensor (batch size * # nodes * # node features) A (tensor): edge pair tensor (batch size * # edges * 2), one is source ID, one is target ID E (tensor): edge feature tensor (batch size * # edges * # edge features) mask (tensor): node mask tensor to mask out non-existent nodes (batch size * # nodes) degree (tensor): node degree tensor for GCN attention (batch size * # edges)
- Returns
results after several repetitions of edge network, attention, aggregation and update function (batch size * # nodes * # node features)
- Return type
X (tensor)