Python/maths/gaussian_error_linear_unit.py
Caeden Perelli-Harris 490e645ed3
Fix minor typing errors in maths/ (#8959)
* updating DIRECTORY.md

* types(maths): Fix pylance issues in maths

* reset(vsc): Reset settings changes

* Update maths/jaccard_similarity.py

Co-authored-by: Tianyi Zheng <tianyizheng02@gmail.com>

* revert(erosion_operation): Revert erosion_operation

* test(jaccard_similarity): Add doctest to test alternative_union

* types(newton_raphson): Add typehints to func bodies

---------

Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com>
Co-authored-by: Tianyi Zheng <tianyizheng02@gmail.com>
2023-08-15 14:27:41 -07:00

54 lines
1.5 KiB
Python

"""
This script demonstrates an implementation of the Gaussian Error Linear Unit function.
* https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions
The function takes a vector of K real numbers as input and returns x * sigmoid(1.702*x).
Gaussian Error Linear Unit (GELU) is a high-performing neural network activation
function.
This script is inspired by a corresponding research paper.
* https://arxiv.org/abs/1606.08415
"""
import numpy as np
def sigmoid(vector: np.ndarray) -> np.ndarray:
"""
Mathematical function sigmoid takes a vector x of K real numbers as input and
returns 1/ (1 + e^-x).
https://en.wikipedia.org/wiki/Sigmoid_function
>>> sigmoid(np.array([-1.0, 1.0, 2.0]))
array([0.26894142, 0.73105858, 0.88079708])
"""
return 1 / (1 + np.exp(-vector))
def gaussian_error_linear_unit(vector: np.ndarray) -> np.ndarray:
"""
Implements the Gaussian Error Linear Unit (GELU) function
Parameters:
vector (np.array): A numpy array of shape (1,n)
consisting of real values
Returns:
gelu_vec (np.array): The input numpy array, after applying
gelu.
Examples:
>>> gaussian_error_linear_unit(np.array([-1.0, 1.0, 2.0]))
array([-0.15420423, 0.84579577, 1.93565862])
>>> gaussian_error_linear_unit(np.array([-3]))
array([-0.01807131])
"""
return vector * sigmoid(1.702 * vector)
if __name__ == "__main__":
import doctest
doctest.testmod()