get crypt_id by most_common

This commit is contained in:
toto 2021-10-29 15:09:29 +02:00
parent 9132520d01
commit ead909e2ca

View File

@ -8,6 +8,7 @@ import tarfile
import sqlite3 import sqlite3
import pathlib import pathlib
import hashlib import hashlib
import collections as coll
from b2sdk.v2 import B2Api from b2sdk.v2 import B2Api
from crypt import encrypt_file, decrypt_file from crypt import encrypt_file, decrypt_file
@ -287,11 +288,10 @@ class DataBase:
crypt_id_list.append(cursor.fetchone()['crypt_id']) crypt_id_list.append(cursor.fetchone()['crypt_id'])
except TypeError: except TypeError:
pass pass
try: if len(crypt_id_list) != 0:
if len(list(set(crypt_id_list))) == 1: id = most_common(crypt_id_list)
return crypt_id_list[0] else:
id = most_frequent(crypt_id_list) # if not already/find in bdd
except ValueError:
cursor.execute("""SELECT IFNULL(max(id) + 1, 0) as crypt_id FROM crypt""") cursor.execute("""SELECT IFNULL(max(id) + 1, 0) as crypt_id FROM crypt""")
return cursor.fetchone()['crypt_id'] return cursor.fetchone()['crypt_id']
params = {'id': id, params = {'id': id,
@ -302,10 +302,13 @@ class DataBase:
AND name NOT IN ({name}) AND name NOT IN ({name})
AND path NOT IN ({path})""".format(**params)) AND path NOT IN ({path})""".format(**params))
neighbour = cursor.fetchall() neighbour = cursor.fetchall()
# if they have a neighbour don't overwrite it
if len(neighbour) > 0: if len(neighbour) > 0:
cursor.execute("""SELECT IFNULL(max(id) + 1, 0) as crypt_id FROM crypt""") cursor.execute("""SELECT IFNULL(max(id) + 1, 0) as crypt_id FROM crypt""")
return cursor.fetchone()['crypt_id'] return cursor.fetchone()['crypt_id']
else: else:
# if they are different, define the same id for the files of this archive
if len(set(crypt_id_list)) > 1:
cursor.execute("""UPDATE files SET crypt_id={id} cursor.execute("""UPDATE files SET crypt_id={id}
WHERE name IN ({name}) WHERE name IN ({name})
AND path IN ({path})""".format(**params)) AND path IN ({path})""".format(**params))
@ -389,6 +392,18 @@ def most_frequent(list):
return max(set(list), key=list.count) return max(set(list), key=list.count)
def most_common(lst):
if len(set(lst)) == 1:
return lst[0]
data = coll.Counter(lst)
most = {'count': 0, 'item': 99999}
for item, count in data.items():
if count > most['count'] and item < most['item']:
most['count'] = count
most['item'] = item
return most['item']
def dict_factory(cursor, row): def dict_factory(cursor, row):
d = {} d = {}
for idx, col in enumerate(cursor.description): for idx, col in enumerate(cursor.description):