Skip to content

Neural Network¤

The main object of the connex library.

connex.NeuralNetwork ¤

A neural network whose structure is specified by a DAG.

__init__(self, graph_data: Any, input_neurons: Sequence[Any], output_neurons: Sequence[Any], hidden_activation: Callable = <function gelu>, output_transformation: Callable = <function _identity>, dropout_p: Union[float, Mapping[Any, float]] = 0.0, use_topo_norm: bool = False, use_topo_self_attention: bool = False, use_neuron_self_attention: bool = False, use_adaptive_activations: bool = False, topo_sort: Optional[Sequence[Any]] = None, *, key: Optional[PRNGKey] = None) ¤

Arguments:

  • graph_data: A networkx.DiGraph, or data that can be turned into a networkx.DiGraph by calling networkx.DiGraph(graph_data) (such as an adjacency dict) representing the DAG structure of the neural network. All nodes of the graph must have the same type.
  • input_neurons: An Sequence of nodes from graph indicating the input neurons. The order here matters, as the input data will be passed into the input neurons in the order specified here.
  • output_neurons: An Sequence of nodes from graph indicating the output neurons. The order here matters, as the output data will be read from the output neurons in the order specified here.
  • hidden_activation: The activation function applied element-wise to the hidden (i.e. non-input, non-output) neurons. It can itself be a trainable equinox.Module.
  • output_transformation: The transformation applied group-wise to the output neurons, e.g. jax.nn.softmax. It can itself be a trainable equinox.Module.
  • dropout_p: Dropout probability. If a single float, the same dropout probability will be applied to all hidden neurons. If a Mapping[Any, float], dropout_p[i] refers to the dropout probability of neuron i. All neurons default to zero unless otherwise specified. Note that this allows dropout to be applied to input and output neurons as well.
  • use_topo_norm: A bool indicating whether to apply a topological batch- version of Layer Norm,

    Cite

    Layer Normalization

    @article{ba2016layer,
        author={Jimmy Lei Ba, Jamie Ryan Kriso, Geoffrey E. Hinton},
        title={Layer Normalization},
        year={2016},
        journal={arXiv:1607.06450},
    }
    

    where the collective inputs of each topological batch are standardized (made to have mean 0 and variance 1), with learnable elementwise-affine parameters gamma and beta. - use_topo_self_attention: A bool indicating whether to apply (single-headed) self-attention to each topological batch's collective inputs.

    Cite

    Attention is All You Need

    @inproceedings{vaswani2017attention,
        author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and
                Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and
                Kaiser, {\L}ukasz and Polosukhin, Illia},
        booktitle={Advances in Neural Information Processing Systems},
        publisher={Curran Associates, Inc.},
        title={Attention is All You Need},
        volume={30},
        year={2017}
    }
    

  • use_neuron_self_attention: A bool indicating whether to apply neuron-wise self-attention, where each neuron applies (single-headed) self-attention to its inputs. If both _use_neuron_self_attention and use_neuron_norm are True, normalization is applied before self-attention.

    Warning

    Neuron-level self-attention will use significantly more memory than than topo-level self-attention.

  • use_adaptive_activations: A bool indicating whether to use neuron-wise adaptive activations, where all hidden activations transform as σ(x) -> a * σ(b * x), where a, b are trainable scalar parameters unique to each neuron.

    Cite

    Locally adaptive activation functions with slope recovery term for deep and physics-informed neural networks # noqa: E501

    @article{Jagtap_2020,
        author={Ameya D. Jagtap, Kenji Kawaguchi, George Em Karniadakis},
        title={Locally adaptive activation functions with slope recovery
               term for deep and physics-informed neural networks},
        year={2020},
        publisher={The Royal Society},
        journal={Proceedings of the Royal Society A: Mathematical, Physical
        and Engineering Sciences},
    }
    

  • topo_sort: An optional sequence of neurons indicating a topological sort of the graph. If None, a topological sort will be performed on the graph, which may be time-consuming for some networks.

  • key: The jax.random.PRNGKey used for parameter initialization and dropout. Optional, keyword-only argument. Defaults to jax.random.PRNGKey(0).
__call__(self, x: Array, *, key: Optional[PRNGKey] = None) -> Array ¤

The forward pass of the network. Neurons are "fired" in topological batch order -- see Section 2.2 of

Cite

Directed Acyclic Graph Neural Networks

@inproceedings{thost2021directed,
    author={Veronika Thost and Jie Chen},
    booktitle={International Conference on Learning Representations},
    publisher={Curran Associates, Inc.},
    title={Directed Acyclic Graph Neural Networks},
    year={2021}
}

Arguments:

  • x: The input array to the network for the forward pass. The individual values will be written to the input neurons in the order passed in during initialization.
  • key: A jax.random.PRNGKey used for dropout. Optional, keyword-only argument. If None, a key will be generated using the current time.

Returns:

The result array from the forward pass. The order of the array elements will be the order of the output neurons passed in during initialization.

to_networkx_weighted_digraph(self) -> DiGraph ¤

Returns a networkx.DiGraph represention of the network with neuron weights saved as edge attributes.