You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
5.0 KiB
Plaintext

#pragma kernel StepAndSolve
#pragma multi_compile __ HAS_COLLIDERS
#pragma multi_compile __ COLLIDE_WITH_SOURCE_MESH
#pragma multi_compile __ USE_FORWARD_COLLISION
#pragma multi_compile __ INIT_VERLET
#pragma warning( disable : 4714 )
#pragma warning( disable : 4000 )
#include "VerletSimulationInclude.hlsl"
#include "Thread.hlsl"
#define DISCARD_NODE if((int)id.x >= verletNodesCount) return;
RWStructuredBuffer<VerletRestNode> _RestPositions;
/*
*We need to make sure that strands are not manipulated from two different thread groups.
*Therefore the nearestPow. We dispatch (strandCount * nearestPow).
*Then we calculate the strand index and localNodeIndex based on the nearest pow2 data.
*We then use that strandIndex and localNodeIndex but multiply it with the actual strandPointsCount to get the actual index(idx).
*
*Note: As a results we have a lot of unused callbacks. Unused count = (nearestPow - strandPointsCount) * strandCount;
*
*But from my tests, this is still faster than dispatching Solve, solverIterations times from c# code. With this approach we only need
*one dispatch with a loop in the compute shader.
**/
void calculateAndWriteNormal(bool isNextNodeFixed, VerletNode nPrev, VerletNode nNext, VerletNode curNode, const uint furMeshLeftIndex, const uint furMeshRightIndex,
const float3 sourceMeshNormal)
{
float3 tangent;
if (isNextNodeFixed)
{
tangent = normalize(curNode.position - nPrev.position);
}
else
{
tangent = normalize(nNext.position - curNode.position);
}
// float3 right = normalize(cross( tangent, normalize(curNode.position - worldSpaceCameraPos)));
writeNormal(furMeshLeftIndex, furMeshRightIndex, sourceMeshNormal, tangent);
}
THREAD_SIZE
void StepAndSolve(uint3 id : SV_DispatchThreadID)
{
uint index, localPointIndex, hairStrandIndex;
if (calculateIndicesAndDiscard(id, index, localPointIndex, hairStrandIndex)) return;
const VerletRestNode resPosition = _RestPositions[index];
VerletNode currentNode = verletNodes[index];
const bool isFixedNode = isFixed(currentNode);
velocityGravityAndCollision(
resPosition.position,
isFixedNode,
resPosition.vertMoveAbility,
resPosition.pinnedAmount,
index
);
const uint furMeshLeftIndex = index * 2 * furMeshBufferStride;
const uint furMeshRightIndex = (index * 2 + 1) * furMeshBufferStride;
const float shapeConstraintIntensity = lerp(shapeConstraintRoot,shapeConstraintTip, resPosition.vertMoveAbility);
//Append the offset between the verlet node and the rest position vertex.
if (localPointIndex > numberOfFixedNodesInStrand)
{
const VerletRestNode nextRestPosition = _RestPositions[index + 1];
const VerletRestNode prevRestPosition = _RestPositions[index - 1];
const VerletNode nextNode = verletNodes[index + 1];
const bool isNextNodeFixed = isFixed(nextNode);
float3 tangent = calculateTangent(resPosition.position, prevRestPosition.position, nextRestPosition.position, isNextNodeFixed);
verletNodes[index].tangent = tangent;
GroupMemoryBarrierWithGroupSync();
#if defined(INIT_VERLET)
const float3 myRotVec = normalize(resPosition.position - prevRestPosition.position);
const float3 myRotVec2 = normalize(nextRestPosition.position-resPosition.position );
verletNodes[index].rotationDiffFromPrevNode = verletNodes[index-1].tangent - myRotVec;
verletNodes[index].rotationDiffFromPrevNode2 = tangent - myRotVec2;
GroupMemoryBarrierWithGroupSync();
#endif
constraints(
index,
resPosition.position,
nextRestPosition.position,
_RestPositions[index - 1].position,
isNextNodeFixed,
shapeConstraintIntensity
);
VerletNode curNode = verletNodes[index];
const float3 restPosLeft = loadSourceVertex(furMeshLeftIndex);
const float3 restPosRight = loadSourceVertex(furMeshRightIndex);
const float3 pLeftDiff = curNode.position - restPosLeft;
const float3 leftTargetPosition = restPosLeft + pLeftDiff;
const float3 rightTargetPosition = restPosRight + pLeftDiff;
writeVector3(furMeshLeftIndex, 0, leftTargetPosition);
writeVector3(furMeshRightIndex, 0, rightTargetPosition);
calculateAndWriteNormal(
isNextNodeFixed,
verletNodes[index - 1],
nextNode,
curNode,
furMeshLeftIndex,
furMeshRightIndex,
resPosition.worldSourceMeshNormal
);
curNode.position = lerp(leftTargetPosition, restPosLeft, resPosition.pinnedAmount * 2 * deltaTime);
verletNodes[index] = curNode;
}
else
{
calculateAndWriteNormal(
false,
currentNode,
verletNodes[index + 1],
currentNode,
furMeshLeftIndex,
furMeshRightIndex,
resPosition.worldSourceMeshNormal
);
verletNodes[index] = currentNode;
}
}