deephyper.keras.layers.MessagePassing#

class deephyper.keras.layers.MessagePassing(*args: Any, **kwargs: Any)[source]#

Bases: Layer

Message passing layer.

Parameters:
  • state_dim (int) – number of output channels.

  • attn_heads (int) – number of attention heads.

  • attn_method (str) – type of attention methods.

  • aggr_method (str) – type of aggregation methods.

  • activation (str) – type of activation functions.

  • update_method (str) – type of update functions.

Methods

build

call

Apply the layer on input tensors.

__call__(*args: Any, **kwargs: Any) Any#

Call self as a function.

call(inputs, **kwargs)[source]#

Apply the layer on input tensors.

Parameters:

inputs (list) – X (tensor): node feature tensor (batch size * # nodes * state dimension) A (tensor): edge pair tensor (batch size * # edges * 2), one is source ID, one is target ID E (tensor): edge feature tensor (batch size * # edges * # edge features) mask (tensor): node mask tensor to mask out non-existent nodes (batch size * # nodes) degree (tensor): node degree tensor for GCN attention (batch size * # edges)

Returns:

results after edge network, attention, aggregation and update function (batch size * # nodes * state dimension)

Return type:

updated_nodes (tensor)