Skip to content

Artificial Neuroplasticity¤

The brain has a remarkable ability to rewire itself under the right conditions, known as plasticity. This includes, among other processes, the formation of new synapses (synaptogenesis), the removal of synapses (synaptic pruning), the formation of new neurons (neurogenesis), and the removal of neurons (programmed cell death).

Furthermore, specific neurons/clusters of neurons can be made to be more or less likely to fire, known as neuromodulation.

We provide the following code functionality to mirror these processes. All return a copy of the network and leave the input network unmodified.

connex.add_connections(network: NeuralNetwork, connections: Union[Sequence[Tuple[Any, Any]], Mapping[Any, Sequence[Any]]], *, key: Optional[PRNGKey] = None) -> NeuralNetwork ¤

Add connections to the network.

Arguments:

  • network: A NeuralNetwork object.
  • connections: The directed edges to add. Must be a sequence of 2-tuples, or an adjacency dict mapping an existing neuron (by its NetworkX id) to its new outgoing connections. Connections that already exist are ignored.
  • key: The jax.random.PRNGKey used for new weight initialization. Optional, keyword-only argument. Defaults to jax.random.PRNGKey(0).

Returns:

A NeuralNetwork object with the specified connections added and original parameters retained.


connex.remove_connections(network: NeuralNetwork, connections: Union[Sequence[Tuple[Any, Any]], Mapping[Any, Sequence[Any]]]) -> NeuralNetwork ¤

Remove connections from the network.

Arguments:

  • network: A NeuralNetwork object.
  • connections: The directed edges to remove. Must be a sequence of 2-tuples, or an adjacency dict mapping an existing neuron (by its NetworkX id) to its new outgoing connections. Connections that already exist are ignored.

Returns:

A NeuralNetwork object with the specified connections removed and original parameters retained.


connex.add_input_neurons(network: NeuralNetwork, new_input_neurons: Sequence[Any], *, key: Optional[PRNGKey] = None) -> NeuralNetwork ¤

Add input neurons to the network. Note that this function only adds neurons themselves, not any connections associated with the new neurons, effectively adding them as isolated nodes in the graph. Use connex.add_connections after this function has been called to add the desired connections.

Arguments:

  • network: The NeuralNetwork to add neurons to
  • new_input_neurons: A sequence of new input neurons (more specifically, their identifiers/names) to add to the network. These must be unique, i.e. cannot already exist in the network. These must also specifically be input neurons. To add hidden or output neurons, use connex.add_hidden_neurons or connex.add_output_neurons.
  • key: The jax.random.PRNGKey used for new parameter initialization. Optional, keyword-only argument. Defaults to jax.random.PRNGKey(0).

Returns:

A NeuralNetwork with the new input neurons added and parameters from the original network retained. The new input neurons are added before the existing input neurons. For example, if the previous input neurons were [0, 1] and new input neurons [2, 3] were added, the new input neurons would be [2, 3, 0, 1].


connex.add_hidden_neurons(network: NeuralNetwork, new_hidden_neurons: Sequence[Any], *, key: Optional[PRNGKey] = None) -> NeuralNetwork ¤

Add hidden neurons to the network. Note that this function only adds neurons themselves, not any connections associated with the new neurons, effectively adding them as isolated nodes in the graph. Use connex.add_connections after this function has been called to add the desired connections.

Arguments:

  • network: The NeuralNetwork to add neurons to
  • new_hidden_neurons: A sequence of new hidden neurons (more specifically, their identifiers/names) to add to the network. These must be unique, i.e. cannot already exist in the These must also specifically be hidden neurons. To add input or output neurons, use connex.add_input_neurons or connex.add_output_neurons.
  • key: The jax.random.PRNGKey used for new parameter initialization. Optional, keyword-only argument. Defaults to jax.random.PRNGKey(0).

Returns:

A NeuralNetwork with the new hidden neurons added and parameters from the original network retained.


connex.add_output_neurons(network: NeuralNetwork, new_output_neurons: Sequence[Any], *, key: Optional[PRNGKey] = None) -> NeuralNetwork ¤

Add output neurons to the network. Note that this function only adds neurons themselves, not any connections associated with the new neurons, effectively adding them as isolated nodes in the graph. Use connex.add_connections after this function has been called to add any desired connections.

Arguments:

  • network: The NeuralNetwork to add neurons to
  • new_output_neurons: A sequence of new output neurons (more specifically, their identifiers/names) to add to the network. These must be unique, i.e. cannot already exist in the network. These must also specifically be output neurons. To add input or output neurons, use connex.add_input_neurons or connex.add_output_neurons.
  • key: The jax.random.PRNGKey used for new parameter initialization. Optional, keyword-only argument. Defaults to jax.random.PRNGKey(0).

Returns:

A NeuralNetwork with the new output neurons added and parameters from the original network retained.


connex.remove_neurons(network: NeuralNetwork, neurons: Sequence[Any]) -> NeuralNetwork ¤

Remove neurons and any of their incoming/outgoing connections from the network.

Arguments:

  • network: The NeuralNetwork to add neurons to
  • neurons: A sequence of neurons (more specifically, their identifiers/names) to remove from the network. These can be input, hidden, or output neurons.

Returns:

A NeuralNetwork with the specified neurons removed and parameters from the original network retained.


connex.set_dropout_p(network: NeuralNetwork, dropout_p: Union[float, Mapping[Any, float]]) -> NeuralNetwork ¤

Set the per-neuron dropout probabilities.

Arguments:

  • network: The NeuralNetwork whose dropout probabilities will be modified.
  • dropout_p: Either a float or mapping from neuron (Any) to float. If a single float, all hidden neurons will have that dropout probability, and all input and output neurons will have dropout probability 0 by default. If a Mapping, it is assumed that dropout_p maps a neuron to its dropout probability, and all unspecified neurons will retain their current dropout probability.

Returns:

A copy of the network with dropout probabilities as specified. The original network (including unspecified dropout probabilities) is left unchanged.