diff --git a/erpnext/accounts/doctype/subscriptions/subscriptions.py b/erpnext/accounts/doctype/subscriptions/subscriptions.py index 6cd8934fb4..0cccaebf1b 100644 --- a/erpnext/accounts/doctype/subscriptions/subscriptions.py +++ b/erpnext/accounts/doctype/subscriptions/subscriptions.py @@ -5,7 +5,7 @@ from __future__ import unicode_literals import frappe from frappe.model.document import Document -from frappe.utils.data import now, nowdate, getdate, cint, add_days, date_diff, get_last_day, get_first_day +from frappe.utils.data import now, nowdate, getdate, cint, add_days, date_diff, get_last_day, get_first_day, add_to_date from frappe import _ @@ -25,17 +25,58 @@ class Subscriptions(Document): if self.trial_period_start and self.is_trialling(): self.current_invoice_start = self.trial_period_start elif not date: - current_invoice = self.get_current_invoice() - if not current_invoice: - self.current_invoice_start = nowdate() - else: - self.current_invoice_start = current_invoice.posting_date + self.current_invoice_start = nowdate() def set_current_invoice_end(self): if self.is_trialling(): self.current_invoice_end = self.trial_period_end else: - self.current_invoice_end = get_last_day(self.current_invoice_start) + billing_cycle_info = self.get_billing_cycle() + if billing_cycle_info: + self.current_invoice_end = add_to_date(self.current_invoice_start, **billing_cycle_info) + else: + self.current_invoice_end = get_last_day(self.current_invoice_start) + + def get_billing_cycle(self): + return self.get_billing_cycle_data() + + def validate_plans_billing_cycle(self, billing_cycle_data): + if billing_cycle_data and len(billing_cycle_data) != 1: + frappe.throw(_('You can only have Plans with the same billing cycle in a Subscription')) + + def get_billing_cycle_and_interval(self): + plan_names = [plan.plan for plan in self.plans] + billing_info = frappe.db.sql( + 'select distinct `billing_interval`, `billing_interval_count` ' + 'from `tabSubscription Plan` ' + 'where name in %s', + (plan_names,), as_dict=1 + ) + + return billing_info + + def get_billing_cycle_data(self): + billing_info = self.get_billing_cycle_and_interval() + + self.validate_plans_billing_cycle(billing_info) + + if billing_info: + data = dict() + interval = billing_info[0]['billing_interval'] + interval_count = billing_info[0]['billing_interval_count'] + if interval not in ['Day', 'Week']: + data['days'] = -1 + if interval == 'Day': + data['days'] = interval_count - 1 + elif interval == 'Month': + data['months'] = interval_count + elif interval == 'Year': + data['years'] == interval_count + # todo: test week + elif interval == 'Week': + data['days'] = interval_count * 7 - 1 + + return data def before_save(self): self.set_status() @@ -89,6 +130,7 @@ class Subscriptions(Document): def validate(self): self.validate_trial_period() + self.validate_plans_billing_cycle(self.get_billing_cycle_and_interval()) def validate_trial_period(self): if self.trial_period_start and self.trial_period_end: