Add Viterbi algorithm (#7509)

* Added Viterbi algorithm Fixes: #7465

Squashed commits

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added doctest for validators

* moved all extracted functions to the main function

* Forgot a type hint

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Carlos Villar 2022-10-29 15:44:18 +02:00 committed by GitHub
parent efb4a3aee8
commit 7b521b66cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -0,0 +1,400 @@
from typing import Any
def viterbi(
observations_space: list,
states_space: list,
initial_probabilities: dict,
transition_probabilities: dict,
emission_probabilities: dict,
) -> list:
"""
Viterbi Algorithm, to find the most likely path of
states from the start and the expected output.
https://en.wikipedia.org/wiki/Viterbi_algorithm
sdafads
Wikipedia example
>>> observations = ["normal", "cold", "dizzy"]
>>> states = ["Healthy", "Fever"]
>>> start_p = {"Healthy": 0.6, "Fever": 0.4}
>>> trans_p = {
... "Healthy": {"Healthy": 0.7, "Fever": 0.3},
... "Fever": {"Healthy": 0.4, "Fever": 0.6},
... }
>>> emit_p = {
... "Healthy": {"normal": 0.5, "cold": 0.4, "dizzy": 0.1},
... "Fever": {"normal": 0.1, "cold": 0.3, "dizzy": 0.6},
... }
>>> viterbi(observations, states, start_p, trans_p, emit_p)
['Healthy', 'Healthy', 'Fever']
>>> viterbi((), states, start_p, trans_p, emit_p)
Traceback (most recent call last):
...
ValueError: There's an empty parameter
>>> viterbi(observations, (), start_p, trans_p, emit_p)
Traceback (most recent call last):
...
ValueError: There's an empty parameter
>>> viterbi(observations, states, {}, trans_p, emit_p)
Traceback (most recent call last):
...
ValueError: There's an empty parameter
>>> viterbi(observations, states, start_p, {}, emit_p)
Traceback (most recent call last):
...
ValueError: There's an empty parameter
>>> viterbi(observations, states, start_p, trans_p, {})
Traceback (most recent call last):
...
ValueError: There's an empty parameter
>>> viterbi("invalid", states, start_p, trans_p, emit_p)
Traceback (most recent call last):
...
ValueError: observations_space must be a list
>>> viterbi(["valid", 123], states, start_p, trans_p, emit_p)
Traceback (most recent call last):
...
ValueError: observations_space must be a list of strings
>>> viterbi(observations, "invalid", start_p, trans_p, emit_p)
Traceback (most recent call last):
...
ValueError: states_space must be a list
>>> viterbi(observations, ["valid", 123], start_p, trans_p, emit_p)
Traceback (most recent call last):
...
ValueError: states_space must be a list of strings
>>> viterbi(observations, states, "invalid", trans_p, emit_p)
Traceback (most recent call last):
...
ValueError: initial_probabilities must be a dict
>>> viterbi(observations, states, {2:2}, trans_p, emit_p)
Traceback (most recent call last):
...
ValueError: initial_probabilities all keys must be strings
>>> viterbi(observations, states, {"a":2}, trans_p, emit_p)
Traceback (most recent call last):
...
ValueError: initial_probabilities all values must be float
>>> viterbi(observations, states, start_p, "invalid", emit_p)
Traceback (most recent call last):
...
ValueError: transition_probabilities must be a dict
>>> viterbi(observations, states, start_p, {"a":2}, emit_p)
Traceback (most recent call last):
...
ValueError: transition_probabilities all values must be dict
>>> viterbi(observations, states, start_p, {2:{2:2}}, emit_p)
Traceback (most recent call last):
...
ValueError: transition_probabilities all keys must be strings
>>> viterbi(observations, states, start_p, {"a":{2:2}}, emit_p)
Traceback (most recent call last):
...
ValueError: transition_probabilities all keys must be strings
>>> viterbi(observations, states, start_p, {"a":{"b":2}}, emit_p)
Traceback (most recent call last):
...
ValueError: transition_probabilities nested dictionary all values must be float
>>> viterbi(observations, states, start_p, trans_p, "invalid")
Traceback (most recent call last):
...
ValueError: emission_probabilities must be a dict
>>> viterbi(observations, states, start_p, trans_p, None)
Traceback (most recent call last):
...
ValueError: There's an empty parameter
"""
_validation(
observations_space,
states_space,
initial_probabilities,
transition_probabilities,
emission_probabilities,
)
# Creates data structures and fill initial step
probabilities: dict = {}
pointers: dict = {}
for state in states_space:
observation = observations_space[0]
probabilities[(state, observation)] = (
initial_probabilities[state] * emission_probabilities[state][observation]
)
pointers[(state, observation)] = None
# Fills the data structure with the probabilities of
# different transitions and pointers to previous states
for o in range(1, len(observations_space)):
observation = observations_space[o]
prior_observation = observations_space[o - 1]
for state in states_space:
# Calculates the argmax for probability function
arg_max = ""
max_probability = -1
for k_state in states_space:
probability = (
probabilities[(k_state, prior_observation)]
* transition_probabilities[k_state][state]
* emission_probabilities[state][observation]
)
if probability > max_probability:
max_probability = probability
arg_max = k_state
# Update probabilities and pointers dicts
probabilities[(state, observation)] = (
probabilities[(arg_max, prior_observation)]
* transition_probabilities[arg_max][state]
* emission_probabilities[state][observation]
)
pointers[(state, observation)] = arg_max
# The final observation
final_observation = observations_space[len(observations_space) - 1]
# argmax for given final observation
arg_max = ""
max_probability = -1
for k_state in states_space:
probability = probabilities[(k_state, final_observation)]
if probability > max_probability:
max_probability = probability
arg_max = k_state
last_state = arg_max
# Process pointers backwards
previous = last_state
result = []
for o in range(len(observations_space) - 1, -1, -1):
result.append(previous)
previous = pointers[previous, observations_space[o]]
result.reverse()
return result
def _validation(
observations_space: Any,
states_space: Any,
initial_probabilities: Any,
transition_probabilities: Any,
emission_probabilities: Any,
) -> None:
"""
>>> observations = ["normal", "cold", "dizzy"]
>>> states = ["Healthy", "Fever"]
>>> start_p = {"Healthy": 0.6, "Fever": 0.4}
>>> trans_p = {
... "Healthy": {"Healthy": 0.7, "Fever": 0.3},
... "Fever": {"Healthy": 0.4, "Fever": 0.6},
... }
>>> emit_p = {
... "Healthy": {"normal": 0.5, "cold": 0.4, "dizzy": 0.1},
... "Fever": {"normal": 0.1, "cold": 0.3, "dizzy": 0.6},
... }
>>> _validation(observations, states, start_p, trans_p, emit_p)
>>> _validation([], states, start_p, trans_p, emit_p)
Traceback (most recent call last):
...
ValueError: There's an empty parameter
"""
_validate_not_empty(
observations_space,
states_space,
initial_probabilities,
transition_probabilities,
emission_probabilities,
)
_validate_lists(observations_space, states_space)
_validate_dicts(
initial_probabilities, transition_probabilities, emission_probabilities
)
def _validate_not_empty(
observations_space: Any,
states_space: Any,
initial_probabilities: Any,
transition_probabilities: Any,
emission_probabilities: Any,
) -> None:
"""
>>> _validate_not_empty(["a"], ["b"], {"c":0.5},
... {"d": {"e": 0.6}}, {"f": {"g": 0.7}})
>>> _validate_not_empty(["a"], ["b"], {"c":0.5}, {}, {"f": {"g": 0.7}})
Traceback (most recent call last):
...
ValueError: There's an empty parameter
>>> _validate_not_empty(["a"], ["b"], None, {"d": {"e": 0.6}}, {"f": {"g": 0.7}})
Traceback (most recent call last):
...
ValueError: There's an empty parameter
"""
if not all(
[
observations_space,
states_space,
initial_probabilities,
transition_probabilities,
emission_probabilities,
]
):
raise ValueError("There's an empty parameter")
def _validate_lists(observations_space: Any, states_space: Any) -> None:
"""
>>> _validate_lists(["a"], ["b"])
>>> _validate_lists(1234, ["b"])
Traceback (most recent call last):
...
ValueError: observations_space must be a list
>>> _validate_lists(["a"], [3])
Traceback (most recent call last):
...
ValueError: states_space must be a list of strings
"""
_validate_list(observations_space, "observations_space")
_validate_list(states_space, "states_space")
def _validate_list(_object: Any, var_name: str) -> None:
"""
>>> _validate_list(["a"], "mock_name")
>>> _validate_list("a", "mock_name")
Traceback (most recent call last):
...
ValueError: mock_name must be a list
>>> _validate_list([0.5], "mock_name")
Traceback (most recent call last):
...
ValueError: mock_name must be a list of strings
"""
if not isinstance(_object, list):
raise ValueError(f"{var_name} must be a list")
else:
for x in _object:
if not isinstance(x, str):
raise ValueError(f"{var_name} must be a list of strings")
def _validate_dicts(
initial_probabilities: Any,
transition_probabilities: Any,
emission_probabilities: Any,
) -> None:
"""
>>> _validate_dicts({"c":0.5}, {"d": {"e": 0.6}}, {"f": {"g": 0.7}})
>>> _validate_dicts("invalid", {"d": {"e": 0.6}}, {"f": {"g": 0.7}})
Traceback (most recent call last):
...
ValueError: initial_probabilities must be a dict
>>> _validate_dicts({"c":0.5}, {2: {"e": 0.6}}, {"f": {"g": 0.7}})
Traceback (most recent call last):
...
ValueError: transition_probabilities all keys must be strings
>>> _validate_dicts({"c":0.5}, {"d": {"e": 0.6}}, {"f": {2: 0.7}})
Traceback (most recent call last):
...
ValueError: emission_probabilities all keys must be strings
>>> _validate_dicts({"c":0.5}, {"d": {"e": 0.6}}, {"f": {"g": "h"}})
Traceback (most recent call last):
...
ValueError: emission_probabilities nested dictionary all values must be float
"""
_validate_dict(initial_probabilities, "initial_probabilities", float)
_validate_nested_dict(transition_probabilities, "transition_probabilities")
_validate_nested_dict(emission_probabilities, "emission_probabilities")
def _validate_nested_dict(_object: Any, var_name: str) -> None:
"""
>>> _validate_nested_dict({"a":{"b": 0.5}}, "mock_name")
>>> _validate_nested_dict("invalid", "mock_name")
Traceback (most recent call last):
...
ValueError: mock_name must be a dict
>>> _validate_nested_dict({"a": 8}, "mock_name")
Traceback (most recent call last):
...
ValueError: mock_name all values must be dict
>>> _validate_nested_dict({"a":{2: 0.5}}, "mock_name")
Traceback (most recent call last):
...
ValueError: mock_name all keys must be strings
>>> _validate_nested_dict({"a":{"b": 4}}, "mock_name")
Traceback (most recent call last):
...
ValueError: mock_name nested dictionary all values must be float
"""
_validate_dict(_object, var_name, dict)
for x in _object.values():
_validate_dict(x, var_name, float, True)
def _validate_dict(
_object: Any, var_name: str, value_type: type, nested: bool = False
) -> None:
"""
>>> _validate_dict({"b": 0.5}, "mock_name", float)
>>> _validate_dict("invalid", "mock_name", float)
Traceback (most recent call last):
...
ValueError: mock_name must be a dict
>>> _validate_dict({"a": 8}, "mock_name", dict)
Traceback (most recent call last):
...
ValueError: mock_name all values must be dict
>>> _validate_dict({2: 0.5}, "mock_name",float, True)
Traceback (most recent call last):
...
ValueError: mock_name all keys must be strings
>>> _validate_dict({"b": 4}, "mock_name", float,True)
Traceback (most recent call last):
...
ValueError: mock_name nested dictionary all values must be float
"""
if not isinstance(_object, dict):
raise ValueError(f"{var_name} must be a dict")
if not all(isinstance(x, str) for x in _object):
raise ValueError(f"{var_name} all keys must be strings")
if not all(isinstance(x, value_type) for x in _object.values()):
nested_text = "nested dictionary " if nested else ""
raise ValueError(
f"{var_name} {nested_text}all values must be {value_type.__name__}"
)
if __name__ == "__main__":
from doctest import testmod
testmod()