Skip to content

test_mod_arithmetic()

Documentation for tests/benchmark/compute/instruction/test_arithmetic.py::test_mod_arithmetic@8db70f93.

Generate fixtures for these test cases for Amsterdam with:

fill -v tests/benchmark/compute/instruction/test_arithmetic.py::test_mod_arithmetic --gas-benchmark-values 1

Benchmark ADDMOD and MULMOD instructions.

The program consists of code segments evaluating the "op chain": mod[0] = calldataload(0) mod[1] = (fixed_arg op args[indexes[0]]) % mod[0] mod[2] = (fixed_arg op args[indexes[1]]) % mod[1] The "args" is a pool of 15 constants pushed to the EVM stack at the program start. The "fixed_arg" is the 0xFF...FF constant added to the EVM stack by PUSH32 just before executing the "op". The order of accessing the numerators is selected in a way the mod value remains in the range as long as possible.

Source code in tests/benchmark/compute/instruction/test_arithmetic.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
@pytest.mark.repricing(mod_bits=191)
@pytest.mark.parametrize("mod_bits", [255, 191, 127, 63])
@pytest.mark.parametrize("opcode", [Op.ADDMOD, Op.MULMOD])
def test_mod_arithmetic(
    benchmark_test: BenchmarkTestFiller,
    pre: Alloc,
    fork: Fork,
    mod_bits: int,
    opcode: Op,
    gas_benchmark_value: int,
) -> None:
    """
    Benchmark ADDMOD and MULMOD instructions.

    The program consists of code segments evaluating the "op chain":
    mod[0] = calldataload(0)
    mod[1] = (fixed_arg op args[indexes[0]]) % mod[0]
    mod[2] = (fixed_arg op args[indexes[1]]) % mod[1]
    The "args" is a pool of 15 constants pushed to the EVM stack at the program
    start.
    The "fixed_arg" is the 0xFF...FF constant added to the EVM stack by PUSH32
    just before executing the "op".
    The order of accessing the numerators is selected in a way the mod value
    remains in the range as long as possible.
    """
    fixed_arg = 2**256 - 1
    num_args = 15

    max_code_size = fork.max_code_size()

    # Pick the modulus min value so that it is _unlikely_ to drop to the lower
    # word count.
    assert mod_bits >= 63
    mod_min = 2 ** (mod_bits - 63)

    # Select the random seed giving the longest found op chain. You can look
    # for a longer one by increasing the op_chain_len. This will activate the
    # while loop below.
    op_chain_len = 666
    match opcode, mod_bits:
        case Op.ADDMOD, 255:
            seed = 4
        case Op.ADDMOD, 191:
            seed = 2
        case Op.ADDMOD, 127:
            seed = 2
        case Op.ADDMOD, 63:
            seed = 64
        case Op.MULMOD, 255:
            seed = 5
        case Op.MULMOD, 191:
            seed = 389
        case Op.MULMOD, 127:
            seed = 5
        case Op.MULMOD, 63:
            # For this setup we were not able to find an op-chain longer than
            # 600.
            seed = 4193
            op_chain_len = 600
        case _:
            raise ValueError(f"{mod_bits}-bit {opcode} not supported.")

    while True:
        rng = random.Random(seed)
        args = [rng.randint(2**255, 2**256 - 1) for _ in range(num_args)]
        initial_mod = rng.randint(2 ** (mod_bits - 1), 2**mod_bits - 1)

        # Evaluate the op chain and collect the order of accessing numerators.
        op_fn = operator.add if opcode == Op.ADDMOD else operator.mul
        mod = initial_mod
        indexes: list[int] = []
        while mod >= mod_min and len(indexes) < op_chain_len:
            results = [op_fn(a, fixed_arg) % mod for a in args]
            # And pick the best one.
            i = max(range(len(results)), key=results.__getitem__)
            mod = results[i]
            indexes.append(i)

        # Disable if you want to find longer op chains.
        assert len(indexes) == op_chain_len
        if len(indexes) == op_chain_len:
            break
        seed += 1
        print(f"{seed=}")

    code_constant_pool = sum((Op.PUSH32[n] for n in args), Bytecode())
    code_segment = (
        Op.CALLDATALOAD(0)
        + sum(
            make_dup(len(args) - i) + Op.PUSH32[fixed_arg] + opcode
            for i in indexes
        )
        + Op.POP
    )
    # Construct the final code. Because of the usage of PUSH32 the code segment
    # is very long, so don't try to include multiple of these.
    code = (
        code_constant_pool
        + Op.JUMPDEST
        + code_segment
        + Op.JUMP(len(code_constant_pool))
    )
    assert (max_code_size - len(code_segment)) < len(code) <= max_code_size

    tx = Transaction(
        to=pre.deploy_contract(code=code),
        data=initial_mod.to_bytes(32, byteorder="big"),
        gas_limit=gas_benchmark_value,
        sender=pre.fund_eoa(),
    )

    benchmark_test(
        tx=tx,
    )

Parametrized Test Cases

This test generates 8 parametrized test cases across 3 forks.