Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,25 @@ async def get_plan(

offers = []
if effective_spec.configuration.ssh_config is None:
offers_with_backends = await get_create_instance_offers(
project=project,
profile=effective_spec.merged_profile,
requirements=get_fleet_requirements(effective_spec),
fleet_spec=effective_spec,
blocks=effective_spec.configuration.blocks,
)
requirements = get_fleet_requirements(effective_spec)
if _is_elastic_cloud_fleet_spec(effective_spec):
offers_with_backends = await offers_services.get_offers_by_requirements(
project=project,
profile=effective_spec.merged_profile,
requirements=requirements,
multinode=(
effective_spec.configuration.placement == InstanceGroupPlacement.CLUSTER
),
blocks=effective_spec.configuration.blocks,
)
else:
offers_with_backends = await get_create_instance_offers(
project=project,
profile=effective_spec.merged_profile,
requirements=requirements,
fleet_spec=effective_spec,
blocks=effective_spec.configuration.blocks,
)
Comment on lines +446 to +463
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having two branches with get_offers_by_requirements and get_create_instance_offers seems too verbose here. I think it can be replaced with a get_create_instance_offers param that would enable/disable BACKENDS_WITH_CREATE_INSTANCE_SUPPORT filtering.

offers = [offer for _, offer in offers_with_backends]

_remove_fleet_spec_sensitive_info(effective_spec)
Expand All @@ -468,6 +480,16 @@ async def get_plan(
return plan


def _is_elastic_cloud_fleet_spec(fleet_spec: FleetSpec) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function name is a misnomer since it does not really check that the fleet is elastic (nodes is a range) but that nodes.target == 0. And since target cannot be less than min, that's all we need to check for.

nodes = fleet_spec.configuration.nodes
return (
fleet_spec.configuration.ssh_config is None
and nodes is not None
and nodes.min == 0
and nodes.target == 0
)


async def get_create_instance_offers(
project: ProjectModel,
profile: Profile,
Expand Down
73 changes: 73 additions & 0 deletions src/tests/_internal/server/routers/test_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dstack._internal.core.models.common import EntityReference
from dstack._internal.core.models.fleets import (
FleetConfiguration,
FleetNodesSpec,
FleetStatus,
InstanceGroupPlacement,
SSHHostParams,
Expand Down Expand Up @@ -2028,6 +2029,78 @@ async def test_returns_create_plan_for_new_fleet(
"action": "create",
}

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_offers_for_elastic_container_backend_fleet(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
offer = get_instance_offer_with_availability(
backend=BackendType.RUNPOD,
region="US-OR-1",
price=0.7185,
)
spec = get_fleet_spec(
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=1))
)
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.RUNPOD
backend_mock.compute.return_value.get_offers.return_value = [offer]
response = await client.post(
f"/api/project/{project.name}/fleets/get_plan",
headers=get_auth_headers(user.token),
json={"spec": spec.dict()},
)
backend_mock.compute.return_value.get_offers.assert_called_once()

response_json = response.json()
assert response.status_code == 200, response_json
assert response_json["offers"] == [json.loads(offer.json())]
assert response_json["total_offers"] == 1
assert response_json["max_offer_price"] == offer.price

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_no_offers_for_non_elastic_container_backend_fleet(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
offer = get_instance_offer_with_availability(
backend=BackendType.RUNPOD,
region="US-OR-1",
price=0.7185,
)
spec = get_fleet_spec(
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=1, max=1))
)
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.RUNPOD
backend_mock.compute.return_value.get_offers.return_value = [offer]
response = await client.post(
f"/api/project/{project.name}/fleets/get_plan",
headers=get_auth_headers(user.token),
json={"spec": spec.dict()},
)
backend_mock.compute.return_value.get_offers.assert_called_once()

response_json = response.json()
assert response.status_code == 200, response_json
assert response_json["offers"] == []
assert response_json["total_offers"] == 0
assert response_json["max_offer_price"] is None

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_update_plan_for_existing_fleet(
Expand Down
Loading