mirror of
https://github.com/TheAlgorithms/Python.git
synced 2024-11-23 21:11:08 +00:00
e3f3d668be
* [pre-commit.ci] pre-commit autoupdate updates: - [github.com/astral-sh/ruff-pre-commit: v0.7.2 → v0.7.3](https://github.com/astral-sh/ruff-pre-commit/compare/v0.7.2...v0.7.3) - [github.com/abravalheri/validate-pyproject: v0.22 → v0.23](https://github.com/abravalheri/validate-pyproject/compare/v0.22...v0.23) * Update sudoku_solver.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Christian Clauss <cclauss@me.com>
221 lines
7.4 KiB
Python
221 lines
7.4 KiB
Python
"""
|
|
Please do not modify this file! It is published at https://norvig.com/sudoku.html with
|
|
only minimal changes to work with modern versions of Python. If you have improvements,
|
|
please make them in a separate file.
|
|
"""
|
|
|
|
import random
|
|
import time
|
|
|
|
|
|
def cross(items_a, items_b):
|
|
"Cross product of elements in A and elements in B."
|
|
return [a + b for a in items_a for b in items_b]
|
|
|
|
|
|
digits = "123456789"
|
|
rows = "ABCDEFGHI"
|
|
cols = digits
|
|
squares = cross(rows, cols)
|
|
unitlist = (
|
|
[cross(rows, c) for c in cols]
|
|
+ [cross(r, cols) for r in rows]
|
|
+ [cross(rs, cs) for rs in ("ABC", "DEF", "GHI") for cs in ("123", "456", "789")]
|
|
)
|
|
units = {s: [u for u in unitlist if s in u] for s in squares}
|
|
peers = {s: set(sum(units[s], [])) - {s} for s in squares} # noqa: RUF017
|
|
|
|
|
|
def test():
|
|
"A set of unit tests."
|
|
assert len(squares) == 81
|
|
assert len(unitlist) == 27
|
|
assert all(len(units[s]) == 3 for s in squares)
|
|
assert all(len(peers[s]) == 20 for s in squares)
|
|
assert units["C2"] == [
|
|
["A2", "B2", "C2", "D2", "E2", "F2", "G2", "H2", "I2"],
|
|
["C1", "C2", "C3", "C4", "C5", "C6", "C7", "C8", "C9"],
|
|
["A1", "A2", "A3", "B1", "B2", "B3", "C1", "C2", "C3"],
|
|
]
|
|
# fmt: off
|
|
assert peers["C2"] == {
|
|
"A2", "B2", "D2", "E2", "F2", "G2", "H2", "I2", "C1", "C3",
|
|
"C4", "C5", "C6", "C7", "C8", "C9", "A1", "A3", "B1", "B3"
|
|
}
|
|
# fmt: on
|
|
print("All tests pass.")
|
|
|
|
|
|
def parse_grid(grid):
|
|
"""Convert grid to a dict of possible values, {square: digits}, or
|
|
return False if a contradiction is detected."""
|
|
## To start, every square can be any digit; then assign values from the grid.
|
|
values = {s: digits for s in squares}
|
|
for s, d in grid_values(grid).items():
|
|
if d in digits and not assign(values, s, d):
|
|
return False ## (Fail if we can't assign d to square s.)
|
|
return values
|
|
|
|
|
|
def grid_values(grid):
|
|
"Convert grid into a dict of {square: char} with '0' or '.' for empties."
|
|
chars = [c for c in grid if c in digits or c in "0."]
|
|
assert len(chars) == 81
|
|
return dict(zip(squares, chars))
|
|
|
|
|
|
def assign(values, s, d):
|
|
"""Eliminate all the other values (except d) from values[s] and propagate.
|
|
Return values, except return False if a contradiction is detected."""
|
|
other_values = values[s].replace(d, "")
|
|
if all(eliminate(values, s, d2) for d2 in other_values):
|
|
return values
|
|
else:
|
|
return False
|
|
|
|
|
|
def eliminate(values, s, d):
|
|
"""Eliminate d from values[s]; propagate when values or places <= 2.
|
|
Return values, except return False if a contradiction is detected."""
|
|
if d not in values[s]:
|
|
return values ## Already eliminated
|
|
values[s] = values[s].replace(d, "")
|
|
## (1) If a square s is reduced to one value d2, then eliminate d2 from the peers.
|
|
if len(values[s]) == 0:
|
|
return False ## Contradiction: removed last value
|
|
elif len(values[s]) == 1:
|
|
d2 = values[s]
|
|
if not all(eliminate(values, s2, d2) for s2 in peers[s]):
|
|
return False
|
|
## (2) If a unit u is reduced to only one place for a value d, then put it there.
|
|
for u in units[s]:
|
|
dplaces = [s for s in u if d in values[s]]
|
|
if len(dplaces) == 0:
|
|
return False ## Contradiction: no place for this value
|
|
# d can only be in one place in unit; assign it there
|
|
elif len(dplaces) == 1 and not assign(values, dplaces[0], d):
|
|
return False
|
|
return values
|
|
|
|
|
|
def display(values):
|
|
"Display these values as a 2-D grid."
|
|
width = 1 + max(len(values[s]) for s in squares)
|
|
line = "+".join(["-" * (width * 3)] * 3)
|
|
for r in rows:
|
|
print(
|
|
"".join(
|
|
values[r + c].center(width) + ("|" if c in "36" else "") for c in cols
|
|
)
|
|
)
|
|
if r in "CF":
|
|
print(line)
|
|
print()
|
|
|
|
|
|
def solve(grid):
|
|
return search(parse_grid(grid))
|
|
|
|
|
|
def some(seq):
|
|
"Return some element of seq that is true."
|
|
for e in seq:
|
|
if e:
|
|
return e
|
|
return False
|
|
|
|
|
|
def search(values):
|
|
"Using depth-first search and propagation, try all possible values."
|
|
if values is False:
|
|
return False ## Failed earlier
|
|
if all(len(values[s]) == 1 for s in squares):
|
|
return values ## Solved!
|
|
## Chose the unfilled square s with the fewest possibilities
|
|
n, s = min((len(values[s]), s) for s in squares if len(values[s]) > 1)
|
|
return some(search(assign(values.copy(), s, d)) for d in values[s])
|
|
|
|
|
|
def solve_all(grids, name="", showif=0.0):
|
|
"""Attempt to solve a sequence of grids. Report results.
|
|
When showif is a number of seconds, display puzzles that take longer.
|
|
When showif is None, don't display any puzzles."""
|
|
|
|
def time_solve(grid):
|
|
start = time.monotonic()
|
|
values = solve(grid)
|
|
t = time.monotonic() - start
|
|
## Display puzzles that take long enough
|
|
if showif is not None and t > showif:
|
|
display(grid_values(grid))
|
|
if values:
|
|
display(values)
|
|
print(f"({t:.5f} seconds)\n")
|
|
return (t, solved(values))
|
|
|
|
times, results = zip(*[time_solve(grid) for grid in grids])
|
|
if (n := len(grids)) > 1:
|
|
print(
|
|
"Solved %d of %d %s puzzles (avg %.2f secs (%d Hz), max %.2f secs)."
|
|
% (sum(results), n, name, sum(times) / n, n / sum(times), max(times))
|
|
)
|
|
|
|
|
|
def solved(values):
|
|
"A puzzle is solved if each unit is a permutation of the digits 1 to 9."
|
|
|
|
def unitsolved(unit):
|
|
return {values[s] for s in unit} == set(digits)
|
|
|
|
return values is not False and all(unitsolved(unit) for unit in unitlist)
|
|
|
|
|
|
def from_file(filename, sep="\n"):
|
|
"Parse a file into a list of strings, separated by sep."
|
|
return open(filename).read().strip().split(sep)
|
|
|
|
|
|
def random_puzzle(assignments=17):
|
|
"""Make a random puzzle with N or more assignments. Restart on contradictions.
|
|
Note the resulting puzzle is not guaranteed to be solvable, but empirically
|
|
about 99.8% of them are solvable. Some have multiple solutions."""
|
|
values = {s: digits for s in squares}
|
|
for s in shuffled(squares):
|
|
if not assign(values, s, random.choice(values[s])):
|
|
break
|
|
ds = [values[s] for s in squares if len(values[s]) == 1]
|
|
if len(ds) >= assignments and len(set(ds)) >= 8:
|
|
return "".join(values[s] if len(values[s]) == 1 else "." for s in squares)
|
|
return random_puzzle(assignments) ## Give up and make a new puzzle
|
|
|
|
|
|
def shuffled(seq):
|
|
"Return a randomly shuffled copy of the input sequence."
|
|
seq = list(seq)
|
|
random.shuffle(seq)
|
|
return seq
|
|
|
|
|
|
grid1 = (
|
|
"003020600900305001001806400008102900700000008006708200002609500800203009005010300"
|
|
)
|
|
grid2 = (
|
|
"4.....8.5.3..........7......2.....6.....8.4......1.......6.3.7.5..2.....1.4......"
|
|
)
|
|
hard1 = (
|
|
".....6....59.....82....8....45........3........6..3.54...325..6.................."
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
test()
|
|
# solve_all(from_file("easy50.txt", '========'), "easy", None)
|
|
# solve_all(from_file("top95.txt"), "hard", None)
|
|
# solve_all(from_file("hardest.txt"), "hardest", None)
|
|
solve_all([random_puzzle() for _ in range(99)], "random", 100.0)
|
|
for puzzle in (grid1, grid2): # , hard1): # Takes 22 sec to solve on my M1 Mac.
|
|
display(parse_grid(puzzle))
|
|
start = time.monotonic()
|
|
solve(puzzle)
|
|
t = time.monotonic() - start
|
|
print(f"Solved: {t:.5f} sec")
|