#pragma kernel StepAndSolve #pragma multi_compile __ HAS_COLLIDERS #pragma multi_compile __ COLLIDE_WITH_SOURCE_MESH #pragma multi_compile __ USE_FORWARD_COLLISION #pragma multi_compile __ INIT_VERLET #pragma warning( disable : 4714 ) #pragma warning( disable : 4000 ) #include "VerletSimulationInclude.hlsl" #include "Thread.hlsl" #define DISCARD_NODE if((int)id.x >= verletNodesCount) return; RWStructuredBuffer _RestPositions; /* *We need to make sure that strands are not manipulated from two different thread groups. *Therefore the nearestPow. We dispatch (strandCount * nearestPow). *Then we calculate the strand index and localNodeIndex based on the nearest pow2 data. *We then use that strandIndex and localNodeIndex but multiply it with the actual strandPointsCount to get the actual index(idx). * *Note: As a results we have a lot of unused callbacks. Unused count = (nearestPow - strandPointsCount) * strandCount; * *But from my tests, this is still faster than dispatching Solve, solverIterations times from c# code. With this approach we only need *one dispatch with a loop in the compute shader. **/ void calculateAndWriteNormal(bool isNextNodeFixed, VerletNode nPrev, VerletNode nNext, VerletNode curNode, const uint furMeshLeftIndex, const uint furMeshRightIndex, const float3 sourceMeshNormal) { float3 tangent; if (isNextNodeFixed) { tangent = normalize(curNode.position - nPrev.position); } else { tangent = normalize(nNext.position - curNode.position); } // float3 right = normalize(cross( tangent, normalize(curNode.position - worldSpaceCameraPos))); writeNormal(furMeshLeftIndex, furMeshRightIndex, sourceMeshNormal, tangent); } THREAD_SIZE void StepAndSolve(uint3 id : SV_DispatchThreadID) { uint index, localPointIndex, hairStrandIndex; if (calculateIndicesAndDiscard(id, index, localPointIndex, hairStrandIndex)) return; const VerletRestNode resPosition = _RestPositions[index]; VerletNode currentNode = verletNodes[index]; const bool isFixedNode = isFixed(currentNode); velocityGravityAndCollision( resPosition.position, isFixedNode, resPosition.vertMoveAbility, resPosition.pinnedAmount, index ); const uint furMeshLeftIndex = index * 2 * furMeshBufferStride; const uint furMeshRightIndex = (index * 2 + 1) * furMeshBufferStride; const float shapeConstraintIntensity = lerp(shapeConstraintRoot,shapeConstraintTip, resPosition.vertMoveAbility); //Append the offset between the verlet node and the rest position vertex. if (localPointIndex > numberOfFixedNodesInStrand) { const VerletRestNode nextRestPosition = _RestPositions[index + 1]; const VerletRestNode prevRestPosition = _RestPositions[index - 1]; const VerletNode nextNode = verletNodes[index + 1]; const bool isNextNodeFixed = isFixed(nextNode); float3 tangent = calculateTangent(resPosition.position, prevRestPosition.position, nextRestPosition.position, isNextNodeFixed); verletNodes[index].tangent = tangent; GroupMemoryBarrierWithGroupSync(); #if defined(INIT_VERLET) const float3 myRotVec = normalize(resPosition.position - prevRestPosition.position); const float3 myRotVec2 = normalize(nextRestPosition.position-resPosition.position ); verletNodes[index].rotationDiffFromPrevNode = verletNodes[index-1].tangent - myRotVec; verletNodes[index].rotationDiffFromPrevNode2 = tangent - myRotVec2; GroupMemoryBarrierWithGroupSync(); #endif constraints( index, resPosition.position, nextRestPosition.position, _RestPositions[index - 1].position, isNextNodeFixed, shapeConstraintIntensity ); VerletNode curNode = verletNodes[index]; const float3 restPosLeft = loadSourceVertex(furMeshLeftIndex); const float3 restPosRight = loadSourceVertex(furMeshRightIndex); const float3 pLeftDiff = curNode.position - restPosLeft; const float3 leftTargetPosition = restPosLeft + pLeftDiff; const float3 rightTargetPosition = restPosRight + pLeftDiff; writeVector3(furMeshLeftIndex, 0, leftTargetPosition); writeVector3(furMeshRightIndex, 0, rightTargetPosition); calculateAndWriteNormal( isNextNodeFixed, verletNodes[index - 1], nextNode, curNode, furMeshLeftIndex, furMeshRightIndex, resPosition.worldSourceMeshNormal ); curNode.position = lerp(leftTargetPosition, restPosLeft, resPosition.pinnedAmount * 2 * deltaTime); verletNodes[index] = curNode; } else { calculateAndWriteNormal( false, currentNode, verletNodes[index + 1], currentNode, furMeshLeftIndex, furMeshRightIndex, resPosition.worldSourceMeshNormal ); verletNodes[index] = currentNode; } }