Просмотр исходного кода

fix(pulse): Miscellaneous contract fixes (#2454)

* fix stuff

* more stuff

* pulsse fixes

* fix

* fix tests

* gr
Jayant Krishnamurthy 8 месяцев назад
Родитель
Сommit
92a1737c47

+ 50 - 9
target_chains/ethereum/contracts/contracts/pulse/IPulse.sol

@@ -6,11 +6,32 @@ import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
 import "./PulseEvents.sol";
 import "./PulseState.sol";
 
-interface IPulseConsumer {
+abstract contract IPulseConsumer {
+    // This method is called by Pulse to provide the price updates to the consumer.
+    // It asserts that the msg.sender is the Pulse contract. It is not meant to be
+    // overridden by the consumer.
+    function _pulseCallback(
+        uint64 sequenceNumber,
+        PythStructs.PriceFeed[] memory priceFeeds
+    ) external {
+        address pulse = getPulse();
+        require(pulse != address(0), "Pulse address not set");
+        require(msg.sender == pulse, "Only Pulse can call this function");
+
+        pulseCallback(sequenceNumber, priceFeeds);
+    }
+
+    // getPulse returns the Pulse contract address. The method is being used to check that the
+    // callback is indeed from the Pulse contract. The consumer is expected to implement this method.
+    function getPulse() internal view virtual returns (address);
+
+    // This method is expected to be implemented by the consumer to handle the price updates.
+    // It will be called by _pulseCallback after _pulseCallback ensures that the call is
+    // indeed from Pulse contract.
     function pulseCallback(
         uint64 sequenceNumber,
         PythStructs.PriceFeed[] memory priceFeeds
-    ) external;
+    ) internal virtual;
 }
 
 interface IPulse is PulseEvents {
@@ -18,10 +39,11 @@ interface IPulse is PulseEvents {
     /**
      * @notice Requests price updates with a callback
      * @dev The msg.value must be equal to getFee(callbackGasLimit)
-     * @param callbackGasLimit The amount of gas allocated for the callback execution
+     * @param provider The provider to fulfill the request
      * @param publishTime The minimum publish time for price updates, it should be less than or equal to block.timestamp + 60
      * @param priceIds The price feed IDs to update. Maximum 10 price feeds per request.
      *        Requests requiring more feeds should be split into multiple calls.
+     * @param callbackGasLimit The amount of gas allocated for the callback execution
      * @return sequenceNumber The sequence number assigned to this request
      * @dev Security note: The 60-second future limit on publishTime prevents a DoS vector where
      *      attackers could submit many low-fee requests for far-future updates when gas prices
@@ -30,7 +52,8 @@ interface IPulse is PulseEvents {
      *      the fee estimation unreliable.
      */
     function requestPriceUpdatesWithCallback(
-        uint256 publishTime,
+        address provider,
+        uint64 publishTime,
         bytes32[] calldata priceIds,
         uint256 callbackGasLimit
     ) external payable returns (uint64 sequenceNumber);
@@ -39,11 +62,13 @@ interface IPulse is PulseEvents {
      * @notice Executes the callback for a price update request
      * @dev Requires 1.5x the callback gas limit to account for cross-contract call overhead
      * For example, if callbackGasLimit is 1M, the transaction needs at least 1.5M gas + some gas for some other operations in the function before the callback
+     * @param providerToCredit The provider to credit for fulfilling the request. This may not be the provider that submitted the request (if the exclusivity period has elapsed).
      * @param sequenceNumber The sequence number of the request
      * @param updateData The raw price update data from Pyth
      * @param priceIds The price feed IDs to update, must match the request
      */
     function executeCallback(
+        address providerToCredit,
         uint64 sequenceNumber,
         bytes[] calldata updateData,
         bytes32[] calldata priceIds
@@ -59,15 +84,22 @@ interface IPulse is PulseEvents {
 
     /**
      * @notice Calculates the total fee required for a price update request
-     * @dev Total fee = base Pyth protocol fee + gas costs for callback
+     * @dev Total fee = base Pyth protocol fee + base provider fee + provider fee per feed + gas costs for callback
+     * @param provider The provider to fulfill the request
      * @param callbackGasLimit The amount of gas allocated for callback execution
+     * @param priceIds The price feed IDs to update.
      * @return feeAmount The total fee in wei that must be provided as msg.value
      */
     function getFee(
-        uint256 callbackGasLimit
+        address provider,
+        uint256 callbackGasLimit,
+        bytes32[] calldata priceIds
     ) external view returns (uint128 feeAmount);
 
-    function getAccruedFees() external view returns (uint128 accruedFeesInWei);
+    function getAccruedPythFees()
+        external
+        view
+        returns (uint128 accruedFeesInWei);
 
     function getRequest(
         uint64 sequenceNumber
@@ -83,9 +115,18 @@ interface IPulse is PulseEvents {
 
     function withdrawAsFeeManager(address provider, uint128 amount) external;
 
-    function registerProvider(uint128 feeInWei) external;
+    function registerProvider(
+        uint128 baseFeeInWei,
+        uint128 feePerFeedInWei,
+        uint128 feePerGasInWei
+    ) external;
 
-    function setProviderFee(uint128 newFeeInWei) external;
+    function setProviderFee(
+        address provider,
+        uint128 newBaseFeeInWei,
+        uint128 newFeePerFeedInWei,
+        uint128 newFeePerGasInWei
+    ) external;
 
     function getProviderInfo(
         address provider

+ 102 - 43
target_chains/ethereum/contracts/contracts/pulse/Pulse.sol

@@ -53,17 +53,19 @@ abstract contract Pulse is IPulse, PulseState {
         }
     }
 
+    // TODO: there can be a separate wrapper function that defaults the provider (or uses the cheapest or something).
     function requestPriceUpdatesWithCallback(
-        uint256 publishTime,
+        address provider,
+        uint64 publishTime,
         bytes32[] calldata priceIds,
         uint256 callbackGasLimit
     ) external payable override returns (uint64 requestSequenceNumber) {
-        address provider = _state.defaultProvider;
         require(
             _state.providers[provider].isRegistered,
             "Provider not registered"
         );
 
+        // FIXME: this comment is wrong. (we're not using tx.gasprice)
         // NOTE: The 60-second future limit on publishTime prevents a DoS vector where
         //      attackers could submit many low-fee requests for far-future updates when gas prices
         //      are low, forcing executors to fulfill them later when gas prices might be much higher.
@@ -75,7 +77,7 @@ abstract contract Pulse is IPulse, PulseState {
         }
         requestSequenceNumber = _state.currentSequenceNumber++;
 
-        uint128 requiredFee = getFee(callbackGasLimit);
+        uint128 requiredFee = getFee(provider, callbackGasLimit, priceIds);
         if (msg.value < requiredFee) revert InsufficientFee();
 
         Request storage req = allocRequest(requestSequenceNumber);
@@ -85,21 +87,20 @@ abstract contract Pulse is IPulse, PulseState {
         req.requester = msg.sender;
         req.numPriceIds = uint8(priceIds.length);
         req.provider = provider;
+        req.fee = SafeCast.toUint128(msg.value - _state.pythFeeInWei);
 
         // Copy price IDs to storage
         for (uint8 i = 0; i < priceIds.length; i++) {
             req.priceIds[i] = priceIds[i];
         }
-
-        _state.providers[provider].accruedFeesInWei += SafeCast.toUint128(
-            msg.value - _state.pythFeeInWei
-        );
         _state.accruedFeesInWei += _state.pythFeeInWei;
 
         emit PriceUpdateRequested(req, priceIds);
     }
 
+    // TODO: does this need to be payable? Any cost paid to Pyth could be taken out of the provider's accrued fees.
     function executeCallback(
+        address providerToCredit,
         uint64 sequenceNumber,
         bytes[] calldata updateData,
         bytes32[] calldata priceIds
@@ -111,7 +112,7 @@ abstract contract Pulse is IPulse, PulseState {
             block.timestamp < req.publishTime + _state.exclusivityPeriodSeconds
         ) {
             require(
-                msg.sender == req.provider,
+                providerToCredit == req.provider,
                 "Only assigned provider during exclusivity period"
             );
         }
@@ -127,19 +128,41 @@ abstract contract Pulse is IPulse, PulseState {
             }
         }
 
-        // Parse price feeds first to measure gas usage
-        PythStructs.PriceFeed[] memory priceFeeds = IPyth(_state.pyth)
-            .parsePriceFeedUpdates(
-                updateData,
-                priceIds,
-                SafeCast.toUint64(req.publishTime),
-                SafeCast.toUint64(req.publishTime)
-            );
+        // TODO: should this use parsePriceFeedUpdatesUnique? also, do we need to add 1 to maxPublishTime?
+        IPyth pyth = IPyth(_state.pyth);
+        uint256 pythFee = pyth.getUpdateFee(updateData);
+        PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{
+            value: pythFee
+        }(
+            updateData,
+            priceIds,
+            SafeCast.toUint64(req.publishTime),
+            SafeCast.toUint64(req.publishTime)
+        );
+
+        // TODO: if this effect occurs here, we need to guarantee that executeCallback can never revert.
+        // If executeCallback can revert, then funds can be permanently locked in the contract.
+        // TODO: there also needs to be some penalty mechanism in case the expected provider doesn't execute the callback.
+        // This should take funds from the expected provider and give to providerToCredit. The penalty should probably scale
+        // with time in order to ensure that the callback eventually gets executed.
+        // (There may be exploits with ^ though if the consumer contract is malicious ?)
+        _state.providers[providerToCredit].accruedFeesInWei += SafeCast
+            .toUint128((req.fee + msg.value) - pythFee);
 
         clearRequest(sequenceNumber);
 
+        // TODO: I'm pretty sure this is going to use a lot of gas because it's doing a storage lookup for each sequence number.
+        // a better solution would be a doubly-linked list of active requests.
+        // After successful callback, update firstUnfulfilledSeq if needed
+        while (
+            _state.firstUnfulfilledSeq < _state.currentSequenceNumber &&
+            !isActive(findRequest(_state.firstUnfulfilledSeq))
+        ) {
+            _state.firstUnfulfilledSeq++;
+        }
+
         try
-            IPulseConsumer(req.requester).pulseCallback{
+            IPulseConsumer(req.requester)._pulseCallback{
                 gas: req.callbackGasLimit
             }(sequenceNumber, priceFeeds)
         {
@@ -149,7 +172,7 @@ abstract contract Pulse is IPulse, PulseState {
             // Explicit revert/require
             emit PriceUpdateCallbackFailed(
                 sequenceNumber,
-                msg.sender,
+                providerToCredit,
                 priceIds,
                 req.requester,
                 reason
@@ -158,20 +181,12 @@ abstract contract Pulse is IPulse, PulseState {
             // Out of gas or other low-level errors
             emit PriceUpdateCallbackFailed(
                 sequenceNumber,
-                msg.sender,
+                providerToCredit,
                 priceIds,
                 req.requester,
                 "low-level error (possibly out of gas)"
             );
         }
-
-        // After successful callback, update firstUnfulfilledSeq if needed
-        while (
-            _state.firstUnfulfilledSeq < _state.currentSequenceNumber &&
-            !isActive(findRequest(_state.firstUnfulfilledSeq))
-        ) {
-            _state.firstUnfulfilledSeq++;
-        }
     }
 
     function emitPriceUpdate(
@@ -182,13 +197,16 @@ abstract contract Pulse is IPulse, PulseState {
         int64[] memory prices = new int64[](priceFeeds.length);
         uint64[] memory conf = new uint64[](priceFeeds.length);
         int32[] memory expos = new int32[](priceFeeds.length);
-        uint256[] memory publishTimes = new uint256[](priceFeeds.length);
+        uint64[] memory publishTimes = new uint64[](priceFeeds.length);
 
         for (uint i = 0; i < priceFeeds.length; i++) {
             prices[i] = priceFeeds[i].price.price;
             conf[i] = priceFeeds[i].price.conf;
             expos[i] = priceFeeds[i].price.expo;
-            publishTimes[i] = priceFeeds[i].price.publishTime;
+            // Safe cast because this is a unix timestamp in seconds.
+            publishTimes[i] = SafeCast.toUint64(
+                priceFeeds[i].price.publishTime
+            );
         }
 
         emit PriceUpdateExecuted(
@@ -203,14 +221,25 @@ abstract contract Pulse is IPulse, PulseState {
     }
 
     function getFee(
-        uint256 callbackGasLimit
+        address provider,
+        uint256 callbackGasLimit,
+        bytes32[] calldata priceIds
     ) public view override returns (uint128 feeAmount) {
         uint128 baseFee = _state.pythFeeInWei; // Fixed fee to Pyth
-        uint128 providerFeeInWei = _state
-            .providers[_state.defaultProvider]
-            .feeInWei; // Provider's per-gas rate
+        // Note: The provider needs to set its fees to include the fee charged by the Pyth contract.
+        // Ideally, we would be able to automatically compute the pyth fees from the priceIds, but the
+        // fee computation on IPyth assumes it has the full updated data.
+        uint128 providerBaseFee = _state.providers[provider].baseFeeInWei;
+        uint128 providerFeedFee = SafeCast.toUint128(
+            priceIds.length * _state.providers[provider].feePerFeedInWei
+        );
+        uint128 providerFeeInWei = _state.providers[provider].feePerGasInWei; // Provider's per-gas rate
         uint256 gasFee = callbackGasLimit * providerFeeInWei; // Total provider fee based on gas
-        feeAmount = baseFee + SafeCast.toUint128(gasFee); // Total fee user needs to pay
+        feeAmount =
+            baseFee +
+            providerBaseFee +
+            providerFeedFee +
+            SafeCast.toUint128(gasFee); // Total fee user needs to pay
     }
 
     function getPythFeeInWei()
@@ -222,7 +251,7 @@ abstract contract Pulse is IPulse, PulseState {
         pythFeeInWei = _state.pythFeeInWei;
     }
 
-    function getAccruedFees()
+    function getAccruedPythFees()
         public
         view
         override
@@ -244,6 +273,7 @@ abstract contract Pulse is IPulse, PulseState {
         shortHash = uint8(hash[0] & NUM_REQUESTS_MASK);
     }
 
+    // TODO: move out governance functions into a separate PulseGovernance contract
     function withdrawFees(uint128 amount) external override {
         require(msg.sender == _state.admin, "Only admin can withdraw fees");
         require(_state.accruedFeesInWei >= amount, "Insufficient balance");
@@ -336,22 +366,51 @@ abstract contract Pulse is IPulse, PulseState {
         emit FeesWithdrawn(msg.sender, amount);
     }
 
-    function registerProvider(uint128 feeInWei) external override {
+    function registerProvider(
+        uint128 baseFeeInWei,
+        uint128 feePerFeedInWei,
+        uint128 feePerGasInWei
+    ) external override {
         ProviderInfo storage provider = _state.providers[msg.sender];
         require(!provider.isRegistered, "Provider already registered");
-        provider.feeInWei = feeInWei;
+        provider.baseFeeInWei = baseFeeInWei;
+        provider.feePerFeedInWei = feePerFeedInWei;
+        provider.feePerGasInWei = feePerGasInWei;
         provider.isRegistered = true;
-        emit ProviderRegistered(msg.sender, feeInWei);
+        emit ProviderRegistered(msg.sender, feePerGasInWei);
     }
 
-    function setProviderFee(uint128 newFeeInWei) external override {
+    function setProviderFee(
+        address provider,
+        uint128 newBaseFeeInWei,
+        uint128 newFeePerFeedInWei,
+        uint128 newFeePerGasInWei
+    ) external override {
         require(
-            _state.providers[msg.sender].isRegistered,
+            _state.providers[provider].isRegistered,
             "Provider not registered"
         );
-        uint128 oldFee = _state.providers[msg.sender].feeInWei;
-        _state.providers[msg.sender].feeInWei = newFeeInWei;
-        emit ProviderFeeUpdated(msg.sender, oldFee, newFeeInWei);
+        require(
+            msg.sender == provider ||
+                msg.sender == _state.providers[provider].feeManager,
+            "Only provider or fee manager can invoke this method"
+        );
+
+        uint128 oldBaseFee = _state.providers[provider].baseFeeInWei;
+        uint128 oldFeePerFeed = _state.providers[provider].feePerFeedInWei;
+        uint128 oldFeePerGas = _state.providers[provider].feePerGasInWei;
+        _state.providers[provider].baseFeeInWei = newBaseFeeInWei;
+        _state.providers[provider].feePerFeedInWei = newFeePerFeedInWei;
+        _state.providers[provider].feePerGasInWei = newFeePerGasInWei;
+        emit ProviderFeeUpdated(
+            provider,
+            oldBaseFee,
+            oldFeePerFeed,
+            oldFeePerGas,
+            newBaseFeeInWei,
+            newFeePerFeedInWei,
+            newFeePerGasInWei
+        );
     }
 
     function getProviderInfo(

+ 1 - 0
target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol

@@ -4,6 +4,7 @@ pragma solidity ^0.8.0;
 
 error NoSuchProvider();
 error NoSuchRequest();
+// TODO: add expected / provided values
 error InsufficientFee();
 error Unauthorized();
 error InvalidCallbackGas();

+ 7 - 3
target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol

@@ -13,7 +13,7 @@ interface PulseEvents {
         int64[] prices,
         uint64[] conf,
         int32[] expos,
-        uint256[] publishTimes
+        uint64[] publishTimes
     );
 
     event FeesWithdrawn(address indexed recipient, uint128 amount);
@@ -35,8 +35,12 @@ interface PulseEvents {
     event ProviderRegistered(address indexed provider, uint128 feeInWei);
     event ProviderFeeUpdated(
         address indexed provider,
-        uint128 oldFee,
-        uint128 newFee
+        uint128 oldBaseFee,
+        uint128 oldFeePerFeed,
+        uint128 oldFeePerGas,
+        uint128 newBaseFee,
+        uint128 newFeePerFeed,
+        uint128 newFeePerGas
     );
     event DefaultProviderUpdated(address oldProvider, address newProvider);
 

+ 8 - 2
target_chains/ethereum/contracts/contracts/pulse/PulseState.sol

@@ -11,16 +11,22 @@ contract PulseState {
 
     struct Request {
         uint64 sequenceNumber;
-        uint256 publishTime;
+        uint64 publishTime;
+        // TODO: this is going to absolutely explode gas costs. Need to do something smarter here.
+        // possible solution is to hash the price ids and store the hash instead.
+        // The ids themselves can be retrieved from the event.
         bytes32[MAX_PRICE_IDS] priceIds;
         uint8 numPriceIds; // Actual number of price IDs used
         uint256 callbackGasLimit;
         address requester;
         address provider;
+        uint128 fee;
     }
 
     struct ProviderInfo {
-        uint128 feeInWei;
+        uint128 baseFeeInWei;
+        uint128 feePerFeedInWei;
+        uint128 feePerGasInWei;
         uint128 accruedFeesInWei;
         address feeManager;
         bool isRegistered;

+ 261 - 71
target_chains/ethereum/contracts/forge-test/Pulse.t.sol

@@ -12,13 +12,22 @@ import "../contracts/pulse/PulseEvents.sol";
 import "../contracts/pulse/PulseErrors.sol";
 
 contract MockPulseConsumer is IPulseConsumer {
+    address private _pulse;
     uint64 public lastSequenceNumber;
     PythStructs.PriceFeed[] private _lastPriceFeeds;
 
+    constructor(address pulse) {
+        _pulse = pulse;
+    }
+
+    function getPulse() internal view override returns (address) {
+        return _pulse;
+    }
+
     function pulseCallback(
         uint64 sequenceNumber,
         PythStructs.PriceFeed[] memory priceFeeds
-    ) external override {
+    ) internal override {
         lastSequenceNumber = sequenceNumber;
         for (uint i = 0; i < priceFeeds.length; i++) {
             _lastPriceFeeds.push(priceFeeds[i]);
@@ -35,10 +44,20 @@ contract MockPulseConsumer is IPulseConsumer {
 }
 
 contract FailingPulseConsumer is IPulseConsumer {
+    address private _pulse;
+
+    constructor(address pulse) {
+        _pulse = pulse;
+    }
+
+    function getPulse() internal view override returns (address) {
+        return _pulse;
+    }
+
     function pulseCallback(
         uint64,
         PythStructs.PriceFeed[] memory
-    ) external pure override {
+    ) internal pure override {
         revert("callback failed");
     }
 }
@@ -46,14 +65,25 @@ contract FailingPulseConsumer is IPulseConsumer {
 contract CustomErrorPulseConsumer is IPulseConsumer {
     error CustomError(string message);
 
+    address private _pulse;
+
+    constructor(address pulse) {
+        _pulse = pulse;
+    }
+
+    function getPulse() internal view override returns (address) {
+        return _pulse;
+    }
+
     function pulseCallback(
         uint64,
         PythStructs.PriceFeed[] memory
-    ) external pure override {
+    ) internal pure override {
         revert CustomError("callback failed");
     }
 }
 
+// FIXME: this shouldn't be IPulseConsumer.
 contract PulseTest is Test, PulseEvents, IPulseConsumer {
     ERC1967Proxy public proxy;
     PulseUpgradeable public pulse;
@@ -64,7 +94,11 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
     address public defaultProvider;
     // Constants
     uint128 constant PYTH_FEE = 1 wei;
-    uint128 constant DEFAULT_PROVIDER_FEE = 1 wei;
+    uint128 constant DEFAULT_PROVIDER_FEE_PER_GAS = 1 wei;
+    uint128 constant DEFAULT_PROVIDER_BASE_FEE = 1 wei;
+    uint128 constant DEFAULT_PROVIDER_FEE_PER_FEED = 10 wei;
+    uint constant MOCK_PYTH_FEE_PER_FEED = 10 wei;
+
     uint128 constant CALLBACK_GAS_LIMIT = 1_000_000;
     bytes32 constant BTC_PRICE_FEED_ID =
         0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43;
@@ -97,8 +131,12 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
             15
         );
         vm.prank(defaultProvider);
-        pulse.registerProvider(DEFAULT_PROVIDER_FEE);
-        consumer = new MockPulseConsumer();
+        pulse.registerProvider(
+            DEFAULT_PROVIDER_BASE_FEE,
+            DEFAULT_PROVIDER_FEE_PER_FEED,
+            DEFAULT_PROVIDER_FEE_PER_GAS
+        );
+        consumer = new MockPulseConsumer(address(proxy));
     }
 
     // Helper function to create price IDs array
@@ -136,8 +174,17 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
     function mockParsePriceFeedUpdates(
         PythStructs.PriceFeed[] memory priceFeeds
     ) internal {
+        uint expectedFee = MOCK_PYTH_FEE_PER_FEED * priceFeeds.length;
+
         vm.mockCall(
             address(pyth),
+            abi.encodeWithSelector(IPyth.getUpdateFee.selector),
+            abi.encode(expectedFee)
+        );
+
+        vm.mockCall(
+            address(pyth),
+            expectedFee,
             abi.encodeWithSelector(IPyth.parsePriceFeedUpdates.selector),
             abi.encode(priceFeeds)
         );
@@ -154,8 +201,10 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
     }
 
     // Helper function to calculate total fee
+    // FIXME: I think this helper probably needs to take some arguments.
     function calculateTotalFee() internal view returns (uint128) {
-        return pulse.getFee(CALLBACK_GAS_LIMIT);
+        return
+            pulse.getFee(defaultProvider, CALLBACK_GAS_LIMIT, createPriceIds());
     }
 
     // Helper function to setup consumer request
@@ -166,17 +215,18 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         returns (
             uint64 sequenceNumber,
             bytes32[] memory priceIds,
-            uint256 publishTime
+            uint64 publishTime
         )
     {
         priceIds = createPriceIds();
-        publishTime = block.timestamp;
+        publishTime = SafeCast.toUint64(block.timestamp);
         vm.deal(consumerAddress, 1 gwei);
 
         uint128 totalFee = calculateTotalFee();
 
         vm.prank(consumerAddress);
         sequenceNumber = pulse.requestPriceUpdatesWithCallback{value: totalFee}(
+            defaultProvider,
             publishTime,
             priceIds,
             CALLBACK_GAS_LIMIT
@@ -190,7 +240,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         vm.txGasPrice(30 gwei);
 
         bytes32[] memory priceIds = createPriceIds();
-        uint256 publishTime = block.timestamp;
+        uint64 publishTime = SafeCast.toUint64(block.timestamp);
 
         // Fund the consumer contract with enough ETH for higher gas price
         vm.deal(address(consumer), 1 ether);
@@ -215,7 +265,8 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
             numPriceIds: 2,
             callbackGasLimit: CALLBACK_GAS_LIMIT,
             requester: address(consumer),
-            provider: defaultProvider
+            provider: defaultProvider,
+            fee: totalFee - PYTH_FEE
         });
 
         vm.expectEmit();
@@ -223,6 +274,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
 
         vm.prank(address(consumer));
         pulse.requestPriceUpdatesWithCallback{value: totalFee}(
+            defaultProvider,
             publishTime,
             priceIds,
             CALLBACK_GAS_LIMIT
@@ -256,7 +308,8 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         vm.prank(address(consumer));
         vm.expectRevert(InsufficientFee.selector);
         pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Intentionally low fee
-            block.timestamp,
+            defaultProvider,
+            SafeCast.toUint64(block.timestamp),
             priceIds,
             CALLBACK_GAS_LIMIT
         );
@@ -264,7 +317,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
 
     function testExecuteCallback() public {
         bytes32[] memory priceIds = createPriceIds();
-        uint256 publishTime = block.timestamp;
+        uint64 publishTime = SafeCast.toUint64(block.timestamp);
 
         // Fund the consumer contract
         vm.deal(address(consumer), 1 gwei);
@@ -274,12 +327,13 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         vm.prank(address(consumer));
         uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
             value: totalFee
-        }(publishTime, priceIds, CALLBACK_GAS_LIMIT);
+        }(defaultProvider, publishTime, priceIds, CALLBACK_GAS_LIMIT);
 
         // Step 2: Create mock price feeds and setup Pyth response
         PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
             publishTime
         );
+        // FIXME: this test doesn't ensure the Pyth fee is paid.
         mockParsePriceFeedUpdates(priceFeeds);
 
         // Create arrays for expected event data
@@ -295,7 +349,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         expectedExpos[0] = MOCK_PRICE_FEED_EXPO;
         expectedExpos[1] = MOCK_PRICE_FEED_EXPO;
 
-        uint256[] memory expectedPublishTimes = new uint256[](2);
+        uint64[] memory expectedPublishTimes = new uint64[](2);
         expectedPublishTimes[0] = publishTime;
         expectedPublishTimes[1] = publishTime;
 
@@ -315,7 +369,12 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         bytes[] memory updateData = createMockUpdateData(priceFeeds);
 
         vm.prank(defaultProvider);
-        pulse.executeCallback(sequenceNumber, updateData, priceIds);
+        pulse.executeCallback(
+            defaultProvider,
+            sequenceNumber,
+            updateData,
+            priceIds
+        );
 
         // Verify callback was executed
         assertEq(consumer.lastSequenceNumber(), sequenceNumber);
@@ -338,7 +397,9 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
     }
 
     function testExecuteCallbackFailure() public {
-        FailingPulseConsumer failingConsumer = new FailingPulseConsumer();
+        FailingPulseConsumer failingConsumer = new FailingPulseConsumer(
+            address(proxy)
+        );
 
         (
             uint64 sequenceNumber,
@@ -362,11 +423,18 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         );
 
         vm.prank(defaultProvider);
-        pulse.executeCallback(sequenceNumber, updateData, priceIds);
+        pulse.executeCallback(
+            defaultProvider,
+            sequenceNumber,
+            updateData,
+            priceIds
+        );
     }
 
     function testExecuteCallbackCustomErrorFailure() public {
-        CustomErrorPulseConsumer failingConsumer = new CustomErrorPulseConsumer();
+        CustomErrorPulseConsumer failingConsumer = new CustomErrorPulseConsumer(
+            address(proxy)
+        );
 
         (
             uint64 sequenceNumber,
@@ -390,7 +458,12 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         );
 
         vm.prank(defaultProvider);
-        pulse.executeCallback(sequenceNumber, updateData, priceIds);
+        pulse.executeCallback(
+            defaultProvider,
+            sequenceNumber,
+            updateData,
+            priceIds
+        );
     }
 
     function testExecuteCallbackWithInsufficientGas() public {
@@ -412,6 +485,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         vm.prank(defaultProvider);
         vm.expectRevert(); // Just expect any revert since it will be an out-of-gas error
         pulse.executeCallback{gas: 100000}(
+            defaultProvider,
             sequenceNumber,
             updateData,
             priceIds
@@ -421,14 +495,14 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
     function testExecuteCallbackWithFutureTimestamp() public {
         // Setup request with future timestamp
         bytes32[] memory priceIds = createPriceIds();
-        uint256 futureTime = block.timestamp + 10; // 10 seconds in future
+        uint64 futureTime = SafeCast.toUint64(block.timestamp + 10); // 10 seconds in future
         vm.deal(address(consumer), 1 gwei);
 
         uint128 totalFee = calculateTotalFee();
         vm.prank(address(consumer));
         uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
             value: totalFee
-        }(futureTime, priceIds, CALLBACK_GAS_LIMIT);
+        }(defaultProvider, futureTime, priceIds, CALLBACK_GAS_LIMIT);
 
         // Try to execute callback before the requested timestamp
         PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
@@ -439,7 +513,12 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
 
         vm.prank(defaultProvider);
         // Should succeed because we're simulating receiving future-dated price updates
-        pulse.executeCallback(sequenceNumber, updateData, priceIds);
+        pulse.executeCallback(
+            defaultProvider,
+            sequenceNumber,
+            updateData,
+            priceIds
+        );
 
         // Compare price feeds array length
         PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds();
@@ -456,7 +535,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
 
     function testRevertOnTooFarFutureTimestamp() public {
         bytes32[] memory priceIds = createPriceIds();
-        uint256 farFutureTime = block.timestamp + 61; // Just over 1 minute
+        uint64 farFutureTime = SafeCast.toUint64(block.timestamp + 61); // Just over 1 minute
         vm.deal(address(consumer), 1 gwei);
 
         uint128 totalFee = calculateTotalFee();
@@ -464,6 +543,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
 
         vm.expectRevert("Too far in future");
         pulse.requestPriceUpdatesWithCallback{value: totalFee}(
+            defaultProvider,
             farFutureTime,
             priceIds,
             CALLBACK_GAS_LIMIT
@@ -485,12 +565,22 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
 
         // First execution
         vm.prank(defaultProvider);
-        pulse.executeCallback(sequenceNumber, updateData, priceIds);
+        pulse.executeCallback(
+            defaultProvider,
+            sequenceNumber,
+            updateData,
+            priceIds
+        );
 
         // Second execution should fail
         vm.prank(defaultProvider);
         vm.expectRevert(NoSuchRequest.selector);
-        pulse.executeCallback(sequenceNumber, updateData, priceIds);
+        pulse.executeCallback(
+            defaultProvider,
+            sequenceNumber,
+            updateData,
+            priceIds
+        );
     }
 
     function testGetFee() public {
@@ -500,12 +590,22 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         gasLimits[1] = 500_000;
         gasLimits[2] = 1_000_000;
 
+        bytes32[] memory priceIds = createPriceIds();
+
         for (uint256 i = 0; i < gasLimits.length; i++) {
             uint256 gasLimit = gasLimits[i];
             uint128 expectedFee = SafeCast.toUint128(
-                DEFAULT_PROVIDER_FEE * gasLimit
+                DEFAULT_PROVIDER_BASE_FEE +
+                    DEFAULT_PROVIDER_FEE_PER_FEED *
+                    priceIds.length +
+                    DEFAULT_PROVIDER_FEE_PER_GAS *
+                    gasLimit
             ) + PYTH_FEE;
-            uint128 actualFee = pulse.getFee(gasLimit);
+            uint128 actualFee = pulse.getFee(
+                defaultProvider,
+                gasLimit,
+                priceIds
+            );
             assertEq(
                 actualFee,
                 expectedFee,
@@ -514,8 +614,13 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         }
 
         // Test with zero gas limit
-        uint128 expectedMinFee = PYTH_FEE;
-        uint128 actualMinFee = pulse.getFee(0);
+        uint128 expectedMinFee = SafeCast.toUint128(
+            PYTH_FEE +
+                DEFAULT_PROVIDER_BASE_FEE +
+                DEFAULT_PROVIDER_FEE_PER_FEED *
+                priceIds.length
+        );
+        uint128 actualMinFee = pulse.getFee(defaultProvider, 0, priceIds);
         assertEq(
             actualMinFee,
             expectedMinFee,
@@ -530,14 +635,15 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
 
         vm.prank(address(consumer));
         pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
-            block.timestamp,
+            defaultProvider,
+            SafeCast.toUint64(block.timestamp),
             priceIds,
             CALLBACK_GAS_LIMIT
         );
 
         // Get admin's balance before withdrawal
         uint256 adminBalanceBefore = admin.balance;
-        uint128 accruedFees = pulse.getAccruedFees();
+        uint128 accruedFees = pulse.getAccruedPythFees();
 
         // Withdraw fees as admin
         vm.prank(admin);
@@ -550,7 +656,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
             "Admin balance should increase by withdrawn amount"
         );
         assertEq(
-            pulse.getAccruedFees(),
+            pulse.getAccruedPythFees(),
             0,
             "Contract should have no fees after withdrawal"
         );
@@ -580,7 +686,8 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
 
         vm.prank(address(consumer));
         pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
-            block.timestamp,
+            defaultProvider,
+            SafeCast.toUint64(block.timestamp),
             priceIds,
             CALLBACK_GAS_LIMIT
         );
@@ -662,7 +769,12 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
                 priceIds[0]
             )
         );
-        pulse.executeCallback(sequenceNumber, updateData, wrongPriceIds);
+        pulse.executeCallback(
+            defaultProvider,
+            sequenceNumber,
+            updateData,
+            wrongPriceIds
+        );
     }
 
     function testRevertOnTooManyPriceIds() public {
@@ -685,7 +797,8 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
             )
         );
         pulse.requestPriceUpdatesWithCallback{value: totalFee}(
-            block.timestamp,
+            defaultProvider,
+            SafeCast.toUint64(block.timestamp),
             priceIds,
             CALLBACK_GAS_LIMIT
         );
@@ -696,26 +809,36 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         uint128 providerFee = 1000;
 
         vm.prank(provider);
-        pulse.registerProvider(providerFee);
+        pulse.registerProvider(providerFee, providerFee, providerFee);
 
         PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider);
-        assertEq(info.feeInWei, providerFee);
+        assertEq(info.feePerGasInWei, providerFee);
         assertTrue(info.isRegistered);
     }
 
     function testSetProviderFee() public {
         address provider = address(0x123);
-        uint128 initialFee = 1000;
-        uint128 newFee = 2000;
+        uint128 initialBaseFee = 1000;
+        uint128 initialFeePerFeed = 2000;
+        uint128 initialFeePerGas = 3000;
+        uint128 newFeePerFeed = 4000;
+        uint128 newBaseFee = 5000;
+        uint128 newFeePerGas = 6000;
 
         vm.prank(provider);
-        pulse.registerProvider(initialFee);
+        pulse.registerProvider(
+            initialBaseFee,
+            initialFeePerFeed,
+            initialFeePerGas
+        );
 
         vm.prank(provider);
-        pulse.setProviderFee(newFee);
+        pulse.setProviderFee(provider, newBaseFee, newFeePerFeed, newFeePerGas);
 
         PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider);
-        assertEq(info.feeInWei, newFee);
+        assertEq(info.baseFeeInWei, newBaseFee);
+        assertEq(info.feePerFeedInWei, newFeePerFeed);
+        assertEq(info.feePerGasInWei, newFeePerGas);
     }
 
     function testDefaultProvider() public {
@@ -723,7 +846,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         uint128 providerFee = 1000;
 
         vm.prank(provider);
-        pulse.registerProvider(providerFee);
+        pulse.registerProvider(providerFee, providerFee, providerFee);
 
         vm.prank(admin);
         pulse.setDefaultProvider(provider);
@@ -736,21 +859,23 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         uint128 providerFee = 1000;
 
         vm.prank(provider);
-        pulse.registerProvider(providerFee);
-
-        vm.prank(admin);
-        pulse.setDefaultProvider(provider);
+        pulse.registerProvider(providerFee, providerFee, providerFee);
 
         bytes32[] memory priceIds = new bytes32[](1);
         priceIds[0] = bytes32(uint256(1));
 
-        uint128 totalFee = pulse.getFee(CALLBACK_GAS_LIMIT);
+        uint128 totalFee = pulse.getFee(provider, CALLBACK_GAS_LIMIT, priceIds);
 
         vm.deal(address(consumer), totalFee);
         vm.prank(address(consumer));
         uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
             value: totalFee
-        }(block.timestamp, priceIds, CALLBACK_GAS_LIMIT);
+        }(
+            provider,
+            SafeCast.toUint64(block.timestamp),
+            priceIds,
+            CALLBACK_GAS_LIMIT
+        );
 
         PulseState.Request memory req = pulse.getRequest(sequenceNumber);
         assertEq(req.provider, provider);
@@ -787,7 +912,11 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         // Register a second provider
         address secondProvider = address(0x456);
         vm.prank(secondProvider);
-        pulse.registerProvider(DEFAULT_PROVIDER_FEE);
+        pulse.registerProvider(
+            DEFAULT_PROVIDER_BASE_FEE,
+            DEFAULT_PROVIDER_FEE_PER_FEED,
+            DEFAULT_PROVIDER_FEE_PER_GAS
+        );
 
         // Setup request
         (
@@ -804,20 +933,32 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         bytes[] memory updateData = createMockUpdateData(priceFeeds);
 
         // Try to execute with second provider during exclusivity period
-        vm.prank(secondProvider);
         vm.expectRevert("Only assigned provider during exclusivity period");
-        pulse.executeCallback(sequenceNumber, updateData, priceIds);
+        pulse.executeCallback(
+            secondProvider,
+            sequenceNumber,
+            updateData,
+            priceIds
+        );
 
         // Original provider should succeed
-        vm.prank(defaultProvider);
-        pulse.executeCallback(sequenceNumber, updateData, priceIds);
+        pulse.executeCallback(
+            defaultProvider,
+            sequenceNumber,
+            updateData,
+            priceIds
+        );
     }
 
     function testExecuteCallbackAfterExclusivity() public {
         // Register a second provider
         address secondProvider = address(0x456);
         vm.prank(secondProvider);
-        pulse.registerProvider(DEFAULT_PROVIDER_FEE);
+        pulse.registerProvider(
+            DEFAULT_PROVIDER_BASE_FEE,
+            DEFAULT_PROVIDER_FEE_PER_FEED,
+            DEFAULT_PROVIDER_FEE_PER_GAS
+        );
 
         // Setup request
         (
@@ -838,14 +979,23 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
 
         // Second provider should now succeed
         vm.prank(secondProvider);
-        pulse.executeCallback(sequenceNumber, updateData, priceIds);
+        pulse.executeCallback(
+            defaultProvider,
+            sequenceNumber,
+            updateData,
+            priceIds
+        );
     }
 
     function testExecuteCallbackWithCustomExclusivityPeriod() public {
         // Register a second provider
         address secondProvider = address(0x456);
         vm.prank(secondProvider);
-        pulse.registerProvider(DEFAULT_PROVIDER_FEE);
+        pulse.registerProvider(
+            DEFAULT_PROVIDER_BASE_FEE,
+            DEFAULT_PROVIDER_FEE_PER_FEED,
+            DEFAULT_PROVIDER_FEE_PER_GAS
+        );
 
         // Set custom exclusivity period
         vm.prank(admin);
@@ -867,14 +1017,22 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
 
         // Try at 29 seconds (should fail for second provider)
         vm.warp(block.timestamp + 29);
-        vm.prank(secondProvider);
         vm.expectRevert("Only assigned provider during exclusivity period");
-        pulse.executeCallback(sequenceNumber, updateData, priceIds);
+        pulse.executeCallback(
+            secondProvider,
+            sequenceNumber,
+            updateData,
+            priceIds
+        );
 
         // Try at 31 seconds (should succeed for second provider)
         vm.warp(block.timestamp + 2);
-        vm.prank(secondProvider);
-        pulse.executeCallback(sequenceNumber, updateData, priceIds);
+        pulse.executeCallback(
+            secondProvider,
+            sequenceNumber,
+            updateData,
+            priceIds
+        );
     }
 
     function testGetFirstActiveRequests() public {
@@ -902,10 +1060,11 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
     }
 
     function createTestRequests(bytes32[] memory priceIds) private {
-        uint256 publishTime = block.timestamp;
+        uint64 publishTime = SafeCast.toUint64(block.timestamp);
         for (uint i = 0; i < 5; i++) {
             vm.deal(address(this), 1 ether);
             pulse.requestPriceUpdatesWithCallback{value: 1 ether}(
+                defaultProvider,
                 publishTime,
                 priceIds,
                 1000000
@@ -919,15 +1078,25 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
     ) private {
         // Create mock price feeds and setup Pyth response
         PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
-            block.timestamp
+            SafeCast.toUint64(block.timestamp)
         );
         mockParsePriceFeedUpdates(priceFeeds);
         updateData = createMockUpdateData(priceFeeds);
 
         vm.deal(defaultProvider, 2 ether); // Increase ETH allocation to prevent OutOfFunds
         vm.startPrank(defaultProvider);
-        pulse.executeCallback{value: 1 ether}(2, updateData, priceIds);
-        pulse.executeCallback{value: 1 ether}(4, updateData, priceIds);
+        pulse.executeCallback{value: 1 ether}(
+            defaultProvider,
+            2,
+            updateData,
+            priceIds
+        );
+        pulse.executeCallback{value: 1 ether}(
+            defaultProvider,
+            4,
+            updateData,
+            priceIds
+        );
         vm.stopPrank();
     }
 
@@ -1000,9 +1169,24 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
     ) private {
         vm.deal(defaultProvider, 3 ether); // Increase ETH allocation
         vm.startPrank(defaultProvider);
-        pulse.executeCallback{value: 1 ether}(1, updateData, priceIds);
-        pulse.executeCallback{value: 1 ether}(3, updateData, priceIds);
-        pulse.executeCallback{value: 1 ether}(5, updateData, priceIds);
+        pulse.executeCallback{value: 1 ether}(
+            defaultProvider,
+            1,
+            updateData,
+            priceIds
+        );
+        pulse.executeCallback{value: 1 ether}(
+            defaultProvider,
+            3,
+            updateData,
+            priceIds
+        );
+        pulse.executeCallback{value: 1 ether}(
+            defaultProvider,
+            5,
+            updateData,
+            priceIds
+        );
         vm.stopPrank();
     }
 
@@ -1017,7 +1201,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         // Setup test data
         bytes32[] memory priceIds = new bytes32[](1);
         priceIds[0] = bytes32(uint256(1));
-        uint256 publishTime = block.timestamp;
+        uint64 publishTime = SafeCast.toUint64(block.timestamp);
         uint256 callbackGasLimit = 1000000;
 
         // Create mock price feeds and setup Pyth response
@@ -1031,6 +1215,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         for (uint i = 0; i < 20; i++) {
             vm.deal(address(this), 1 ether);
             pulse.requestPriceUpdatesWithCallback{value: 1 ether}(
+                defaultProvider,
                 publishTime,
                 priceIds,
                 callbackGasLimit
@@ -1041,6 +1226,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
                 vm.deal(defaultProvider, 1 ether);
                 vm.prank(defaultProvider);
                 pulse.executeCallback{value: 1 ether}(
+                    defaultProvider,
                     uint64(i + 1),
                     updateData,
                     priceIds
@@ -1071,11 +1257,15 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer {
         );
     }
 
+    function getPulse() internal view override returns (address) {
+        return address(pulse);
+    }
+
     // Mock implementation of pulseCallback
     function pulseCallback(
         uint64 sequenceNumber,
         PythStructs.PriceFeed[] memory priceFeeds
-    ) external override {
+    ) internal override {
         // Just accept the callback, no need to do anything with the data
         // This prevents the revert we're seeing
     }