mirror of
https://github.com/TheAlgorithms/Python.git
synced 2024-12-21 02:30:15 +00:00
bc8df6de31
* [pre-commit.ci] pre-commit autoupdate updates: - [github.com/astral-sh/ruff-pre-commit: v0.2.2 → v0.3.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.2.2...v0.3.2) - [github.com/pre-commit/mirrors-mypy: v1.8.0 → v1.9.0](https://github.com/pre-commit/mirrors-mypy/compare/v1.8.0...v1.9.0) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
201 lines
5.0 KiB
Python
201 lines
5.0 KiB
Python
"""
|
|
Sum of digits sequence
|
|
Problem 551
|
|
|
|
Let a(0), a(1),... be an integer sequence defined by:
|
|
a(0) = 1
|
|
for n >= 1, a(n) is the sum of the digits of all preceding terms
|
|
|
|
The sequence starts with 1, 1, 2, 4, 8, ...
|
|
You are given a(10^6) = 31054319.
|
|
|
|
Find a(10^15)
|
|
"""
|
|
|
|
ks = range(2, 20 + 1)
|
|
base = [10**k for k in range(ks[-1] + 1)]
|
|
memo: dict[int, dict[int, list[list[int]]]] = {}
|
|
|
|
|
|
def next_term(a_i, k, i, n):
|
|
"""
|
|
Calculates and updates a_i in-place to either the n-th term or the
|
|
smallest term for which c > 10^k when the terms are written in the form:
|
|
a(i) = b * 10^k + c
|
|
|
|
For any a(i), if digitsum(b) and c have the same value, the difference
|
|
between subsequent terms will be the same until c >= 10^k. This difference
|
|
is cached to greatly speed up the computation.
|
|
|
|
Arguments:
|
|
a_i -- array of digits starting from the one's place that represent
|
|
the i-th term in the sequence
|
|
k -- k when terms are written in the from a(i) = b*10^k + c.
|
|
Term are calulcated until c > 10^k or the n-th term is reached.
|
|
i -- position along the sequence
|
|
n -- term to calculate up to if k is large enough
|
|
|
|
Return: a tuple of difference between ending term and starting term, and
|
|
the number of terms calculated. ex. if starting term is a_0=1, and
|
|
ending term is a_10=62, then (61, 9) is returned.
|
|
"""
|
|
# ds_b - digitsum(b)
|
|
ds_b = sum(a_i[j] for j in range(k, len(a_i)))
|
|
c = sum(a_i[j] * base[j] for j in range(min(len(a_i), k)))
|
|
|
|
diff, dn = 0, 0
|
|
max_dn = n - i
|
|
|
|
sub_memo = memo.get(ds_b)
|
|
|
|
if sub_memo is not None:
|
|
jumps = sub_memo.get(c)
|
|
|
|
if jumps is not None and len(jumps) > 0:
|
|
# find and make the largest jump without going over
|
|
max_jump = -1
|
|
for _k in range(len(jumps) - 1, -1, -1):
|
|
if jumps[_k][2] <= k and jumps[_k][1] <= max_dn:
|
|
max_jump = _k
|
|
break
|
|
|
|
if max_jump >= 0:
|
|
diff, dn, _kk = jumps[max_jump]
|
|
# since the difference between jumps is cached, add c
|
|
new_c = diff + c
|
|
for j in range(min(k, len(a_i))):
|
|
new_c, a_i[j] = divmod(new_c, 10)
|
|
if new_c > 0:
|
|
add(a_i, k, new_c)
|
|
|
|
else:
|
|
sub_memo[c] = []
|
|
else:
|
|
sub_memo = {c: []}
|
|
memo[ds_b] = sub_memo
|
|
|
|
if dn >= max_dn or c + diff >= base[k]:
|
|
return diff, dn
|
|
|
|
if k > ks[0]:
|
|
while True:
|
|
# keep doing smaller jumps
|
|
_diff, terms_jumped = next_term(a_i, k - 1, i + dn, n)
|
|
diff += _diff
|
|
dn += terms_jumped
|
|
|
|
if dn >= max_dn or c + diff >= base[k]:
|
|
break
|
|
else:
|
|
# would be too small a jump, just compute sequential terms instead
|
|
_diff, terms_jumped = compute(a_i, k, i + dn, n)
|
|
diff += _diff
|
|
dn += terms_jumped
|
|
|
|
jumps = sub_memo[c]
|
|
|
|
# keep jumps sorted by # of terms skipped
|
|
j = 0
|
|
while j < len(jumps):
|
|
if jumps[j][1] > dn:
|
|
break
|
|
j += 1
|
|
|
|
# cache the jump for this value digitsum(b) and c
|
|
sub_memo[c].insert(j, (diff, dn, k))
|
|
return (diff, dn)
|
|
|
|
|
|
def compute(a_i, k, i, n):
|
|
"""
|
|
same as next_term(a_i, k, i, n) but computes terms without memoizing results.
|
|
"""
|
|
if i >= n:
|
|
return 0, i
|
|
if k > len(a_i):
|
|
a_i.extend([0 for _ in range(k - len(a_i))])
|
|
|
|
# note: a_i -> b * 10^k + c
|
|
# ds_b -> digitsum(b)
|
|
# ds_c -> digitsum(c)
|
|
start_i = i
|
|
ds_b, ds_c, diff = 0, 0, 0
|
|
for j in range(len(a_i)):
|
|
if j >= k:
|
|
ds_b += a_i[j]
|
|
else:
|
|
ds_c += a_i[j]
|
|
|
|
while i < n:
|
|
i += 1
|
|
addend = ds_c + ds_b
|
|
diff += addend
|
|
ds_c = 0
|
|
for j in range(k):
|
|
s = a_i[j] + addend
|
|
addend, a_i[j] = divmod(s, 10)
|
|
|
|
ds_c += a_i[j]
|
|
|
|
if addend > 0:
|
|
break
|
|
|
|
if addend > 0:
|
|
add(a_i, k, addend)
|
|
return diff, i - start_i
|
|
|
|
|
|
def add(digits, k, addend):
|
|
"""
|
|
adds addend to digit array given in digits
|
|
starting at index k
|
|
"""
|
|
for j in range(k, len(digits)):
|
|
s = digits[j] + addend
|
|
if s >= 10:
|
|
quotient, digits[j] = divmod(s, 10)
|
|
addend = addend // 10 + quotient
|
|
else:
|
|
digits[j] = s
|
|
addend = addend // 10
|
|
|
|
if addend == 0:
|
|
break
|
|
|
|
while addend > 0:
|
|
addend, digit = divmod(addend, 10)
|
|
digits.append(digit)
|
|
|
|
|
|
def solution(n: int = 10**15) -> int:
|
|
"""
|
|
returns n-th term of sequence
|
|
|
|
>>> solution(10)
|
|
62
|
|
|
|
>>> solution(10**6)
|
|
31054319
|
|
|
|
>>> solution(10**15)
|
|
73597483551591773
|
|
"""
|
|
|
|
digits = [1]
|
|
i = 1
|
|
dn = 0
|
|
while True:
|
|
diff, terms_jumped = next_term(digits, 20, i + dn, n)
|
|
dn += terms_jumped
|
|
if dn == n - i:
|
|
break
|
|
|
|
a_n = 0
|
|
for j in range(len(digits)):
|
|
a_n += digits[j] * 10**j
|
|
return a_n
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print(f"{solution() = }")
|