Sfoglia il codice sorgente

Remove middle structs (#375)

* Remove middle structs

* Parse and process attestation in place

It helps a lot because we won't expand memory anymore

* Update comments

* Remove unusued PriceAttestation struct
Ali Behjati 3 anni fa
parent
commit
b23258112d

+ 118 - 124
ethereum/contracts/pyth/Pyth.sol

@@ -24,36 +24,13 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth {
         setPyth2WormholeEmitter(pyth2WormholeEmitter);
     }
 
-    function updatePriceBatchFromVm(bytes calldata encodedVm) private returns (PythInternalStructs.BatchPriceAttestation memory bpa) {
+    function updatePriceBatchFromVm(bytes calldata encodedVm) private {
         (IWormhole.VM memory vm, bool valid, string memory reason) = wormhole().parseAndVerifyVM(encodedVm);
 
         require(valid, reason);
         require(verifyPythVM(vm), "invalid data source chain/emitter ID");
 
-        PythInternalStructs.BatchPriceAttestation memory batch = parseBatchPriceAttestation(vm.payload);
-
-        uint freshPrices = 0;
-
-        for (uint i = 0; i < batch.attestations.length; i++) {
-            PythInternalStructs.PriceAttestation memory attestation = batch.attestations[i];
-    
-            PythInternalStructs.PriceInfo memory newPriceInfo = createNewPriceInfo(attestation);
-            PythInternalStructs.PriceInfo memory latestPrice = latestPriceInfo(attestation.priceId);
-
-            bool fresh = false;
-            if(newPriceInfo.price.publishTime > latestPrice.price.publishTime) {
-                freshPrices += 1;
-                fresh = true;
-                setLatestPriceInfo(attestation.priceId, newPriceInfo);
-            }
-
-            emit PriceFeedUpdate(attestation.priceId, fresh, vm.emitterChainId, vm.sequence, latestPrice.price.publishTime,
-                newPriceInfo.price.publishTime, newPriceInfo.price.price, newPriceInfo.price.conf);
-        }
-
-        emit BatchPriceFeedUpdate(vm.emitterChainId, vm.sequence, batch.attestations.length, freshPrices);
-
-        return batch;
+        parseAndProcessBatchPriceAttestation(vm);
     }
 
     function updatePriceFeeds(bytes[] calldata updateData) public override payable {
@@ -76,146 +53,163 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth {
         return singleUpdateFeeInWei() * updateData.length;
     }
 
-    function createNewPriceInfo(PythInternalStructs.PriceAttestation memory pa) private pure returns (PythInternalStructs.PriceInfo memory info) {
-        PythInternalStructs.PriceAttestationStatus status = PythInternalStructs.PriceAttestationStatus(pa.status);
-        if (status == PythInternalStructs.PriceAttestationStatus.TRADING) {
-            info.price.price = pa.price;
-            info.price.conf = pa.conf;
-            info.price.publishTime = pa.publishTime;
-            info.emaPrice.publishTime = pa.publishTime;
-        } else {
-            info.price.price = pa.prevPrice;
-            info.price.conf = pa.prevConf;
-            info.price.publishTime = pa.prevPublishTime;
-
-            // The EMA is last updated when the aggregate had trading status,
-            // so, we use prev_publish_time (the time when the aggregate last had trading status).
-            info.emaPrice.publishTime = pa.prevPublishTime;
-        }
-
-        info.price.expo = pa.expo;
-        info.emaPrice.price = pa.emaPrice;
-        info.emaPrice.conf = pa.emaConf;
-        info.emaPrice.expo = pa.expo;
-
-        return info;
-    }
-
     function verifyPythVM(IWormhole.VM memory vm) private view returns (bool valid) {
         return isValidDataSource(vm.emitterChainId, vm.emitterAddress); 
     }
 
-
-    function parseBatchPriceAttestation(bytes memory encoded) public pure returns (PythInternalStructs.BatchPriceAttestation memory bpa) {
+    function parseAndProcessBatchPriceAttestation(IWormhole.VM memory vm) internal {
+        bytes memory encoded = vm.payload;    
         uint index = 0;
 
         // Check header
-        bpa.header.magic = encoded.toUint32(index);
-        index += 4;
-        require(bpa.header.magic == 0x50325748, "invalid magic value");
-
-        bpa.header.versionMajor = encoded.toUint16(index);
-        index += 2;
-        require(bpa.header.versionMajor == 3, "invalid version major, expected 3");
-
-        bpa.header.versionMinor = encoded.toUint16(index);
-        index += 2;
-        require(bpa.header.versionMinor >= 0, "invalid version minor, expected 0 or more");
-
-        bpa.header.hdrSize = encoded.toUint16(index);
-        index += 2;
-
-        // NOTE(2022-04-19): Currently, only payloadId comes after
-        // hdrSize. Future extra header fields must be read using a
-        // separate offset to respect hdrSize, i.e.:
-        //
-        // uint hdrIndex = 0;
-        // bpa.header.payloadId = encoded.toUint8(index + hdrIndex);
-        // hdrIndex += 1;
-        //
-        // bpa.header.someNewField = encoded.toUint32(index + hdrIndex);
-        // hdrIndex += 4;
-        //
-        // // Skip remaining unknown header bytes
-        // index += bpa.header.hdrSize;
-
-        bpa.header.payloadId = encoded.toUint8(index);
-
-        // Skip remaining unknown header bytes
-        index += bpa.header.hdrSize;
-
-        // Payload ID of 2 required for batch headerBa
-        require(bpa.header.payloadId == 2, "invalid payload ID, expected 2 for BatchPriceAttestation");
+        {
+            uint32 magic = encoded.toUint32(index);
+            index += 4;
+            require(magic == 0x50325748, "invalid magic value");
+
+            uint16 versionMajor = encoded.toUint16(index);
+            index += 2;
+            require(versionMajor == 3, "invalid version major, expected 3");
+
+            uint16 versionMinor = encoded.toUint16(index);
+            index += 2;
+            require(versionMinor >= 0, "invalid version minor, expected 0 or more");
+
+            uint16 hdrSize = encoded.toUint16(index);
+            index += 2;
+
+            // NOTE(2022-04-19): Currently, only payloadId comes after
+            // hdrSize. Future extra header fields must be read using a
+            // separate offset to respect hdrSize, i.e.:
+            //
+            // uint hdrIndex = 0;
+            // bpa.header.payloadId = encoded.toUint8(index + hdrIndex);
+            // hdrIndex += 1;
+            //
+            // bpa.header.someNewField = encoded.toUint32(index + hdrIndex);
+            // hdrIndex += 4;
+            //
+            // // Skip remaining unknown header bytes
+            // index += bpa.header.hdrSize;
+
+            uint8 payloadId = encoded.toUint8(index);
+
+            // Skip remaining unknown header bytes
+            index += hdrSize;
+
+            // Payload ID of 2 required for batch headerBa
+            require(payloadId == 2, "invalid payload ID, expected 2 for BatchPriceAttestation");
+        }
 
         // Parse the number of attestations
-        bpa.nAttestations = encoded.toUint16(index);
+        uint16 nAttestations = encoded.toUint16(index);
         index += 2;
 
         // Parse the attestation size
-        bpa.attestationSize = encoded.toUint16(index);
+        uint16 attestationSize = encoded.toUint16(index);
         index += 2;
-        require(encoded.length == (index + (bpa.attestationSize * bpa.nAttestations)), "invalid BatchPriceAttestation size");
-
-        bpa.attestations = new PythInternalStructs.PriceAttestation[](bpa.nAttestations);
+        require(encoded.length == (index + (attestationSize * nAttestations)), "invalid BatchPriceAttestation size");
 
+        PythInternalStructs.PriceInfo memory info;
+        bytes32 priceId;
+        uint freshPrices = 0;
+        
         // Deserialize each attestation
-        for (uint j=0; j < bpa.nAttestations; j++) {
+        for (uint j=0; j < nAttestations; j++) {
             // NOTE: We don't advance the global index immediately.
             // attestationIndex is an attestation-local offset used
             // for readability and easier debugging.
             uint attestationIndex = 0;
 
-            // Attestation
-            bpa.attestations[j].productId = encoded.toBytes32(index + attestationIndex);
+            // Unused bytes32 product id
             attestationIndex += 32;
 
-            bpa.attestations[j].priceId = encoded.toBytes32(index + attestationIndex);
+            priceId = encoded.toBytes32(index + attestationIndex);
             attestationIndex += 32;
 
-            bpa.attestations[j].price = int64(encoded.toUint64(index + attestationIndex));
+            info.price.price = int64(encoded.toUint64(index + attestationIndex));
             attestationIndex += 8;
 
-            bpa.attestations[j].conf = encoded.toUint64(index + attestationIndex);
+            info.price.conf = encoded.toUint64(index + attestationIndex);
             attestationIndex += 8;
 
-            bpa.attestations[j].expo = int32(encoded.toUint32(index + attestationIndex));
+            info.price.expo = int32(encoded.toUint32(index + attestationIndex));
+            info.emaPrice.expo = info.price.expo;
             attestationIndex += 4;
 
-            bpa.attestations[j].emaPrice = int64(encoded.toUint64(index + attestationIndex));
+            info.emaPrice.price = int64(encoded.toUint64(index + attestationIndex));
             attestationIndex += 8;
 
-            bpa.attestations[j].emaConf = encoded.toUint64(index + attestationIndex);
+            info.emaPrice.conf = encoded.toUint64(index + attestationIndex);
             attestationIndex += 8;
 
-            bpa.attestations[j].status = encoded.toUint8(index + attestationIndex);
-            attestationIndex += 1;
-
-            bpa.attestations[j].numPublishers = encoded.toUint32(index + attestationIndex);
-            attestationIndex += 4;
-
-            bpa.attestations[j].maxNumPublishers = encoded.toUint32(index + attestationIndex);
-            attestationIndex += 4;
+            {
+                // Status is an enum (encoded as uint8) with the following values:
+                // 0 = UNKNOWN: The price feed is not currently updating for an unknown reason.
+                // 1 = TRADING: The price feed is updating as expected.
+                // 2 = HALTED: The price feed is not currently updating because trading in the product has been halted.
+                // 3 = AUCTION: The price feed is not currently updating because an auction is setting the price.
+                uint8 status = encoded.toUint8(index + attestationIndex);
+                attestationIndex += 1;
+
+                // Unused uint32 numPublishers
+                attestationIndex += 4;
+
+                // Unused uint32 numPublishers
+                attestationIndex += 4;
+
+                // Unused uint64 attestationTime
+                attestationIndex += 8;
+
+                info.price.publishTime = encoded.toUint64(index + attestationIndex);
+                info.emaPrice.publishTime = info.price.publishTime;
+                attestationIndex += 8;
+
+                if (status == 1) { // status == TRADING
+                    attestationIndex += 24;
+                } else {
+                    // If status is not trading then the latest available price is
+                    // the previous price info that are passed here.
+
+                    // Previous publish time
+                    info.price.publishTime = encoded.toUint64(index + attestationIndex);
+                    attestationIndex += 8;
+
+                    // Previous price
+                    info.price.price = int64(encoded.toUint64(index + attestationIndex));
+                    attestationIndex += 8;
+
+                    // Previous confidence
+                    info.price.conf = encoded.toUint64(index + attestationIndex);
+                    attestationIndex += 8;
+
+                    // The EMA is last updated when the aggregate had trading status,
+                    // so, we use previous publish time here too.
+                    info.emaPrice.publishTime = info.price.publishTime;
+                }
+            }
 
-            bpa.attestations[j].attestationTime = encoded.toUint64(index + attestationIndex);
-            attestationIndex += 8;
+            require(attestationIndex <= attestationSize, "INTERNAL: Consumed more than `attestationSize` bytes");
 
-            bpa.attestations[j].publishTime = encoded.toUint64(index + attestationIndex);
-            attestationIndex += 8;
+            // Respect specified attestation size for forward-compat
+            index += attestationSize;
 
-            bpa.attestations[j].prevPublishTime = encoded.toUint64(index + attestationIndex);
-            attestationIndex += 8;
+            // Store the attestation
+            PythInternalStructs.PriceInfo memory latestPrice = latestPriceInfo(priceId);
 
-            bpa.attestations[j].prevPrice = int64(encoded.toUint64(index + attestationIndex));
-            attestationIndex += 8;
+            bool fresh = false;
+            if(info.price.publishTime > latestPrice.price.publishTime) {
+                freshPrices += 1;
+                fresh = true;
+                setLatestPriceInfo(priceId, info);
+            }
 
-            bpa.attestations[j].prevConf = encoded.toUint64(index + attestationIndex);
-            attestationIndex += 8;
+            emit PriceFeedUpdate(priceId, fresh, vm.emitterChainId, vm.sequence, latestPrice.price.publishTime,
+                info.price.publishTime, info.price.price, info.price.conf);
+        }
 
-            require(attestationIndex <= bpa.attestationSize, "INTERNAL: Consumed more than `attestationSize` bytes");
 
-            // Respect specified attestation size for forward-compat
-            index += bpa.attestationSize;
-        }
+        emit BatchPriceFeedUpdate(vm.emitterChainId, vm.sequence, nAttestations, freshPrices);
     }
 
     function queryPriceFeed(bytes32 id) public view override returns (PythStructs.PriceFeed memory priceFeed){

+ 0 - 47
ethereum/contracts/pyth/PythInternalStructs.sol

@@ -9,40 +9,6 @@ import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
 contract PythInternalStructs {
     using BytesLib for bytes;
 
-    struct BatchPriceAttestation {
-        Header header;
-
-        uint16 nAttestations;
-        uint16 attestationSize;
-        PriceAttestation[] attestations;
-    }
-
-    struct Header {
-        uint32 magic;
-        uint16 versionMajor;
-        uint16 versionMinor;
-        uint16 hdrSize;
-        uint8 payloadId;
-    }
-
-    struct PriceAttestation {
-        bytes32 productId;
-        bytes32 priceId;
-        int64 price;
-        uint64 conf;
-        int32 expo;
-        int64 emaPrice;
-        uint64 emaConf;
-        uint8 status;
-        uint32 numPublishers;
-        uint32 maxNumPublishers;
-        uint64 attestationTime;
-        uint64 publishTime;
-        uint64 prevPublishTime;
-        int64 prevPrice;
-        uint64 prevConf;
-    }
-
     struct InternalPrice {
         int64 price;
         uint64 conf;
@@ -62,17 +28,4 @@ contract PythInternalStructs {
         uint16 chainId;
         bytes32 emitterAddress;
     }
-
-    /* PriceAttestationStatus represents the availability status of a price feed passed down in attestation.
-        UNKNOWN: The price feed is not currently updating for an unknown reason.
-        TRADING: The price feed is updating as expected.
-        HALTED: The price feed is not currently updating because trading in the product has been halted.
-        AUCTION: The price feed is not currently updating because an auction is setting the price.
-    */
-    enum PriceAttestationStatus {
-        UNKNOWN,
-        TRADING,
-        HALTED,
-        AUCTION
-    }
 }

+ 1 - 1
ethereum/forge-test/utils/PythTestUtils.t.sol

@@ -80,7 +80,7 @@ abstract contract PythTestUtils is Test, WormholeTestUtils {
             // Breaking this in two encodePackes because of the limited EVM stack.
             attestations = abi.encodePacked(
                 attestations,
-                uint8(PythInternalStructs.PriceAttestationStatus.TRADING),
+                uint8(1), // status = 1 = Trading
                 uint32(5), // Number of publishers. This field is not used.
                 uint32(10), // Maximum number of publishers. This field is not used.
                 uint64(prices[i].publishTime), // Attestation time. This field is not used.

+ 31 - 49
ethereum/test/pyth.js

@@ -248,14 +248,17 @@ contract("Pyth", function () {
     function generateRawBatchAttestation(
         publishTime,
         attestationTime,
-        priceVal
+        priceVal,
+        emaPriceVal,
     ) {
         const pubTs = u64ToHex(publishTime);
         const attTs = u64ToHex(attestationTime);
         const price = u64ToHex(priceVal);
+        const emaPrice = u64ToHex(emaPriceVal || priceVal);
         const replaced = RAW_BATCH.replace(RAW_BATCH_PUBLISH_TIME_REGEX, pubTs)
             .replace(RAW_BATCH_ATTESTATION_TIME_REGEX, attTs)
-            .replace(RAW_BATCH_PRICE_REGEX, price);
+            .replace(RAW_BATCH_PRICE_REGEX, price)
+            .replace(RAW_BATCH_EMA_PRICE_REGEX, emaPrice);
         return replaced;
     }
 
@@ -283,61 +286,40 @@ contract("Pyth", function () {
         return replaced;
     }
 
-    it("should parse batch price attestation correctly", async function () {
-        const magic = 0x50325748;
-        const versionMajor = 3;
-        const versionMinor = 0;
-
+    it.only("should parse batch price attestation correctly", async function () {
         let attestationTime = 1647273460; // re-used for publishTime
         let publishTime = 1647273465; // re-used for publishTime
         let priceVal = 1337;
+        let emaPriceVal = 2022;
         let rawBatch = generateRawBatchAttestation(
             publishTime,
             attestationTime,
-            priceVal
+            priceVal,
+            emaPriceVal
         );
-        let parsed = await this.pythProxy.parseBatchPriceAttestation(rawBatch);
-
-        // Check the header
-        assert.equal(parsed.header.magic, magic);
-        assert.equal(parsed.header.versionMajor, versionMajor);
-        assert.equal(parsed.header.versionMinor, versionMinor);
-        assert.equal(parsed.header.payloadId, 2);
-
-        assert.equal(parsed.nAttestations, RAW_BATCH_ATTESTATION_COUNT);
-        assert.equal(parsed.attestationSize, RAW_PRICE_ATTESTATION_SIZE);
-
-        assert.equal(parsed.attestations.length, parsed.nAttestations);
-
-        for (var i = 0; i < parsed.attestations.length; ++i) {
-            const prodId =
-                "0x" + (i + 1).toString(16).padStart(2, "0").repeat(32);
-            const priceByte = 255 - ((i + 1) % 256);
-            const priceId =
-                "0x" + priceByte.toString(16).padStart(2, "0").repeat(32);
-
-            assert.equal(parsed.attestations[i].productId, prodId);
-            assert.equal(parsed.attestations[i].priceId, priceId);
-            assert.equal(parsed.attestations[i].price, priceVal);
-            assert.equal(parsed.attestations[i].conf, 101);
-            assert.equal(parsed.attestations[i].expo, -3);
-            assert.equal(parsed.attestations[i].emaPrice, -42);
-            assert.equal(parsed.attestations[i].emaConf, 42);
-            assert.equal(parsed.attestations[i].status, 1);
-            assert.equal(parsed.attestations[i].numPublishers, 123212);
-            assert.equal(parsed.attestations[i].maxNumPublishers, 321232);
-            assert.equal(
-                parsed.attestations[i].attestationTime,
-                attestationTime
-            );
-            assert.equal(parsed.attestations[i].publishTime, publishTime);
-            assert.equal(parsed.attestations[i].prevPublishTime, 0xdeadbabe);
-            assert.equal(parsed.attestations[i].prevPrice, 0xdeadfacebeef);
-            assert.equal(parsed.attestations[i].prevConf, 0xbadbadbeef);
 
-            console.debug(
-                `attestation ${i + 1}/${parsed.attestations.length} parsed OK`
-            );
+        const receipt = await updatePriceFeeds(this.pythProxy, [rawBatch]);
+
+        expectEvent(receipt, 'PriceFeedUpdate', {
+            price: "1337",
+        });
+
+        for (var i = 1; i <= RAW_BATCH_ATTESTATION_COUNT; i++) {
+            const price_id =
+                "0x" +
+                (255 - (i % 256)).toString(16).padStart(2, "0").repeat(32);
+
+            const price = await this.pythProxy.getPriceUnsafe(price_id);
+            assert.equal(price.price, priceVal.toString());
+            assert.equal(price.conf, "101"); // The value is hardcoded in the RAW_BATCH.
+            assert.equal(price.publishTime, publishTime.toString());
+            assert.equal(price.expo, "-3"); // The value is hardcoded in the RAW_BATCH.
+
+            const emaPrice = await this.pythProxy.getEmaPriceUnsafe(price_id);
+            assert.equal(emaPrice.price, emaPriceVal.toString());
+            assert.equal(emaPrice.conf, "42"); // The value is hardcoded in the RAW_BATCH.
+            assert.equal(emaPrice.publishTime, publishTime.toString());
+            assert.equal(emaPrice.expo, "-3"); // The value is hardcoded in the RAW_BATCH.
         }
     });