Neural Network¤
The main object of the connex
library.
connex.NeuralNetwork
¤
A neural network whose structure is specified by a DAG.
__init__(self, graph_data: Any, input_neurons: Sequence[Any], output_neurons: Sequence[Any], hidden_activation: Callable = <function gelu>, output_transformation: Callable = <function _identity>, dropout_p: Union[float, Mapping[Any, float]] = 0.0, use_topo_norm: bool = False, use_topo_self_attention: bool = False, use_neuron_self_attention: bool = False, use_adaptive_activations: bool = False, topo_sort: Optional[Sequence[Any]] = None, *, key: Optional[PRNGKey] = None)
¤
Arguments:
graph_data
: Anetworkx.DiGraph
, or data that can be turned into anetworkx.DiGraph
by callingnetworkx.DiGraph(graph_data)
(such as an adjacency dict) representing the DAG structure of the neural network. All nodes of the graph must have the same type.input_neurons
: AnSequence
of nodes fromgraph
indicating the input neurons. The order here matters, as the input data will be passed into the input neurons in the order specified here.output_neurons
: AnSequence
of nodes fromgraph
indicating the output neurons. The order here matters, as the output data will be read from the output neurons in the order specified here.hidden_activation
: The activation function applied element-wise to the hidden (i.e. non-input, non-output) neurons. It can itself be a trainableequinox.Module
.output_transformation
: The transformation applied group-wise to the output neurons, e.g.jax.nn.softmax
. It can itself be a trainableequinox.Module
.dropout_p
: Dropout probability. If a singlefloat
, the same dropout probability will be applied to all hidden neurons. If aMapping[Any, float]
,dropout_p[i]
refers to the dropout probability of neuroni
. All neurons default to zero unless otherwise specified. Note that this allows dropout to be applied to input and output neurons as well.-
use_topo_norm
: Abool
indicating whether to apply a topological batch- version of Layer Norm,Cite
@article{ba2016layer, author={Jimmy Lei Ba, Jamie Ryan Kriso, Geoffrey E. Hinton}, title={Layer Normalization}, year={2016}, journal={arXiv:1607.06450}, }
where the collective inputs of each topological batch are standardized (made to have mean 0 and variance 1), with learnable elementwise-affine parameters
gamma
andbeta
. -use_topo_self_attention
: Abool
indicating whether to apply (single-headed) self-attention to each topological batch's collective inputs.Cite
@inproceedings{vaswani2017attention, author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and Kaiser, {\L}ukasz and Polosukhin, Illia}, booktitle={Advances in Neural Information Processing Systems}, publisher={Curran Associates, Inc.}, title={Attention is All You Need}, volume={30}, year={2017} }
-
use_neuron_self_attention
: Abool
indicating whether to apply neuron-wise self-attention, where each neuron applies (single-headed) self-attention to its inputs. If both_use_neuron_self_attention
anduse_neuron_norm
areTrue
, normalization is applied before self-attention.Warning
Neuron-level self-attention will use significantly more memory than than topo-level self-attention.
-
use_adaptive_activations
: A bool indicating whether to use neuron-wise adaptive activations, where all hidden activations transform asσ(x) -> a * σ(b * x)
, wherea
,b
are trainable scalar parameters unique to each neuron.Cite
Locally adaptive activation functions with slope recovery term for deep and physics-informed neural networks # noqa: E501
@article{Jagtap_2020, author={Ameya D. Jagtap, Kenji Kawaguchi, George Em Karniadakis}, title={Locally adaptive activation functions with slope recovery term for deep and physics-informed neural networks}, year={2020}, publisher={The Royal Society}, journal={Proceedings of the Royal Society A: Mathematical, Physical and Engineering Sciences}, }
-
topo_sort
: An optional sequence of neurons indicating a topological sort of the graph. IfNone
, a topological sort will be performed on the graph, which may be time-consuming for some networks. key
: Thejax.random.PRNGKey
used for parameter initialization and dropout. Optional, keyword-only argument. Defaults tojax.random.PRNGKey(0)
.
__call__(self, x: Array, *, key: Optional[PRNGKey] = None) -> Array
¤
The forward pass of the network. Neurons are "fired" in topological batch order -- see Section 2.2 of
Cite
Directed Acyclic Graph Neural Networks
@inproceedings{thost2021directed,
author={Veronika Thost and Jie Chen},
booktitle={International Conference on Learning Representations},
publisher={Curran Associates, Inc.},
title={Directed Acyclic Graph Neural Networks},
year={2021}
}
Arguments:
x
: The input array to the network for the forward pass. The individual values will be written to the input neurons in the order passed in during initialization.key
: Ajax.random.PRNGKey
used for dropout. Optional, keyword-only argument. IfNone
, a key will be generated using the current time.
Returns:
The result array from the forward pass. The order of the array elements will be the order of the output neurons passed in during initialization.
to_networkx_weighted_digraph(self) -> DiGraph
¤
Returns a networkx.DiGraph
represention of the network with neuron weights
saved as edge attributes.