Python/project_euler/problem_551/sol1.py
pre-commit-ci[bot] bc8df6de31
[pre-commit.ci] pre-commit autoupdate ()
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/astral-sh/ruff-pre-commit: v0.2.2 → v0.3.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.2.2...v0.3.2)
- [github.com/pre-commit/mirrors-mypy: v1.8.0 → v1.9.0](https://github.com/pre-commit/mirrors-mypy/compare/v1.8.0...v1.9.0)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-03-13 07:52:41 +01:00

201 lines
5.0 KiB
Python

"""
Sum of digits sequence
Problem 551
Let a(0), a(1),... be an integer sequence defined by:
a(0) = 1
for n >= 1, a(n) is the sum of the digits of all preceding terms
The sequence starts with 1, 1, 2, 4, 8, ...
You are given a(10^6) = 31054319.
Find a(10^15)
"""
ks = range(2, 20 + 1)
base = [10**k for k in range(ks[-1] + 1)]
memo: dict[int, dict[int, list[list[int]]]] = {}
def next_term(a_i, k, i, n):
"""
Calculates and updates a_i in-place to either the n-th term or the
smallest term for which c > 10^k when the terms are written in the form:
a(i) = b * 10^k + c
For any a(i), if digitsum(b) and c have the same value, the difference
between subsequent terms will be the same until c >= 10^k. This difference
is cached to greatly speed up the computation.
Arguments:
a_i -- array of digits starting from the one's place that represent
the i-th term in the sequence
k -- k when terms are written in the from a(i) = b*10^k + c.
Term are calulcated until c > 10^k or the n-th term is reached.
i -- position along the sequence
n -- term to calculate up to if k is large enough
Return: a tuple of difference between ending term and starting term, and
the number of terms calculated. ex. if starting term is a_0=1, and
ending term is a_10=62, then (61, 9) is returned.
"""
# ds_b - digitsum(b)
ds_b = sum(a_i[j] for j in range(k, len(a_i)))
c = sum(a_i[j] * base[j] for j in range(min(len(a_i), k)))
diff, dn = 0, 0
max_dn = n - i
sub_memo = memo.get(ds_b)
if sub_memo is not None:
jumps = sub_memo.get(c)
if jumps is not None and len(jumps) > 0:
# find and make the largest jump without going over
max_jump = -1
for _k in range(len(jumps) - 1, -1, -1):
if jumps[_k][2] <= k and jumps[_k][1] <= max_dn:
max_jump = _k
break
if max_jump >= 0:
diff, dn, _kk = jumps[max_jump]
# since the difference between jumps is cached, add c
new_c = diff + c
for j in range(min(k, len(a_i))):
new_c, a_i[j] = divmod(new_c, 10)
if new_c > 0:
add(a_i, k, new_c)
else:
sub_memo[c] = []
else:
sub_memo = {c: []}
memo[ds_b] = sub_memo
if dn >= max_dn or c + diff >= base[k]:
return diff, dn
if k > ks[0]:
while True:
# keep doing smaller jumps
_diff, terms_jumped = next_term(a_i, k - 1, i + dn, n)
diff += _diff
dn += terms_jumped
if dn >= max_dn or c + diff >= base[k]:
break
else:
# would be too small a jump, just compute sequential terms instead
_diff, terms_jumped = compute(a_i, k, i + dn, n)
diff += _diff
dn += terms_jumped
jumps = sub_memo[c]
# keep jumps sorted by # of terms skipped
j = 0
while j < len(jumps):
if jumps[j][1] > dn:
break
j += 1
# cache the jump for this value digitsum(b) and c
sub_memo[c].insert(j, (diff, dn, k))
return (diff, dn)
def compute(a_i, k, i, n):
"""
same as next_term(a_i, k, i, n) but computes terms without memoizing results.
"""
if i >= n:
return 0, i
if k > len(a_i):
a_i.extend([0 for _ in range(k - len(a_i))])
# note: a_i -> b * 10^k + c
# ds_b -> digitsum(b)
# ds_c -> digitsum(c)
start_i = i
ds_b, ds_c, diff = 0, 0, 0
for j in range(len(a_i)):
if j >= k:
ds_b += a_i[j]
else:
ds_c += a_i[j]
while i < n:
i += 1
addend = ds_c + ds_b
diff += addend
ds_c = 0
for j in range(k):
s = a_i[j] + addend
addend, a_i[j] = divmod(s, 10)
ds_c += a_i[j]
if addend > 0:
break
if addend > 0:
add(a_i, k, addend)
return diff, i - start_i
def add(digits, k, addend):
"""
adds addend to digit array given in digits
starting at index k
"""
for j in range(k, len(digits)):
s = digits[j] + addend
if s >= 10:
quotient, digits[j] = divmod(s, 10)
addend = addend // 10 + quotient
else:
digits[j] = s
addend = addend // 10
if addend == 0:
break
while addend > 0:
addend, digit = divmod(addend, 10)
digits.append(digit)
def solution(n: int = 10**15) -> int:
"""
returns n-th term of sequence
>>> solution(10)
62
>>> solution(10**6)
31054319
>>> solution(10**15)
73597483551591773
"""
digits = [1]
i = 1
dn = 0
while True:
diff, terms_jumped = next_term(digits, 20, i + dn, n)
dn += terms_jumped
if dn == n - i:
break
a_n = 0
for j in range(len(digits)):
a_n += digits[j] * 10**j
return a_n
if __name__ == "__main__":
print(f"{solution() = }")